aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-11 18:45:21 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-11 18:45:21 -0800
commitf2102f4e2c1c87f1d1bf9ab856a2849c54478760 (patch)
tree54ffdbb4081d6e75d4e626682ea9c70e6866599b /tensorflow
parent3961abed9560cd852fff4add393b451483bbc3af (diff)
TensorFlow: upstream changes from the afternoon.
Changes: - futurize --stage2 changes for Python 3 compatibility by @girving. - Small updates to documentation by @vrv, schuster and others - Account for failure of std::thread::hardware_concurrency by @ebrevdo. - More changes for backwards-compatibility tests by Josh - Updates to python op doc generation by Josh - Added support for using the best-fit allocator via ConfigProto by @vrv. - Rename LocalSession to DirectSession, since local was a bad name for it. - Enable tf.nn.moments() to work with tensors of unknown shape by @mrry. GITHUB_ISSUE: 139 - Changes for Android build by Andrew. Base CL: 107645181
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/__init__.py4
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/core/BUILD18
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc (renamed from tensorflow/core/common_runtime/local_session.cc)42
-rw-r--r--tensorflow/core/common_runtime/direct_session.h (renamed from tensorflow/core/common_runtime/local_session.h)14
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc (renamed from tensorflow/core/common_runtime/local_session_test.cc)20
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc11
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc11
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.h7
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.cc6
-rw-r--r--tensorflow/core/common_runtime/session.cc2
-rw-r--r--tensorflow/core/framework/attr_value_util.cc44
-rw-r--r--tensorflow/core/framework/config.proto10
-rw-r--r--tensorflow/core/framework/node_def_builder_test.cc33
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc2
-rw-r--r--tensorflow/core/framework/op.cc6
-rw-r--r--tensorflow/core/framework/op_compatibility_test.cc306
-rw-r--r--tensorflow/core/framework/op_def_util_test.cc27
-rw-r--r--tensorflow/core/framework/register_types.h6
-rw-r--r--tensorflow/core/framework/types.cc6
-rw-r--r--tensorflow/core/graph/optimizer_cse.cc2
-rw-r--r--tensorflow/core/kernels/aggregate_ops.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_abs.cc2
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h6
-rw-r--r--tensorflow/core/kernels/lrn_op.cc8
-rw-r--r--tensorflow/core/platform/posix/port.cc5
-rw-r--r--tensorflow/core/platform/test.cc3
-rw-r--r--tensorflow/core/platform/test_main.cc3
-rw-r--r--tensorflow/core/platform/tracing.h2
-rw-r--r--tensorflow/g3doc/api_docs/python/control_flow_ops.md15
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md6
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md3
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md74
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md5
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/fact_test.py2
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/index.md11
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/zero_out_1_test.py4
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py4
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/zero_out_3_test.py4
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/zero_out_grad_2.py4
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/convert_to_records.py2
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py3
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py3
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md11
-rw-r--r--tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py10
-rw-r--r--tensorflow/g3doc/tutorials/mnist/input_data.py4
-rw-r--r--tensorflow/g3doc/tutorials/mnist/mnist.py206
-rw-r--r--tensorflow/g3doc/tutorials/mnist/mnist_softmax.py2
-rw-r--r--tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py10
-rw-r--r--tensorflow/models/embedding/word2vec.py3
-rw-r--r--tensorflow/models/embedding/word2vec_optimized.py3
-rw-r--r--tensorflow/models/embedding/word2vec_optimized_test.py4
-rw-r--r--tensorflow/models/embedding/word2vec_test.py4
-rw-r--r--tensorflow/models/image/alexnet/alexnet_benchmark.py8
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py6
-rw-r--r--tensorflow/models/image/cifar10/cifar10_eval.py5
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input.py4
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input_test.py4
-rw-r--r--tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py8
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py7
-rw-r--r--tensorflow/models/image/mnist/BUILD8
-rw-r--r--tensorflow/models/image/mnist/convolutional.py13
-rw-r--r--tensorflow/models/rnn/__init__.py4
-rw-r--r--tensorflow/models/rnn/linear.py4
-rw-r--r--tensorflow/models/rnn/linear_test.py4
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py6
-rw-r--r--tensorflow/models/rnn/ptb/reader.py11
-rw-r--r--tensorflow/models/rnn/ptb/reader_test.py4
-rw-r--r--tensorflow/models/rnn/rnn.py4
-rw-r--r--tensorflow/models/rnn/rnn_cell.py29
-rw-r--r--tensorflow/models/rnn/rnn_cell_test.py5
-rw-r--r--tensorflow/models/rnn/rnn_test.py4
-rw-r--r--tensorflow/models/rnn/seq2seq.py4
-rw-r--r--tensorflow/models/rnn/seq2seq_test.py6
-rw-r--r--tensorflow/models/rnn/translate/BUILD4
-rw-r--r--tensorflow/models/rnn/translate/data_utils.py2
-rw-r--r--tensorflow/models/rnn/translate/seq2seq_model.py5
-rw-r--r--tensorflow/models/rnn/translate/translate.py3
-rw-r--r--tensorflow/python/BUILD6
-rw-r--r--tensorflow/python/__init__.py4
-rw-r--r--tensorflow/python/client/client_lib.py4
-rw-r--r--tensorflow/python/client/events_writer_test.py4
-rw-r--r--tensorflow/python/client/graph_util.py4
-rw-r--r--tensorflow/python/client/graph_util_test.py4
-rw-r--r--tensorflow/python/client/notebook.py3
-rw-r--r--tensorflow/python/client/session.py6
-rw-r--r--tensorflow/python/client/session_test.py5
-rw-r--r--tensorflow/python/framework/device.py4
-rw-r--r--tensorflow/python/framework/device_test.py4
-rw-r--r--tensorflow/python/framework/docs.py22
-rw-r--r--tensorflow/python/framework/errors.py4
-rw-r--r--tensorflow/python/framework/errors_test.py4
-rw-r--r--tensorflow/python/framework/framework_lib.py4
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py5
-rw-r--r--tensorflow/python/framework/importer.py6
-rw-r--r--tensorflow/python/framework/importer_test.py4
-rw-r--r--tensorflow/python/framework/op_def_registry.py4
-rw-r--r--tensorflow/python/framework/ops.py55
-rw-r--r--tensorflow/python/framework/ops_test.py15
-rw-r--r--tensorflow/python/framework/random_seed.py4
-rw-r--r--tensorflow/python/framework/registry.py4
-rw-r--r--tensorflow/python/framework/registry_test.py4
-rw-r--r--tensorflow/python/framework/tensor_shape.py21
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py18
-rw-r--r--tensorflow/python/framework/tensor_util.py17
-rw-r--r--tensorflow/python/framework/tensor_util_test.py4
-rw-r--r--tensorflow/python/framework/test_util.py3
-rw-r--r--tensorflow/python/framework/test_util_test.py6
-rw-r--r--tensorflow/python/framework/types.py11
-rw-r--r--tensorflow/python/framework/types_test.py17
-rw-r--r--tensorflow/python/kernel_tests/argmax_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/attention_ops_test.py16
-rw-r--r--tensorflow/python/kernel_tests/batch_matmul_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/bcast_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/bias_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/candidate_sampler_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py7
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py11
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py70
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/decode_raw_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py4
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/determinant_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/diag_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/division_future_test.py50
-rw-r--r--tensorflow/python/kernel_tests/division_past_test.py50
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/edit_distance_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py14
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py17
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker.py4
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker_test.py3
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py4
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py3
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/lookup_table_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/lrn_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/numerics_test.py4
-rw-r--r--tensorflow/python/kernel_tests/pack_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/pad_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/random_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/random_shuffle_queue_test.py9
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/save_restore_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py11
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py5
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/sparse_concat_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_reorder_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparsemask_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/string_to_number_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/summary_image_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/summary_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/unique_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/unpack_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/variable_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py4
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py3
-rw-r--r--tensorflow/python/lib/core/pywrap_status_test.py4
-rw-r--r--tensorflow/python/lib/io/python_io.py4
-rw-r--r--tensorflow/python/lib/io/tf_record.py4
-rw-r--r--tensorflow/python/ops/array_grad.py4
-rw-r--r--tensorflow/python/ops/array_ops.py10
-rw-r--r--tensorflow/python/ops/attention_ops.py4
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py4
-rw-r--r--tensorflow/python/ops/clip_ops.py3
-rw-r--r--tensorflow/python/ops/common_shapes.py23
-rw-r--r--tensorflow/python/ops/constant_op.py9
-rw-r--r--tensorflow/python/ops/control_flow_grad.py5
-rw-r--r--tensorflow/python/ops/control_flow_ops.py12
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py4
-rw-r--r--tensorflow/python/ops/data_flow_grad.py8
-rw-r--r--tensorflow/python/ops/data_flow_ops.py4
-rw-r--r--tensorflow/python/ops/embedding_ops.py8
-rw-r--r--tensorflow/python/ops/gradients.py7
-rw-r--r--tensorflow/python/ops/gradients_test.py4
-rw-r--r--tensorflow/python/ops/image_ops.py12
-rw-r--r--tensorflow/python/ops/image_ops_test.py7
-rw-r--r--tensorflow/python/ops/init_ops.py10
-rw-r--r--tensorflow/python/ops/io_ops.py4
-rw-r--r--tensorflow/python/ops/linalg_grad.py4
-rw-r--r--tensorflow/python/ops/linalg_ops.py4
-rw-r--r--tensorflow/python/ops/logging_ops.py4
-rw-r--r--tensorflow/python/ops/math_grad.py9
-rw-r--r--tensorflow/python/ops/math_ops.py109
-rw-r--r--tensorflow/python/ops/math_ops_test.py4
-rw-r--r--tensorflow/python/ops/nn.py23
-rw-r--r--tensorflow/python/ops/nn_grad.py4
-rw-r--r--tensorflow/python/ops/nn_ops.py5
-rw-r--r--tensorflow/python/ops/nn_test.py39
-rw-r--r--tensorflow/python/ops/numerics.py4
-rw-r--r--tensorflow/python/ops/op_def_library.py5
-rw-r--r--tensorflow/python/ops/op_def_library_test.py4
-rw-r--r--tensorflow/python/ops/parsing_ops.py4
-rw-r--r--tensorflow/python/ops/random_ops.py5
-rw-r--r--tensorflow/python/ops/sparse_grad.py4
-rw-r--r--tensorflow/python/ops/sparse_ops.py7
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py4
-rw-r--r--tensorflow/python/ops/standard_ops.py4
-rw-r--r--tensorflow/python/ops/state_grad.py4
-rw-r--r--tensorflow/python/ops/state_ops.py4
-rw-r--r--tensorflow/python/ops/string_ops.py4
-rw-r--r--tensorflow/python/ops/summary_ops.py4
-rw-r--r--tensorflow/python/ops/variable_scope.py14
-rw-r--r--tensorflow/python/ops/variables.py4
-rw-r--r--tensorflow/python/platform/__init__.py3
-rw-r--r--tensorflow/python/platform/app.py3
-rw-r--r--tensorflow/python/platform/default/_app.py12
-rw-r--r--tensorflow/python/platform/default/_flags.py4
-rw-r--r--tensorflow/python/platform/default/_gfile.py4
-rw-r--r--tensorflow/python/platform/default/_googletest.py4
-rw-r--r--tensorflow/python/platform/default/_logging.py4
-rw-r--r--tensorflow/python/platform/default/_parameterized.py4
-rw-r--r--tensorflow/python/platform/default/_resource_loader.py6
-rw-r--r--tensorflow/python/platform/default/_status_bar.py4
-rw-r--r--tensorflow/python/platform/default/flags_test.py4
-rw-r--r--tensorflow/python/platform/default/gfile_test.py8
-rw-r--r--tensorflow/python/platform/default/logging_test.py4
-rw-r--r--tensorflow/python/platform/flags.py3
-rw-r--r--tensorflow/python/platform/gfile.py3
-rw-r--r--tensorflow/python/platform/googletest.py3
-rw-r--r--tensorflow/python/platform/logging.py3
-rw-r--r--tensorflow/python/platform/parameterized.py3
-rw-r--r--tensorflow/python/platform/resource_loader.py3
-rw-r--r--tensorflow/python/platform/status_bar.py3
-rw-r--r--tensorflow/python/platform/test.py4
-rw-r--r--tensorflow/python/summary/event_accumulator.py5
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py9
-rw-r--r--tensorflow/python/summary/event_multiplexer.py23
-rw-r--r--tensorflow/python/summary/event_multiplexer_test.py16
-rw-r--r--tensorflow/python/summary/impl/directory_watcher.py4
-rw-r--r--tensorflow/python/summary/impl/directory_watcher_test.py4
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py2
-rw-r--r--tensorflow/python/summary/impl/event_file_loader_test.py4
-rw-r--r--tensorflow/python/summary/impl/reservoir.py6
-rw-r--r--tensorflow/python/summary/impl/reservoir_test.py13
-rw-r--r--tensorflow/python/training/adagrad.py4
-rw-r--r--tensorflow/python/training/adagrad_test.py4
-rw-r--r--tensorflow/python/training/adam.py4
-rw-r--r--tensorflow/python/training/adam_test.py4
-rw-r--r--tensorflow/python/training/coordinator.py4
-rw-r--r--tensorflow/python/training/coordinator_test.py4
-rw-r--r--tensorflow/python/training/ftrl.py4
-rw-r--r--tensorflow/python/training/ftrl_test.py4
-rw-r--r--tensorflow/python/training/gradient_descent.py4
-rw-r--r--tensorflow/python/training/gradient_descent_test.py4
-rw-r--r--tensorflow/python/training/input.py10
-rw-r--r--tensorflow/python/training/input_test.py17
-rw-r--r--tensorflow/python/training/learning_rate_decay.py4
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py8
-rw-r--r--tensorflow/python/training/momentum.py4
-rw-r--r--tensorflow/python/training/momentum_test.py5
-rw-r--r--tensorflow/python/training/moving_averages.py4
-rw-r--r--tensorflow/python/training/moving_averages_test.py5
-rw-r--r--tensorflow/python/training/optimizer.py6
-rw-r--r--tensorflow/python/training/queue_runner.py4
-rw-r--r--tensorflow/python/training/queue_runner_test.py4
-rw-r--r--tensorflow/python/training/rmsprop.py4
-rw-r--r--tensorflow/python/training/rmsprop_test.py4
-rw-r--r--tensorflow/python/training/saver.py9
-rw-r--r--tensorflow/python/training/saver_test.py4
-rw-r--r--tensorflow/python/training/summary_io.py8
-rw-r--r--tensorflow/python/training/summary_writer_test.py4
-rw-r--r--tensorflow/python/training/training.py4
-rw-r--r--tensorflow/python/training/training_ops.py4
-rw-r--r--tensorflow/python/training/training_ops_test.py10
-rw-r--r--tensorflow/python/training/training_util.py4
-rw-r--r--tensorflow/python/user_ops/user_ops.py4
-rw-r--r--tensorflow/python/util/protobuf/compare.py16
-rw-r--r--tensorflow/python/util/protobuf/compare_test.py17
-rw-r--r--tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html6
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts12
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/render.ts114
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts38
-rw-r--r--tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html10
-rw-r--r--tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html79
-rw-r--r--tensorflow/tensorboard/components/tf-graph-info/tf-node-list-item.html16
-rw-r--r--tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html110
-rw-r--r--tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html50
-rw-r--r--tensorflow/tensorboard/components/tf-graph/tf-graph.html20
-rw-r--r--tensorflow/tensorboard/float_wrapper.py6
-rw-r--r--tensorflow/tensorboard/float_wrapper_test.py4
-rw-r--r--tensorflow/tensorboard/scripts/demo_from_server.py9
-rw-r--r--tensorflow/tensorboard/tensorboard.py2
-rw-r--r--tensorflow/tensorboard/tensorboard_handler.py9
-rw-r--r--tensorflow/tools/docker/simple_console.py12
-rw-r--r--tensorflow/tools/pip_package/setup.py4
-rw-r--r--tensorflow/tools/pip_package/simple_console.py12
321 files changed, 2608 insertions, 709 deletions
diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py
index 3e28aa85ec..ddf5b30211 100644
--- a/tensorflow/__init__.py
+++ b/tensorflow/__init__.py
@@ -1,4 +1,8 @@
# Bring in all of the public TensorFlow interface into this
# module.
# pylint: disable=wildcard-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python import *
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 8a5bf87a29..7b8cfc56ed 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -71,7 +71,6 @@ cc_binary(
deps = [
":cc_ops",
"//tensorflow/core:kernels",
- "//tensorflow/core:local",
"//tensorflow/core:tensorflow",
],
)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c2fcfeed8c..31ec4913f8 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -70,8 +70,8 @@ tf_cuda_library(
"common_runtime/gpu/*.cc",
"common_runtime/copy_tensor.cc",
"common_runtime/gpu_device_factory.cc",
- "common_runtime/local_session.cc",
- "common_runtime/local_session.h",
+ "common_runtime/direct_session.cc",
+ "common_runtime/direct_session.h",
],
),
hdrs = glob(["public/**/*.h"]),
@@ -113,10 +113,10 @@ tf_cuda_library(
)
tf_cuda_library(
- name = "local",
+ name = "direct_session",
srcs = [
- "common_runtime/local_session.cc",
- "common_runtime/local_session.h",
+ "common_runtime/direct_session.cc",
+ "common_runtime/direct_session.h",
],
copts = tf_copts(),
cuda_deps = [
@@ -200,10 +200,10 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = [
":core",
+ ":direct_session",
":gpu_runtime",
":kernels",
":lib",
- ":local",
],
)
@@ -300,7 +300,7 @@ tf_proto_library(
cc_api_version = 2,
go_api_version = 2,
java_api_version = 2,
- py_api_version = 2,
+ py_api_version = 2, # TODO(irving): Handle 3
visibility = ["//tensorflow:internal"],
)
@@ -402,9 +402,9 @@ tf_cc_tests(
),
deps = [
":core",
+ ":direct_session",
":kernels",
":lib",
- ":local",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
@@ -423,8 +423,8 @@ tf_cc_tests(
],
),
deps = [
+ ":direct_session",
":kernels",
- ":local",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/core/common_runtime/local_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 4c41e5b3b3..7727b828ed 100644
--- a/tensorflow/core/common_runtime/local_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1,4 +1,4 @@
-#include "tensorflow/core/common_runtime/local_session.h"
+#include "tensorflow/core/common_runtime/direct_session.h"
#include <string>
#include <vector>
@@ -42,7 +42,7 @@ static bool InitModule(const SessionOptions& options) {
// Default to using the number of cores available in the process.
inter_op_parallelism_threads = port::NumSchedulableCPUs();
}
- LOG(INFO) << "Local session inter op parallelism threads: "
+ LOG(INFO) << "Direct session inter op parallelism threads: "
<< inter_op_parallelism_threads;
kernel_thread_pool_ = new thread::ThreadPool(options.env, "Compute",
inter_op_parallelism_threads);
@@ -92,7 +92,7 @@ void SchedClosure(std::function<void()> c) {
} // namespace
-LocalSession::LocalSession(const SessionOptions& options,
+DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr)
: options_(options),
device_mgr_(device_mgr),
@@ -124,7 +124,7 @@ LocalSession::LocalSession(const SessionOptions& options,
}
}
-LocalSession::~LocalSession() {
+DirectSession::~DirectSession() {
for (auto d : device_mgr_->ListDevices()) {
d->op_segment()->RemoveHold(session_handle_);
}
@@ -134,7 +134,7 @@ LocalSession::~LocalSession() {
delete cancellation_manager_;
}
-Status LocalSession::Create(const GraphDef& graph) {
+Status DirectSession::Create(const GraphDef& graph) {
mutex_lock l(graph_def_lock_);
if (graph_created_) {
return errors::AlreadyExists(
@@ -143,18 +143,18 @@ Status LocalSession::Create(const GraphDef& graph) {
return ExtendLocked(graph);
}
-Status LocalSession::Extend(const GraphDef& graph) {
+Status DirectSession::Extend(const GraphDef& graph) {
mutex_lock l(graph_def_lock_);
return ExtendLocked(graph);
}
-Status LocalSession::ExtendLocked(const GraphDef& graph) {
+Status DirectSession::ExtendLocked(const GraphDef& graph) {
graph_created_ = true; // In case this is first call
graph_def_.MergeFrom(graph);
return Status::OK();
}
-Status LocalSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
+Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs) {
@@ -250,7 +250,7 @@ Status LocalSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
return s;
}
-Status LocalSession::GetOrCreateExecutors(
+Status DirectSession::GetOrCreateExecutors(
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys) {
@@ -358,7 +358,7 @@ Status LocalSession::GetOrCreateExecutors(
return Status::OK();
}
-void LocalSession::SaveStatefulNodes(Graph* graph) {
+void DirectSession::SaveStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
VLOG(2) << "Saving " << n->DebugString();
@@ -367,7 +367,7 @@ void LocalSession::SaveStatefulNodes(Graph* graph) {
}
}
-void LocalSession::RestoreStatefulNodes(Graph* graph) {
+void DirectSession::RestoreStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
auto iter = stateful_placements_.find(n->name());
@@ -379,7 +379,7 @@ void LocalSession::RestoreStatefulNodes(Graph* graph) {
}
}
-Status LocalSession::CreateGraphs(gtl::ArraySlice<string> feeds,
+Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
gtl::ArraySlice<string> fetches,
gtl::ArraySlice<string> target_nodes,
std::unordered_map<string, Graph*>* outputs) {
@@ -422,7 +422,7 @@ Status LocalSession::CreateGraphs(gtl::ArraySlice<string> feeds,
return strings::StrCat(prefix, "/_", name_counter_++);
};
popts.get_incarnation = [](const string& name) {
- // The local session does not have changing incarnation numbers.
+ // The direct session does not have changing incarnation numbers.
// Just return '1'.
return 1;
};
@@ -476,29 +476,29 @@ Status LocalSession::CreateGraphs(gtl::ArraySlice<string> feeds,
return Status::OK();
}
-::tensorflow::Status LocalSession::Close() {
+::tensorflow::Status DirectSession::Close() {
cancellation_manager_->StartCancel();
return ::tensorflow::Status::OK();
}
-class LocalSessionFactory : public SessionFactory {
+class DirectSessionFactory : public SessionFactory {
public:
- LocalSessionFactory() {}
+ DirectSessionFactory() {}
Session* NewSession(const SessionOptions& options) override {
std::vector<Device*> devices;
DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0",
&devices);
- return new LocalSession(options, new DeviceMgr(devices));
+ return new DirectSession(options, new DeviceMgr(devices));
}
};
-class LocalSessionRegistrar {
+class DirectSessionRegistrar {
public:
- LocalSessionRegistrar() {
- SessionFactory::Register("LOCAL_SESSION", new LocalSessionFactory());
+ DirectSessionRegistrar() {
+ SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
}
};
-static LocalSessionRegistrar registrar;
+static DirectSessionRegistrar registrar;
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/local_session.h b/tensorflow/core/common_runtime/direct_session.h
index 453cfdde47..6a2d58b081 100644
--- a/tensorflow/core/common_runtime/local_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -1,5 +1,5 @@
-#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
-#define TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
+#ifndef TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
+#define TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
#include <memory>
#include <string>
@@ -22,11 +22,11 @@ namespace tensorflow {
class Device;
-class LocalSession : public Session {
+class DirectSession : public Session {
public:
// Takes ownership of 'device_mgr'.
- LocalSession(const SessionOptions& options, const DeviceMgr* device_mgr);
- ~LocalSession() override;
+ DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
+ ~DirectSession() override;
::tensorflow::Status Create(const GraphDef& graph) override;
::tensorflow::Status Extend(const GraphDef& graph) override;
@@ -101,9 +101,9 @@ class LocalSession : public Session {
// For generating unique names.
int64 name_counter_ GUARDED_BY(mu_) = 0;
- TF_DISALLOW_COPY_AND_ASSIGN(LocalSession);
+ TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
};
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
+#endif // TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
diff --git a/tensorflow/core/common_runtime/local_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 9325fe44c3..d021a3a5c1 100644
--- a/tensorflow/core/common_runtime/local_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1,4 +1,4 @@
-#include "tensorflow/core/common_runtime/local_session.h"
+#include "tensorflow/core/common_runtime/direct_session.h"
#include <map>
#include <string>
@@ -29,7 +29,7 @@ Session* CreateSession() {
return NewSession(options);
}
-class LocalSessionMinusAXTest : public ::testing::Test {
+class DirectSessionMinusAXTest : public ::testing::Test {
public:
void Initialize(std::initializer_list<float> a_values) {
RequireDefaultOps();
@@ -64,7 +64,7 @@ class LocalSessionMinusAXTest : public ::testing::Test {
GraphDef def_;
};
-TEST_F(LocalSessionMinusAXTest, RunSimpleNetwork) {
+TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
Initialize({3, 2, -1, 0});
std::unique_ptr<Session> session(CreateSession());
ASSERT_TRUE(session != nullptr);
@@ -86,7 +86,7 @@ TEST_F(LocalSessionMinusAXTest, RunSimpleNetwork) {
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
}
-TEST_F(LocalSessionMinusAXTest, TestFeed) {
+TEST_F(DirectSessionMinusAXTest, TestFeed) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
ASSERT_TRUE(session != nullptr);
@@ -115,7 +115,7 @@ TEST_F(LocalSessionMinusAXTest, TestFeed) {
EXPECT_FLOAT_EQ(39.0, mat(1, 0));
}
-TEST_F(LocalSessionMinusAXTest, TestConcurrency) {
+TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
ASSERT_TRUE(session != nullptr);
@@ -147,7 +147,7 @@ TEST_F(LocalSessionMinusAXTest, TestConcurrency) {
delete tp;
}
-TEST_F(LocalSessionMinusAXTest, TwoCreateCallsFails) {
+TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
ASSERT_TRUE(session != nullptr);
@@ -157,7 +157,7 @@ TEST_F(LocalSessionMinusAXTest, TwoCreateCallsFails) {
ASSERT_FALSE(session->Create(def_).ok());
}
-TEST_F(LocalSessionMinusAXTest, ForgetToCreate) {
+TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
ASSERT_TRUE(session != nullptr);
@@ -166,7 +166,7 @@ TEST_F(LocalSessionMinusAXTest, ForgetToCreate) {
ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
}
-TEST_F(LocalSessionMinusAXTest, InvalidDevice) {
+TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
GraphDef def;
Graph graph(OpRegistry::Global());
@@ -203,7 +203,7 @@ TEST_F(LocalSessionMinusAXTest, InvalidDevice) {
ASSERT_OK(session->Run(inputs, output_names, {}, &outputs));
}
-TEST(LocalSessionTest, KeepsStateAcrossRunsOfSession) {
+TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
@@ -242,7 +242,7 @@ TEST(LocalSessionTest, KeepsStateAcrossRunsOfSession) {
EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
}
-TEST(LocalSessionTest, MultipleFeedTest) {
+TEST(DirectSessionTest, MultipleFeedTest) {
GraphDef def;
Graph g(OpRegistry::Global());
Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 26d34645f1..65174135d8 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -558,11 +558,12 @@ LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
<< " numa: " << desc.numa_node() << " pci: " << desc.pci_bus_id();
ProcessState* process_state = ProcessState::singleton();
- return CreateGPUDevice(
- options, name, allocated_bytes, bus_adjacency, gpu_id,
- GetShortDeviceDescription(gpu_id, desc),
- process_state->GetGPUAllocator(gpu_id, allocated_memory),
- process_state->GetCPUAllocator(desc.numa_node()));
+ return CreateGPUDevice(options, name, allocated_bytes, bus_adjacency, gpu_id,
+ GetShortDeviceDescription(gpu_id, desc),
+ process_state->GetGPUAllocator(
+ gpu_id, allocated_memory,
+ options.config.gpu_options().allocator_type()),
+ process_state->GetCPUAllocator(desc.numa_node()));
}
static int GetMinGPUMultiprocessorCount() {
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 70ac6130c2..474b988d2f 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -87,7 +87,8 @@ void ProcessState::SetGPUCount(int c) {
int ProcessState::GPUCount() const { return gpu_count_; }
-Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes) {
+Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes,
+ const string& allocator_type) {
#if GOOGLE_CUDA
mutex_lock lock(mu_);
gpu::Platform* gpu_platform = GPUMachineManager();
@@ -104,7 +105,13 @@ Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes) {
if (gpu_allocators_[gpu_id] == nullptr) {
VisitableAllocator* gpu_allocator;
- if (FLAGS_brain_gpu_use_bfc_allocator) {
+ // Validate allocator types.
+ if (!allocator_type.empty() && allocator_type != "BFC") {
+ LOG(ERROR) << "Invalid allocator type: " << allocator_type;
+ return nullptr;
+ }
+
+ if (FLAGS_brain_gpu_use_bfc_allocator || allocator_type == "BFC") {
gpu_allocator = new GPUBFCAllocator(gpu_id, total_bytes);
} else {
gpu_allocator = new GPURegionAllocator(gpu_id, total_bytes);
diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h
index 527d12c10d..79b72eb13c 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.h
+++ b/tensorflow/core/common_runtime/gpu/process_state.h
@@ -63,9 +63,14 @@ class ProcessState {
// a given gpu_id creates the allocator, so only the total_bytes
// used on that first call is used.
//
+ // "Allocator type" describes the type of algorithm to use for the
+ // underlying allocator. REQURES: Must be a valid type (see
+ // config.proto for the list of supported strings.).
+ //
// REQUIRES: gpu_id must be a valid ordinal for a GPU available in the
// current system environment. Otherwise returns nullptr.
- Allocator* GetGPUAllocator(int gpu_id, size_t total_bytes);
+ Allocator* GetGPUAllocator(int gpu_id, size_t total_bytes,
+ const string& allocator_type);
Allocator* GetCUDAHostAllocator(int numa_node);
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
index 111dea6d4c..7ffeed3507 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.cc
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -4,8 +4,7 @@
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
-#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \
- (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
+#if !defined(__ANDROID__) && (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#endif
#include "tensorflow/core/framework/types.h"
@@ -39,8 +38,7 @@ void CopyTensorBetweenDevices(const string& id, DeviceContext* send_dev_context,
done(Status::OK());
}
-#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \
- (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
+#if !defined(__ANDROID__) && (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
constexpr auto CopyTensorBetweenDevicesFunc = &GPUUtil::CopyViaDMA;
#else
constexpr auto CopyTensorBetweenDevicesFunc = &CopyTensorBetweenDevices;
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
index 6d1ab5cea4..2dd5d9b3cc 100644
--- a/tensorflow/core/common_runtime/session.cc
+++ b/tensorflow/core/common_runtime/session.cc
@@ -9,7 +9,7 @@ namespace tensorflow {
namespace {
Status GetFactory(const SessionOptions& options, SessionFactory** ret) {
- string runtime_type = "LOCAL_SESSION";
+ string runtime_type = "DIRECT_SESSION";
if (!options.target.empty()) {
// Use the service based session.
runtime_type = "REMOTE_SESSION";
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 400ef118b8..645c069b93 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -110,25 +110,25 @@ string SummarizeAttrValue(const AttrValue& attr_value) {
Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
int num_set = 0;
-#define VALIDATE_FIELD(name, type_string, oneof_case) \
- do { \
- if (attr_value.has_list()) { \
- if (attr_value.list().name##_size() > 0) { \
- if (type != "list(" type_string ")") { \
- return errors::InvalidArgument( \
- "AttrValue had value with type list(" type_string ") when ", \
- type, " expected"); \
- } \
- ++num_set; \
- } \
- } else if (attr_value.value_case() == AttrValue::oneof_case) { \
- if (type != type_string) { \
- return errors::InvalidArgument( \
- "AttrValue had value with type " type_string " when ", type, \
- " expected"); \
- } \
- ++num_set; \
- } \
+#define VALIDATE_FIELD(name, type_string, oneof_case) \
+ do { \
+ if (attr_value.has_list()) { \
+ if (attr_value.list().name##_size() > 0) { \
+ if (type != "list(" type_string ")") { \
+ return errors::InvalidArgument( \
+ "AttrValue had value with type 'list(" type_string ")' when '", \
+ type, "' expected"); \
+ } \
+ ++num_set; \
+ } \
+ } else if (attr_value.value_case() == AttrValue::oneof_case) { \
+ if (type != type_string) { \
+ return errors::InvalidArgument( \
+ "AttrValue had value with type '" type_string "' when '", type, \
+ "' expected"); \
+ } \
+ ++num_set; \
+ } \
} while (false)
VALIDATE_FIELD(s, "string", kS);
@@ -144,7 +144,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
if (attr_value.value_case() == AttrValue::kFunc) {
if (type != "func") {
return errors::InvalidArgument(
- "AttrValue had value with type 'func' when ", type, " expected");
+ "AttrValue had value with type 'func' when '", type, "' expected");
}
++num_set;
}
@@ -161,7 +161,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
if (num_set) {
return errors::InvalidArgument(
- "AttrValue missing value with expected type ", type);
+ "AttrValue missing value with expected type '", type, "'");
} else {
// Indicate that we have a list, but an empty one.
++num_set;
@@ -171,7 +171,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
// Okay to have an empty list, but not to be missing a non-list value.
if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
return errors::InvalidArgument(
- "AttrValue missing value with expected type ", type);
+ "AttrValue missing value with expected type '", type, "'");
}
// Ref types and DT_INVALID are illegal.
diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto
index f0def3d6d7..81f4a151a7 100644
--- a/tensorflow/core/framework/config.proto
+++ b/tensorflow/core/framework/config.proto
@@ -9,6 +9,16 @@ message GPUOptions {
// to pre-allocate all of the GPU memory, 0.5 means the process
// allocates ~50% of the available GPU memory.
double per_process_gpu_memory_fraction = 1;
+
+ // The type of GPU allocation strategy to use.
+ //
+ // Allowed values:
+ // "": The empty string (default) uses a system-chosen default
+ // which may change over time.
+ //
+ // "BFC": A "Best-fit with coalescing" algorithm, simplified from a
+ // version of dlmalloc.
+ string allocator_type = 2;
};
// Session configuration parameters.
diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc
index 6fd4a8d1ed..1a2ae586b6 100644
--- a/tensorflow/core/framework/node_def_builder_test.cc
+++ b/tensorflow/core/framework/node_def_builder_test.cc
@@ -252,11 +252,12 @@ TEST_F(NodeDefBuilderTest, PolymorphicOut) {
ExpectInvalid(Builder(), "NodeDef missing attr 'T' from");
// Attr has the wrong type
- ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}),
- "AttrValue had value with type list(type) when type expected");
+ ExpectInvalid(
+ Builder().Attr("T", {DT_INT32, DT_BOOL}),
+ "AttrValue had value with type 'list(type)' when 'type' expected");
ExpectInvalid(Builder().Attr("T", 12),
- "AttrValue had value with type int when type expected");
+ "AttrValue had value with type 'int' when 'type' expected");
}
TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) {
@@ -405,8 +406,9 @@ TEST_F(NodeDefBuilderTest, OutTypeList) {
op: "OutTypeList"
attr { key: "T" value { list { } } } )proto");
- ExpectInvalid(Builder().Attr("T", DT_FLOAT),
- "AttrValue had value with type type when list(type) expected");
+ ExpectInvalid(
+ Builder().Attr("T", DT_FLOAT),
+ "AttrValue had value with type 'type' when 'list(type)' expected");
}
TEST_F(NodeDefBuilderTest, TypeListRestrict) {
@@ -447,10 +449,11 @@ TEST_F(NodeDefBuilderTest, Attr) {
// Attr has wrong type
ExpectInvalid(Builder().Attr("a", "bad"),
- "AttrValue had value with type string when int expected");
+ "AttrValue had value with type 'string' when 'int' expected");
- ExpectInvalid(Builder().Attr("a", {12}),
- "AttrValue had value with type list(int) when int expected");
+ ExpectInvalid(
+ Builder().Attr("a", {12}),
+ "AttrValue had value with type 'list(int)' when 'int' expected");
// Missing attr
ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<");
@@ -477,7 +480,7 @@ TEST_F(NodeDefBuilderTest, AttrFloat) {
// Won't automatically cast int to float
ExpectInvalid(Builder().Attr("a", 12),
- "AttrValue had value with type int when float expected");
+ "AttrValue had value with type 'int' when 'float' expected");
}
TEST_F(NodeDefBuilderTest, AttrBoolList) {
@@ -494,7 +497,7 @@ TEST_F(NodeDefBuilderTest, AttrBoolList) {
// Won't cast int -> bool.
ExpectInvalid(Builder().Attr("a", {0}),
- "AttrValue had value with type list(int) when list(bool) "
+ "AttrValue had value with type 'list(int)' when 'list(bool)' "
"expected");
}
@@ -892,8 +895,9 @@ TEST_F(NodeDefBuilderTest, NIntsOut) {
ExpectInvalid(Builder().Attr("N", 1),
"Value for attr 'N' of 1 must be at least minimum 2");
- ExpectInvalid(Builder().Attr("N", {3}),
- "AttrValue had value with type list(int) when int expected");
+ ExpectInvalid(
+ Builder().Attr("N", {3}),
+ "AttrValue had value with type 'list(int)' when 'int' expected");
ExpectInvalid(Builder(), "NodeDef missing attr 'N' from");
}
@@ -933,8 +937,9 @@ TEST_F(NodeDefBuilderTest, NPolymorphicOut) {
ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING),
"Value for attr 'N' of 1 must be at least minimum 2");
- ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}),
- "AttrValue had value with type list(type) when type expected");
+ ExpectInvalid(
+ Builder().Attr("N", 3).Attr("T", {DT_STRING}),
+ "AttrValue had value with type 'list(type)' when 'type' expected");
}
TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) {
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 71f1760a09..765fdfcb8d 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -85,7 +85,7 @@ TEST(NodeDefUtilTest, In) {
AddNodeAttr("T", 17, &bad);
ExpectFailure(
bad, op,
- "AttrValue had value with type int when type expected\n\t for attr "
+ "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
"'T'\n\t; NodeDef: ");
// Wrong number of inputs
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index 15b7eab4da..90b39efbce 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -45,15 +45,15 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name,
if (op_def == nullptr) {
status->Update(
errors::NotFound("Op type not registered '", op_type_name, "'"));
- static bool first = true;
- if (first) {
+ static bool first_unregistered = true;
+ if (first_unregistered) {
OpList op_list;
Export(true, &op_list);
LOG(INFO) << "All registered Ops:";
for (const auto& op : op_list.op()) {
LOG(INFO) << SummarizeOpDef(op);
}
- first = false;
+ first_unregistered = false;
}
}
return op_def;
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc
index 068130136b..7716b34ca1 100644
--- a/tensorflow/core/framework/op_compatibility_test.cc
+++ b/tensorflow/core/framework/op_compatibility_test.cc
@@ -32,7 +32,7 @@ class OpCompatibilityTest : public OpsTestBase {
return new_op_def;
}
- void Run(const OpDef& old_op_def) {
+ void ExpectSuccess(const OpDef& old_op_def) {
// Record the original signature before we change *node_def().
DataTypeVector old_in_types, old_out_types;
ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types,
@@ -56,6 +56,49 @@ class OpCompatibilityTest : public OpsTestBase {
}
string Result() { return GetOutput(0)->scalar<string>()(); }
+
+ void ExpectInvalid(const OpDef& old_op_def, string error) {
+ // Record the original signature before we change *node_def().
+ DataTypeVector old_in_types, old_out_types;
+ ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types,
+ &old_out_types));
+
+ // This should be all that is needed to get compatiblity.
+ const OpDef* new_op_def = RegisteredOpDef();
+ AddDefaultsToNodeDef(*new_op_def, node_def());
+
+ // Validate that it does not pass validation.
+ Status status = ValidateNodeDef(*node_def(), *new_op_def);
+ if (status.ok()) {
+ ADD_FAILURE() << SummarizeNodeDef(*node_def());
+ } else {
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(error))
+ << status << " does not contain " << error;
+ }
+ }
+
+ void ExpectTypeMismatch(const OpDef& old_op_def) {
+ // Record the original signature before we change *node_def().
+ DataTypeVector old_in_types, old_out_types;
+ ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types,
+ &old_out_types));
+
+ // This should be all that is needed to get compatiblity.
+ const OpDef* new_op_def = RegisteredOpDef();
+ AddDefaultsToNodeDef(*new_op_def, node_def());
+
+ // Validate that it is valid, but with incompatible types.
+ ASSERT_OK(ValidateNodeDef(*node_def(), *new_op_def));
+
+ DataTypeVector new_in_types, new_out_types;
+ ASSERT_OK(InOutTypesForNode(*node_def(), *new_op_def, &new_in_types,
+ &new_out_types));
+ if (new_in_types == old_in_types && new_out_types == old_out_types) {
+ ADD_FAILURE() << SummarizeNodeDef(*node_def()) << "\n"
+ << DataTypeVectorString(new_in_types) << " -> "
+ << DataTypeVectorString(new_out_types);
+ }
+ }
};
// Should be compatible if the Op hasn't changed (sanity check).
@@ -79,7 +122,7 @@ TEST_F(OpCompatibilityTest, Same) {
.Input(FakeInput(3, DT_FLOAT))
.Input(FakeInput(2, DT_BOOL))
.Finalize(node_def()));
- Run(*RegisteredOpDef());
+ ExpectSuccess(*RegisteredOpDef());
EXPECT_EQ(
"same = Same[N=3, T=DT_FLOAT, TList=[DT_BOOL, DT_BOOL]](a, b, c, c:1, "
"c:2, d, d:1, d:2, e, e:1)",
@@ -95,7 +138,7 @@ TEST_F(OpCompatibilityTest, AddAttr) {
ASSERT_OK(
OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("add_attr", &old_op_def).Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("add_attr = AddAttr[a=42]()", Result());
}
@@ -112,7 +155,7 @@ TEST_F(OpCompatibilityTest, LessStrict) {
ASSERT_OK(NodeDefBuilder("less_strict", &old_op_def)
.Attr("a", "B")
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("less_strict = LessStrict[a=\"B\"]()", Result());
}
@@ -130,7 +173,7 @@ TEST_F(OpCompatibilityTest, RemoveRestriction) {
ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op_def)
.Attr("a", DT_INT32)
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("remove_restriction = RemoveRestriction[a=DT_INT32]()", Result());
}
@@ -149,7 +192,7 @@ TEST_F(OpCompatibilityTest, AttrOrder) {
.Attr("b", true)
.Attr("a", 7)
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("attr_order = AttrOrder[a=7, b=true]()", Result());
}
@@ -166,7 +209,7 @@ TEST_F(OpCompatibilityTest, AddDefault) {
ASSERT_OK(NodeDefBuilder("add_default", &old_op_def)
.Attr("a", 765)
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("add_default = AddDefault[a=765]()", Result());
}
@@ -182,11 +225,11 @@ TEST_F(OpCompatibilityTest, RemoveDefault) {
.Attr("a: int = 91")
.Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("remove_default", &old_op_def).Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("remove_default = RemoveDefault[a=91]()", Result());
}
-// Should be able to make an input polymorphic.
+// Should be able to make an input/output polymorphic.
// Changing from int32 -> T (where T: type = DT_INT32 by default).
REGISTER_OP("TypePolymorphic")
.Input("a: T")
@@ -203,11 +246,11 @@ TEST_F(OpCompatibilityTest, TypePolymorphic) {
ASSERT_OK(NodeDefBuilder("type_polymorphic", &old_op_def)
.Input(FakeInput())
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("type_polymorphic = TypePolymorphic[T=DT_INT32](a)", Result());
}
-// Should be able to make a single input into a list.
+// Should be able to make a single input/output into a list.
// Changing from int32 -> N * int32 (where N: int = 1 by default).
REGISTER_OP("MakeList")
.Input("a: N * int32")
@@ -224,11 +267,11 @@ TEST_F(OpCompatibilityTest, MakeList) {
ASSERT_OK(NodeDefBuilder("make_list", &old_op_def)
.Input(FakeInput())
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("make_list = MakeList[N=1](a)", Result());
}
-// Should be able to make a single input into a polymorphic list.
+// Should be able to make a single input/output into a polymorphic list.
// Changing from int32 -> N * T (where N: int = 1 by default and
// T: type = DT_INT32 by default).
REGISTER_OP("MakePolyList")
@@ -247,11 +290,11 @@ TEST_F(OpCompatibilityTest, MakePolyList) {
ASSERT_OK(NodeDefBuilder("make_poly_list", &old_op_def)
.Input(FakeInput())
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("make_poly_list = MakePolyList[N=1, T=DT_INT32](a)", Result());
}
-// Should be able to make a single input into an arbitrary list.
+// Should be able to make a single input/output into an arbitrary list.
// Changing from int32 -> T (where T: list(type) = [DT_INT32] by default).
REGISTER_OP("MakeAnyList")
.Input("a: T")
@@ -268,11 +311,11 @@ TEST_F(OpCompatibilityTest, MakeAnyList) {
ASSERT_OK(NodeDefBuilder("make_any_list", &old_op_def)
.Input(FakeInput())
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("make_any_list = MakeAnyList[T=[DT_INT32]](a)", Result());
}
-// Should be able to make a single polymorphic input into a list of
+// Should be able to make a single polymorphic input/output into a list of
// the same type. Changing from T -> N * T (where N: int = 1 by default).
REGISTER_OP("PolyIntoList")
.Input("a: N * T")
@@ -291,11 +334,11 @@ TEST_F(OpCompatibilityTest, PolyIntoList) {
ASSERT_OK(NodeDefBuilder("poly_into_list", &old_op_def)
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("poly_into_list = PolyIntoList[N=1, T=DT_INT32](a)", Result());
}
-// Should be able to change the name of an input.
+// Should be able to change the name of an input/output.
REGISTER_OP("ChangeName").Input("y: int32").Output("ndef: string");
REGISTER_KERNEL_BUILDER(Name("ChangeName").Device(DEVICE_CPU), TestKernel);
@@ -308,11 +351,11 @@ TEST_F(OpCompatibilityTest, ChangeName) {
ASSERT_OK(NodeDefBuilder("change_name", &old_op_def)
.Input(FakeInput())
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("change_name = ChangeName[](a)", Result());
}
-// Should be able to add an input of type
+// Should be able to add an input/output of type
// N * int32 (where N: int = 0 by default).
REGISTER_OP("AddNInts")
.Input("a: N * int32")
@@ -325,11 +368,11 @@ TEST_F(OpCompatibilityTest, AddNInts) {
ASSERT_OK(
OpDefBuilder("AddNInts").Output("ndef: string").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("add_n_ints", &old_op_def).Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("add_n_ints = AddNInts[N=0]()", Result());
}
-// Should be able to add an input of type N * T
+// Should be able to add an input/output of type N * T
// (where N: int = 0 by default, and T: type = any valid default).
REGISTER_OP("AddNSame")
.Input("a: N * T")
@@ -343,12 +386,38 @@ TEST_F(OpCompatibilityTest, AddNSame) {
ASSERT_OK(
OpDefBuilder("AddNSame").Output("ndef: string").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("add_n_same", &old_op_def).Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("add_n_same = AddNSame[N=0, T=DT_BOOL]()", Result());
}
-// Should be able to add an input of type T
-// (where T: list(type) = [] by default).
+// Should be able to add an input/output of type N * T
+// (where N: int >= 0 = 0 by default, and T an existing type attr).
+REGISTER_OP("AddNSameAsExisting")
+ .Input("a: T")
+ .Input("b: N * T")
+ .Output("ndef: string")
+ .Attr("N: int >= 0 = 0")
+ .Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("AddNSameAsExisting").Device(DEVICE_CPU),
+ TestKernel);
+
+TEST_F(OpCompatibilityTest, AddNSameAsExisting) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AddNSameAsExisting")
+ .Input("a: T")
+ .Output("ndef: string")
+ .Attr("T: type")
+ .Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("add_n_same_as_existing", &old_op_def)
+ .Input(FakeInput(DT_STRING))
+ .Finalize(node_def()));
+ ExpectSuccess(old_op_def);
+ EXPECT_EQ("add_n_same_as_existing = AddNSameAsExisting[N=0, T=DT_STRING](a)",
+ Result());
+}
+
+// Should be able to add an input/output of type T
+// (where T: list(type) >= 0 = [] by default).
REGISTER_OP("AddAnyList")
.Input("a: T")
.Output("ndef: string")
@@ -360,7 +429,7 @@ TEST_F(OpCompatibilityTest, AddAnyList) {
ASSERT_OK(
OpDefBuilder("AddAnyList").Output("ndef: string").Finalize(&old_op_def));
ASSERT_OK(NodeDefBuilder("add_any_list", &old_op_def).Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("add_any_list = AddAnyList[T=[]]()", Result());
}
@@ -381,7 +450,7 @@ TEST_F(OpCompatibilityTest, ShorterAnyList) {
ASSERT_OK(NodeDefBuilder("shorter_any_list", &old_op_def)
.Input(FakeInput(2, DT_BOOL))
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("shorter_any_list = ShorterAnyList[T=[DT_BOOL, DT_BOOL]](a, a:1)",
Result());
}
@@ -402,21 +471,178 @@ TEST_F(OpCompatibilityTest, ShorterSameList) {
ASSERT_OK(NodeDefBuilder("shorter_same_list", &old_op_def)
.Input(FakeInput(2))
.Finalize(node_def()));
- Run(old_op_def);
+ ExpectSuccess(old_op_def);
EXPECT_EQ("shorter_same_list = ShorterSameList[N=2](a, a:1)", Result());
}
-// TODO(josh11b): Negative tests?
-// * add attr w/out default
-// * add non-list input/output
-// * add list input/output with non-empty default
-// * change type of input/output
-// * change order of input/output
-// * change type of attr
-// * Input("foo: T").Attr("T: type") -> Input("foo: T").Attr("T: list(type)")
-
-// What about changing an attr's default? Not technically illegal, but
-// likely should be forbidden since it likely affects semantics.
+// Negative tests -------------------------------------------------------------
+
+// Can't add an attr without a default.
+REGISTER_OP("AddAttrNoDefault").Attr("a: int");
+
+TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
+ ExpectInvalid(old_op_def, "NodeDef missing attr 'a'");
+}
+
+// Can't add a non-list input/output.
+REGISTER_OP("AddSingleInput").Input("a: int32");
+
+TEST_F(OpCompatibilityTest, AddSingleInputFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "expected inputs 'int32' do not match 0 inputs specified");
+}
+
+// Can't add a list input/output without an empty default.
+
+REGISTER_OP("AddNIntsBigDefault").Input("a: N * int32").Attr("N: int = 1");
+REGISTER_OP("AddNSameBigDefault")
+ .Input("a: N * T")
+ .Attr("N: int = 1")
+ .Attr("T: type = DT_INT32");
+REGISTER_OP("AddListBigDefault")
+ .Input("a: T")
+ .Attr("T: list(type) = [DT_INT32]");
+
+TEST_F(OpCompatibilityTest, AddNIntsBigDefaultFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "expected inputs 'int32' do not match 0 inputs specified");
+}
+
+TEST_F(OpCompatibilityTest, AddNSameBigDefaultFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "expected inputs 'int32' do not match 0 inputs specified");
+}
+
+TEST_F(OpCompatibilityTest, AddListBigDefaultFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def).Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "expected inputs 'int32' do not match 0 inputs specified");
+}
+
+// Can't change the type of an input/output.
+
+REGISTER_OP("ChangeType").Input("a: float");
+
+TEST_F(OpCompatibilityTest, ChangeTypeFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("ChangeType").Input("a: int32").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ExpectTypeMismatch(old_op_def);
+}
+
+// Can't change the order of inputs/outputs.
+
+REGISTER_OP("ChangeOrder").Input("a: float").Input("b: int32");
+
+TEST_F(OpCompatibilityTest, ChangeOrderFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("ChangeOrder")
+ .Input("b: int32")
+ .Input("a: float")
+ .Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ExpectTypeMismatch(old_op_def);
+}
+
+// Can't change the type of an attr.
+
+REGISTER_OP("ChangeAttrType").Attr("a: int");
+
+TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) {
+ OpDef old_op_def;
+ ASSERT_OK(
+ OpDefBuilder("ChangeAttrType").Attr("a: bool").Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Attr("a", true)
+ .Finalize(node_def()));
+ ExpectInvalid(old_op_def, "value with type 'bool' when 'int' expected");
+}
+
+// Can't change an attr from a list.
+
+REGISTER_OP("AttrFromList").Attr("a: int");
+
+TEST_F(OpCompatibilityTest, AttrFromListFails) {
+ OpDef old_op_def;
+ ASSERT_OK(
+ OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op_def));
+ ASSERT_OK(
+ NodeDefBuilder("fails", &old_op_def).Attr("a", {5}).Finalize(node_def()));
+ ExpectInvalid(old_op_def, "value with type 'list(int)' when 'int' expected");
+}
+
+// Can't change an attr to a list.
+
+REGISTER_OP("AttrToList").Attr("a: list(int)");
+
+TEST_F(OpCompatibilityTest, AttrToListFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op_def));
+ ASSERT_OK(
+ NodeDefBuilder("fails", &old_op_def).Attr("a", 5).Finalize(node_def()));
+ ExpectInvalid(old_op_def, "value with type 'int' when 'list(int)' expected");
+}
+
+// Can't change an input from polymorphic to a list of any type.
+
+REGISTER_OP("PolymorphicToAnyList").Input("a: T").Attr("T: list(type)");
+
+TEST_F(OpCompatibilityTest, PolymorphicToAnyListFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("PolymorphicToAnyList")
+ .Input("a: T")
+ .Attr("T: type")
+ .Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "value with type 'type' when 'list(type)' expected");
+}
+
+// Can't change an input from a list of the same type to a list of any type.
+
+REGISTER_OP("SameToAnyList")
+ .Input("a: T")
+ .Attr("T: list(type)")
+ .Attr("N: int = 1");
+
+TEST_F(OpCompatibilityTest, SameToAnyListFails) {
+ OpDef old_op_def;
+ ASSERT_OK(OpDefBuilder("SameToAnyList")
+ .Input("a: N * T")
+ .Attr("T: type")
+ .Attr("N: int")
+ .Finalize(&old_op_def));
+ ASSERT_OK(NodeDefBuilder("fails", &old_op_def)
+ .Input(FakeInput(1, DT_INT32))
+ .Finalize(node_def()));
+ ExpectInvalid(old_op_def,
+ "value with type 'type' when 'list(type)' expected");
+}
+
+// Changing an attr's default is not technically illegal, but should
+// be forbidden if it the attr ever didn't exist since it likely
+// affects semantics.
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc
index 515e8bb288..dabf87e18a 100644
--- a/tensorflow/core/framework/op_def_util_test.cc
+++ b/tensorflow/core/framework/op_def_util_test.cc
@@ -117,39 +117,39 @@ TEST_F(ValidateOpDefTest, BadAttrDefault) {
ExpectFailure(
TestProto("name: 'BadAttrDef' attr { name: 'a' "
"type: 'int' default_value { s: 'x' } }"),
- "AttrValue had value with type string when int expected\n\t for "
+ "AttrValue had value with type 'string' when 'int' expected\n\t for "
"attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
"type: 'int' default_value { f: 0.5 } }"),
- "AttrValue had value with type float when int expected\n\t for "
- "attr 'a'\n\t in Op 'BadAttrDef'");
+ "AttrValue had value with type 'float' when 'int' expected\n"
+ "\t for attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(
TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'int' "
"default_value { i: 5 list { i: [2] } } }"),
- "AttrValue had value with type list(int) when int expected\n\t for "
+ "AttrValue had value with type 'list(int)' when 'int' expected\n\t for "
"attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(
TestProto("name: 'BadAttrDef' attr { name: 'a' "
"type: 'list(int)' default_value { f: 0.5 } }"),
- "AttrValue had value with type float when list(int) expected\n\t "
+ "AttrValue had value with type 'float' when 'list(int)' expected\n\t "
"for attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(
TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'list(int)' "
"default_value { list { i: [5] f: [0.5] } } }"),
- "AttrValue had value with type list(float) when list(int) "
+ "AttrValue had value with type 'list(float)' when 'list(int)' "
"expected\n\t for attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
"type: 'type' default_value { } }"),
- "AttrValue missing value with expected type type\n\t for "
+ "AttrValue missing value with expected type 'type'\n\t for "
"attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
"type: 'shape' default_value { } }"),
- "AttrValue missing value with expected type shape\n\t for "
+ "AttrValue missing value with expected type 'shape'\n\t for "
"attr 'a'\n\t in Op 'BadAttrDef'");
ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
"type: 'tensor' default_value { } }"),
- "AttrValue missing value with expected type tensor\n\t for "
+ "AttrValue missing value with expected type 'tensor'\n\t for "
"attr 'a'\n\t in Op 'BadAttrDef'");
// default_value {} is indistinguishable from default_value{ list{} } (one
@@ -242,14 +242,13 @@ TEST_F(ValidateOpDefTest, BadAttrAllowed) {
ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
.Attr("x: list({'foo', 'bar'}) = ['baz']")),
"attr 'x' of \"baz\" is not in the list of allowed values");
- ExpectFailure(TestProto(
- "name: 'BadAttrtude' attr { name: 'a' "
- "type: 'string' allowed_values { s: 'not list' } }"),
- "with type string when list(string) expected");
+ ExpectFailure(TestProto("name: 'BadAttrtude' attr { name: 'a' "
+ "type: 'string' allowed_values { s: 'not list' } }"),
+ "with type 'string' when 'list(string)' expected");
ExpectFailure(
TestProto("name: 'BadAttrtude' attr { name: 'a' "
"type: 'string' allowed_values { list { i: [6] } } }"),
- "with type list(int) when list(string) expected");
+ "with type 'list(int)' when 'list(string)' expected");
}
TEST_F(ValidateOpDefTest, BadArgType) {
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index 18473aea2e..37b3b0e6ef 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -27,7 +27,7 @@
#undef REGISTER_PARTITION
*/
-#ifndef __ANDROID__
+#if !defined(__ANDROID__)
// Call "m" for all number types that support the comparison operations "<" and
// ">".
@@ -68,7 +68,7 @@
m(float); \
m(double)
-#else // __ANDROID__
+#else // defined(__ANDROID__)
#define TF_CALL_REAL_NUMBER_TYPES(m) \
m(float); \
@@ -85,6 +85,6 @@
// Maybe we could put an empty macro here for Android?
#define TF_CALL_GPU_NUMBER_TYPES(m) m(float)
-#endif // __ANDROID__
+#endif // defined(__ANDROID__)
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index 01b9fca3b6..666f5e4593 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -139,7 +139,7 @@ DataTypeVector AllTypes() {
DT_QINT8, DT_QUINT8, DT_QINT32};
}
-#ifndef __ANDROID__
+#if !defined(__ANDROID__)
DataTypeVector RealNumberTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8};
@@ -157,7 +157,7 @@ DataTypeVector NumberTypes() {
DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32};
}
-#else // __ANDROID__
+#else // defined(__ANDROID__)
DataTypeVector RealNumberTypes() { return {DT_FLOAT, DT_INT32}; }
@@ -171,7 +171,7 @@ DataTypeVector RealAndQuantizedTypes() {
return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
}
-#endif // __ANDROID__
+#endif // defined(__ANDROID__)
// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
// is_simple<T> in tensor.cc (and possible choose a more general name?)
diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc
index 2fa6f075c0..1a523d5214 100644
--- a/tensorflow/core/graph/optimizer_cse.cc
+++ b/tensorflow/core/graph/optimizer_cse.cc
@@ -89,7 +89,7 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
size_t h = Hash64(str_to_hash);
-#if !defined(__ANDROID__) && !defined(ANDROID)
+#if !defined(__ANDROID__)
// Hash the attrs. For example, this makes sure different constants
// end up in different hash buckets.
string tmp;
diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc
index 426e868735..77e4be6c66 100644
--- a/tensorflow/core/kernels/aggregate_ops.cc
+++ b/tensorflow/core/kernels/aggregate_ops.cc
@@ -34,7 +34,7 @@ class AddNOp : public OpKernel {
#define I(IDX) ctx->input(IDX).flat<T>()
-#if defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID)
+#if defined(__ANDROID__)
// On Android, we only support additions of two arguments, so we
// can reduce the number of template instantiations.
OP_REQUIRES(ctx, num == 2,
@@ -103,7 +103,7 @@ class AddNOp : public OpKernel {
functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
}
-#endif // defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID)
+#endif // defined(__ANDROID__)
#undef I
}
diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc
index 5d39b88166..377346691f 100644
--- a/tensorflow/core/kernels/cwise_op_abs.cc
+++ b/tensorflow/core/kernels/cwise_op_abs.cc
@@ -2,7 +2,7 @@
namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Abs", functor::abs, float, double, int32, int64);
-#ifndef __ANDROID__
+#if !defined(__ANDROID__)
REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU),
UnaryOp<CPUDevice, functor::abs<complex64>>);
#endif
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index cf848b86d1..42fd5ced69 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -355,7 +355,7 @@ struct SelectFunctor<CPUDevice, T> {
// device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using
// the functor "F" (e.g., functor:sqrt).
-#ifdef __ANDROID__
+#if defined(__ANDROID__)
// On Android, only register the first type (float)
#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
@@ -364,7 +364,7 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
REGISTER(OP, D, N, F, T0)
-#else // !__ANDROID__
+#else // !defined(__ANDROID__)
#define REGISTER2(OP, D, N, F, T0, T1) \
REGISTER(OP, D, N, F, T0) \
REGISTER(OP, D, N, F, T1)
@@ -383,7 +383,7 @@ struct SelectFunctor<CPUDevice, T> {
#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
REGISTER3(OP, D, N, F, T4, T5, T6)
-#endif // __ANDROID__
+#endif // defined(__ANDROID__)
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index e5abf5906f..4b67304a37 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -9,7 +9,7 @@
#include "tensorflow/core/public/tensor.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#ifndef __ANDROID__
+#if !defined(__ANDROID__)
#include "tensorflow/core/util/work_sharder.h"
#endif
@@ -51,7 +51,7 @@ class LRNOp : public OpKernel {
context->allocate_output(
0, TensorShape({batch, rows, cols, depth}), &output));
-#ifdef __ANDROID__
+#if !defined(__ANDROID__)
MognetLRN(in, batch, rows, cols, depth, output);
#else
const int nodes = cols * rows;
@@ -123,7 +123,7 @@ class LRNOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LRN").Device(DEVICE_CPU), LRNOp);
-#ifndef __ANDROID__
+#if !defined(__ANDROID__)
class LRNGradOp : public OpKernel {
public:
@@ -223,6 +223,6 @@ class LRNGradOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LRNGrad").Device(DEVICE_CPU), LRNGradOp);
-#endif // __ANDROID__
+#endif // !defined(__ANDROID__)
} // namespace tensorflow
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index 1fb168ef5d..5e9ab88979 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -34,7 +34,8 @@ int NumSchedulableCPUs() {
perror("sched_getaffinity");
#endif
#if defined(__APPLE__) && defined(__MACH__)
- return std::thread::hardware_concurrency();
+ unsigned int count = std::thread::hardware_concurrency();
+ if (count > 0) return static_cast<int>(count);
#endif
const int kDefaultCores = 4; // Semi-conservative guess
fprintf(stderr, "can't determine number of CPU cores: assuming %d\n",
@@ -45,7 +46,7 @@ int NumSchedulableCPUs() {
void* aligned_malloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__)
return memalign(minimum_alignment, size);
-#else // !__ANDROID__
+#else // !defined(__ANDROID__)
void* ptr = NULL;
// posix_memalign requires that the requested alignment be at least
// sizeof(void*). In this case, fall back on malloc which should return
diff --git a/tensorflow/core/platform/test.cc b/tensorflow/core/platform/test.cc
index 21c6905683..afc5cc57be 100644
--- a/tensorflow/core/platform/test.cc
+++ b/tensorflow/core/platform/test.cc
@@ -8,8 +8,7 @@
namespace tensorflow {
namespace testing {
-#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \
- defined(PLATFORM_GOOGLE_ANDROID)
+#if defined(PLATFORM_GOOGLE) || defined(__ANDROID__)
string TmpDir() { return FLAGS_test_tmpdir; }
int RandomSeed() { return FLAGS_test_random_seed; }
#else
diff --git a/tensorflow/core/platform/test_main.cc b/tensorflow/core/platform/test_main.cc
index 11230c3f7b..4b2451e22a 100644
--- a/tensorflow/core/platform/test_main.cc
+++ b/tensorflow/core/platform/test_main.cc
@@ -7,8 +7,7 @@
#include "tensorflow/core/platform/port.h"
-#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \
- defined(PLATFORM_GOOGLE_ANDROID)
+#if defined(PLATFORM_GOOGLE) || defined(__ANDROID__)
// main() is supplied by gunit_main
#else
#include "gtest/gtest.h"
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index 2b53a64cf1..09f8d722df 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -196,7 +196,7 @@ class Tracing::TraceMe {
} // namespace port
} // namespace tensorflow
-#if defined(PLATFORM_GOOGLE) && !defined(ANDROID) && !defined(__ANDROID__)
+#if defined(PLATFORM_GOOGLE)
#include "tensorflow/core/platform/google/tracing_impl.h"
#else
#include "tensorflow/core/platform/default/tracing_impl.h"
diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
index 8778739c8a..98f703f95c 100644
--- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md
+++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
@@ -37,6 +37,8 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
* [`tf.add_check_numerics_ops()`](#add_check_numerics_ops)
* [`tf.Assert(condition, data, summarize=None, name=None)`](#Assert)
* [`tf.Print(input_, data, message=None, first_n=None, summarize=None, name=None)`](#Print)
+* [Other Functions and Classes](#AUTOGENERATED-other-functions-and-classes)
+ * [`class tf.xrange`](#xrange)
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
@@ -593,3 +595,16 @@ evaluating.
Same tensor as `input_`.
+
+## Other Functions and Classes <a class="md-anchor" id="AUTOGENERATED-other-functions-and-classes"></a>
+- - -
+
+### `class tf.xrange` <a class="md-anchor" id="xrange"></a>
+
+xrange(stop) -> xrange object
+xrange(start, stop[, step]) -> xrange object
+
+Like range(), but instead of returning a list, returns an object that
+generates the numbers in the range on demand. For looping, this is
+slightly faster than range() and more memory efficient.
+
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index a62ef9e711..1f88ad0ca4 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1229,6 +1229,12 @@ types.as_dtype() function.
- - -
+#### `tf.DType.is_floating` <a class="md-anchor" id="DType.is_floating"></a>
+
+Returns whether this is a (real) floating point type.
+
+- - -
+
#### `tf.DType.max` <a class="md-anchor" id="DType.max"></a>
Returns the maximum representable value in this data type.
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index c2b2a72d2b..f104011ee6 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -121,6 +121,7 @@
* [`edit_distance`](../../api_docs/python/math_ops.md#edit_distance)
* [`exp`](../../api_docs/python/math_ops.md#exp)
* [`floor`](../../api_docs/python/math_ops.md#floor)
+ * [`floordiv`](../../api_docs/python/math_ops.md#floordiv)
* [`imag`](../../api_docs/python/math_ops.md#imag)
* [`inv`](../../api_docs/python/math_ops.md#inv)
* [`invert_permutation`](../../api_docs/python/math_ops.md#invert_permutation)
@@ -158,6 +159,7 @@
* [`square`](../../api_docs/python/math_ops.md#square)
* [`sub`](../../api_docs/python/math_ops.md#sub)
* [`transpose`](../../api_docs/python/math_ops.md#transpose)
+ * [`truediv`](../../api_docs/python/math_ops.md#truediv)
* [`unique`](../../api_docs/python/math_ops.md#unique)
* [`unsorted_segment_sum`](../../api_docs/python/math_ops.md#unsorted_segment_sum)
* [`where`](../../api_docs/python/math_ops.md#where)
@@ -188,6 +190,7 @@
* [`tuple`](../../api_docs/python/control_flow_ops.md#tuple)
* [`verify_tensor_all_finite`](../../api_docs/python/control_flow_ops.md#verify_tensor_all_finite)
* [`where`](../../api_docs/python/control_flow_ops.md#where)
+ * [`xrange`](../../api_docs/python/control_flow_ops.md#xrange)
* **[Images](../../api_docs/python/image.md)**:
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 1f8fece3fe..c0d742c89b 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -76,6 +76,9 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
* [`tf.unique(x, name=None)`](#unique)
* [`tf.edit_distance(hypothesis, truth, normalize=True, name='edit_distance')`](#edit_distance)
* [`tf.invert_permutation(x, name=None)`](#invert_permutation)
+* [Other Functions and Classes](#AUTOGENERATED-other-functions-and-classes)
+ * [`tf.floordiv(x, y, name=None)`](#floordiv)
+ * [`tf.truediv(x, y, name=None)`](#truediv)
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
@@ -1892,3 +1895,74 @@ invert_permutation(x) ==> [2, 4, 3, 0, 1]
A `Tensor` of type `int32`. 1-D.
+
+## Other Functions and Classes <a class="md-anchor" id="AUTOGENERATED-other-functions-and-classes"></a>
+- - -
+
+### `tf.floordiv(x, y, name=None)` <a class="md-anchor" id="floordiv"></a>
+
+Divides `x / y` elementwise, rounding down for floating point.
+
+The same as `tf.div(x,y)`, but uses `tf.floor(tf.div(x,y))` for floating
+point arguments so that the result is always an integer (though possibly an
+integer represented as floating point). This op is generated by `x // y`
+floor division in Python 3 and in Python 2.7 with
+`from __future__ import division`.
+
+Note that for efficiency, __floordiv__ uses C semantics for negative numbers
+(unlike Python and Numpy).
+
+`x` and `y` must have the same type, and the result will have the same type
+as well.
+
+##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` numerator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
+
+ `x / y` rounded down (except possibly for integers in C).
+
+##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
+
+
+* <b>`TypeError`</b>: If the inputs are complex.
+
+
+- - -
+
+### `tf.truediv(x, y, name=None)` <a class="md-anchor" id="truediv"></a>
+
+Divides x / y elementwise, always producing floating point results.
+
+The same as `tf.div` for floating point arguments, but casts integer arguments
+to floating point before dividing so that the result is always floating point.
+This op is generated by normal `x / y` division in Python 3 and in Python 2.7
+with `from __future__ import division`. If you want integer division that
+rounds down, use `x // y` or `tf.floordiv`.
+
+`x` and `y` must have the same numeric type. If the inputs are floating
+point, the output will have the same type. If the inputs are integral, the
+inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
+and `int64` (matching the behavior of Numpy).
+
+##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
+
+
+* <b>`x`</b>: `Tensor` numerator of numeric type.
+* <b>`y`</b>: `Tensor` denominator of numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
+
+ `x / y` evaluated in floating point.
+
+##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
+
+
+* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
+
+
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 713606b29f..9014a81150 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -6,7 +6,7 @@ You can install TensorFlow using our provided binary packages or from source.
The TensorFlow Python API currently requires Python 2.7: we are
[working](https://github.com/tensorflow/tensorflow/issues/1) on adding support
-for Python 3.0.
+for Python 3.
The simplest way to install TensorFlow is using
[pip](https://pypi.python.org/pypi/pip) for both Linux and Mac.
@@ -307,6 +307,9 @@ Follow installation instructions [here](http://docs.scipy.org/doc/numpy/user/ins
```bash
$ bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
+# To build with GPU support:
+$ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
+
$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# The name of the .whl file will depend on your platform.
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py b/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py
index e2f44db60b..2a0b992dc9 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py
@@ -1,4 +1,6 @@
"""Test that user ops can be used as expected."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md
index 9f3b3985f1..9dd2456e0b 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/index.md
+++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md
@@ -137,6 +137,11 @@ Tensorflow system can reference and use the Op when requested.
Python op wrappers are created automatically in
`bazel-genfiles/tensorflow/python/ops/gen_user_ops.py` for all ops placed in the
[`tensorflow/core/user_ops`][user_ops] directory when you build Tensorflow.
+
+> Note: The generated function will be given a snake_case name (to comply with
+> [PEP8](https://www.python.org/dev/peps/pep-0008/)). So if your op is named
+> `ZeroOut` in the C++ files, the python function will be called `zero_out`.
+
Those ops are imported into
[`tensorflow/python/user_ops/user_ops.py`][python-user_ops] with the statement:
@@ -897,8 +902,10 @@ There are several ways to preserve backwards-compatibility.
can't be done in a compatible way (for example, adding an input, or making a
single input into a list).
-If you cannot make your change to an operation backwards compatible, then
-create a new operation with a new name with the new semantics.
+The full list of safe and unsafe changes can be found in
+[tensorflow/core/framework/op_compatibility_test.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_compatibility_test.cc).
+If you cannot make your change to an operation backwards compatible, then create
+a new operation with a new name with the new semantics.
## GPU Support <a class="md-anchor" id="mult-archs"></a>
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_1_test.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_1_test.py
index 321f603adf..88964be10d 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_1_test.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_1_test.py
@@ -1,5 +1,9 @@
"""Test for version 1 of the zero_out op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
index ce38e435fa..a4a1cf1a48 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
@@ -1,5 +1,9 @@
"""Test for version 2 of the zero_out op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_3_test.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_3_test.py
index eaf45d1ec4..5f54ce592b 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_3_test.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_3_test.py
@@ -1,5 +1,9 @@
"""Test for version 3 of the zero_out op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_grad_2.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_grad_2.py
index 61fb92db27..07e312cfa1 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_grad_2.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_grad_2.py
@@ -1,5 +1,9 @@
"""The gradient of the tutorial zero_out op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
diff --git a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
index 6d77ed50e2..1a1fc1af79 100644
--- a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
+++ b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
@@ -1,4 +1,6 @@
"""Converts MNIST data to TFRecords file format with Example protos."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import os
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
index 7e9d8355a9..c65d808c6a 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
@@ -5,7 +5,10 @@ Command to run this py_binary target:
bazel run -c opt \
<...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import os.path
import time
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
index ef62242d1e..98c51eba24 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -5,7 +5,10 @@ Command to run this py_binary target:
bazel run -c opt \
<...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded_var
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import os.path
import time
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
index e467535ffe..7eb9eb5c51 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
@@ -8,6 +8,8 @@ for context.
YOU MUST run convert_to_records before running this (but you only need to
run it once).
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import os.path
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index 60eaf5978a..fcd891c58b 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -90,8 +90,8 @@ which digit a given image is of.
For the purposes of this tutorial, we're going to want our labels
as "one-hot vectors". A one-hot vector is a vector which is 0 in most
dimensions, and 1 in a single dimension. In this case, the \\(n\\)th digit will be
-represented as a vector which is 1 in the \\(n\\)th dimensions. For example, 0
-would be \\([1,0,0,0,0,0,0,0,0,0,0]\\).
+represented as a vector which is 1 in the \\(n\\)th dimensions. For example, 3
+would be \\([0,0,0,1,0,0,0,0,0,0]\\).
Consequently, `mnist.train.labels` is a
`[60000, 10]` array of floats.
@@ -157,9 +157,10 @@ If you expand that equation out, you get:
$$\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$
But it's often more helpful to think of softmax the first way:
-exponentiating its inputs and then normalizing them. The exponentiation
-means that one unit more evidence increases the weight given to any hypothesis
-multiplicatively. And conversely, having one less unit of evidence means that a
+exponentiating its inputs and then normalizing them.
+The exponentiation means that one more unit of evidence increases the weight
+given to any hypothesis multiplicatively.
+And conversely, having one less unit of evidence means that a
hypothesis gets a fraction of its earlier weight. No hypothesis ever has zero
or negative weight. Softmax then normalizes these weights, so that they add up
to one, forming a valid probability distribution. (To get more intuition about
diff --git a/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py b/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
index df974ce715..3fbdfc0697 100644
--- a/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
@@ -7,13 +7,17 @@ MNIST tutorial:
https://tensorflow.org/tutorials/mnist/tf/index.html
"""
-from __future__ import print_function
# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
import time
import tensorflow.python.platform
import numpy
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.g3doc.tutorials.mnist import input_data
@@ -101,14 +105,14 @@ def do_eval(sess,
"""
# And run one epoch of eval.
true_count = 0 # Counts the number of correct predictions.
- steps_per_epoch = int(data_set.num_examples / FLAGS.batch_size)
+ steps_per_epoch = data_set.num_examples // FLAGS.batch_size
num_examples = steps_per_epoch * FLAGS.batch_size
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
- precision = float(true_count) / float(num_examples)
+ precision = true_count / num_examples
print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
(num_examples, true_count, precision))
diff --git a/tensorflow/g3doc/tutorials/mnist/input_data.py b/tensorflow/g3doc/tutorials/mnist/input_data.py
index e700680aa4..391d133ea1 100644
--- a/tensorflow/g3doc/tutorials/mnist/input_data.py
+++ b/tensorflow/g3doc/tutorials/mnist/input_data.py
@@ -1,10 +1,14 @@
"""Functions for downloading and reading MNIST data."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import gzip
import os
import urllib
import numpy
+from six.moves import xrange # pylint: disable=redefined-builtin
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist.py b/tensorflow/g3doc/tutorials/mnist/mnist.py
index acf4d01dd1..64be52293a 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist.py
+++ b/tensorflow/g3doc/tutorials/mnist/mnist.py
@@ -17,6 +17,10 @@ https://tensorflow.org/get_started/os_setup.html
MNIST tutorial:
https://tensorflow.org/tutorials/mnist/tf/index.html
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
@@ -31,118 +35,118 @@ IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
def inference(images, hidden1_units, hidden2_units):
- """Build the MNIST model up to where it may be used for inference.
-
- Args:
- images: Images placeholder, from inputs().
- hidden1: Size of the first hidden layer.
- hidden2: Size of the second hidden layer.
-
- Returns:
- softmax_linear: Output tensor with the computed logits.
- """
- # Hidden 1
- with tf.name_scope('hidden1') as scope:
- weights = tf.Variable(
- tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
- stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
- name='weights')
- biases = tf.Variable(tf.zeros([hidden1_units]),
- name='biases')
- hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
- # Hidden 2
- with tf.name_scope('hidden2') as scope:
- weights = tf.Variable(
- tf.truncated_normal([hidden1_units, hidden2_units],
- stddev=1.0 / math.sqrt(float(hidden1_units))),
- name='weights')
- biases = tf.Variable(tf.zeros([hidden2_units]),
- name='biases')
- hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
- # Linear
- with tf.name_scope('softmax_linear') as scope:
- weights = tf.Variable(
- tf.truncated_normal([hidden2_units, NUM_CLASSES],
- stddev=1.0 / math.sqrt(float(hidden2_units))),
- name='weights')
- biases = tf.Variable(tf.zeros([NUM_CLASSES]),
- name='biases')
- logits = tf.matmul(hidden2, weights) + biases
- return logits
+ """Build the MNIST model up to where it may be used for inference.
+
+ Args:
+ images: Images placeholder, from inputs().
+ hidden1: Size of the first hidden layer.
+ hidden2: Size of the second hidden layer.
+
+ Returns:
+ softmax_linear: Output tensor with the computed logits.
+ """
+ # Hidden 1
+ with tf.name_scope('hidden1') as scope:
+ weights = tf.Variable(
+ tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
+ stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
+ name='weights')
+ biases = tf.Variable(tf.zeros([hidden1_units]),
+ name='biases')
+ hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
+ # Hidden 2
+ with tf.name_scope('hidden2') as scope:
+ weights = tf.Variable(
+ tf.truncated_normal([hidden1_units, hidden2_units],
+ stddev=1.0 / math.sqrt(float(hidden1_units))),
+ name='weights')
+ biases = tf.Variable(tf.zeros([hidden2_units]),
+ name='biases')
+ hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
+ # Linear
+ with tf.name_scope('softmax_linear') as scope:
+ weights = tf.Variable(
+ tf.truncated_normal([hidden2_units, NUM_CLASSES],
+ stddev=1.0 / math.sqrt(float(hidden2_units))),
+ name='weights')
+ biases = tf.Variable(tf.zeros([NUM_CLASSES]),
+ name='biases')
+ logits = tf.matmul(hidden2, weights) + biases
+ return logits
def loss(logits, labels):
- """Calculates the loss from the logits and the labels.
-
- Args:
- logits: Logits tensor, float - [batch_size, NUM_CLASSES].
- labels: Labels tensor, int32 - [batch_size].
-
- Returns:
- loss: Loss tensor of type float.
- """
- # Convert from sparse integer labels in the range [0, NUM_CLASSSES)
- # to 1-hot dense float vectors (that is we will have batch_size vectors,
- # each with NUM_CLASSES values, all of which are 0.0 except there will
- # be a 1.0 in the entry corresponding to the label).
- batch_size = tf.size(labels)
- labels = tf.expand_dims(labels, 1)
- indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
- concated = tf.concat(1, [indices, labels])
- onehot_labels = tf.sparse_to_dense(
- concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
- onehot_labels,
- name='xentropy')
- loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
- return loss
+ """Calculates the loss from the logits and the labels.
+
+ Args:
+ logits: Logits tensor, float - [batch_size, NUM_CLASSES].
+ labels: Labels tensor, int32 - [batch_size].
+
+ Returns:
+ loss: Loss tensor of type float.
+ """
+ # Convert from sparse integer labels in the range [0, NUM_CLASSSES)
+ # to 1-hot dense float vectors (that is we will have batch_size vectors,
+ # each with NUM_CLASSES values, all of which are 0.0 except there will
+ # be a 1.0 in the entry corresponding to the label).
+ batch_size = tf.size(labels)
+ labels = tf.expand_dims(labels, 1)
+ indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
+ concated = tf.concat(1, [indices, labels])
+ onehot_labels = tf.sparse_to_dense(
+ concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
+ cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
+ onehot_labels,
+ name='xentropy')
+ loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
+ return loss
def training(loss, learning_rate):
- """Sets up the training Ops.
+ """Sets up the training Ops.
- Creates a summarizer to track the loss over time in TensorBoard.
+ Creates a summarizer to track the loss over time in TensorBoard.
- Creates an optimizer and applies the gradients to all trainable variables.
+ Creates an optimizer and applies the gradients to all trainable variables.
- The Op returned by this function is what must be passed to the
- `sess.run()` call to cause the model to train.
+ The Op returned by this function is what must be passed to the
+ `sess.run()` call to cause the model to train.
- Args:
- loss: Loss tensor, from loss().
- learning_rate: The learning rate to use for gradient descent.
+ Args:
+ loss: Loss tensor, from loss().
+ learning_rate: The learning rate to use for gradient descent.
- Returns:
- train_op: The Op for training.
- """
- # Add a scalar summary for the snapshot loss.
- tf.scalar_summary(loss.op.name, loss)
- # Create the gradient descent optimizer with the given learning rate.
- optimizer = tf.train.GradientDescentOptimizer(learning_rate)
- # Create a variable to track the global step.
- global_step = tf.Variable(0, name='global_step', trainable=False)
- # Use the optimizer to apply the gradients that minimize the loss
- # (and also increment the global step counter) as a single training step.
- train_op = optimizer.minimize(loss, global_step=global_step)
- return train_op
+ Returns:
+ train_op: The Op for training.
+ """
+ # Add a scalar summary for the snapshot loss.
+ tf.scalar_summary(loss.op.name, loss)
+ # Create the gradient descent optimizer with the given learning rate.
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+ # Create a variable to track the global step.
+ global_step = tf.Variable(0, name='global_step', trainable=False)
+ # Use the optimizer to apply the gradients that minimize the loss
+ # (and also increment the global step counter) as a single training step.
+ train_op = optimizer.minimize(loss, global_step=global_step)
+ return train_op
def evaluation(logits, labels):
- """Evaluate the quality of the logits at predicting the label.
-
- Args:
- logits: Logits tensor, float - [batch_size, NUM_CLASSES].
- labels: Labels tensor, int32 - [batch_size], with values in the
- range [0, NUM_CLASSES).
-
- Returns:
- A scalar int32 tensor with the number of examples (out of batch_size)
- that were predicted correctly.
- """
- # For a classifier model, we can use the in_top_k Op.
- # It returns a bool tensor with shape [batch_size] that is true for
- # the examples where the label's is was in the top k (here k=1)
- # of all logits for that example.
- correct = tf.nn.in_top_k(logits, labels, 1)
- # Return the number of true entries.
- return tf.reduce_sum(tf.cast(correct, tf.int32))
+ """Evaluate the quality of the logits at predicting the label.
+
+ Args:
+ logits: Logits tensor, float - [batch_size, NUM_CLASSES].
+ labels: Labels tensor, int32 - [batch_size], with values in the
+ range [0, NUM_CLASSES).
+
+ Returns:
+ A scalar int32 tensor with the number of examples (out of batch_size)
+ that were predicted correctly.
+ """
+ # For a classifier model, we can use the in_top_k Op.
+ # It returns a bool tensor with shape [batch_size] that is true for
+ # the examples where the label's is was in the top k (here k=1)
+ # of all logits for that example.
+ correct = tf.nn.in_top_k(logits, labels, 1)
+ # Return the number of true entries.
+ return tf.reduce_sum(tf.cast(correct, tf.int32))
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py b/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
index c766af474c..8bc9b04911 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
+++ b/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
@@ -3,6 +3,8 @@
See extensive documentation at
http://tensorflow.org/tutorials/mnist/beginners/index.md
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
# Import data
diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
index 173944950c..a9e0f28436 100644
--- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
@@ -1,4 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import collections
@@ -6,6 +9,7 @@ import math
import numpy as np
import os
import random
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import urllib
import zipfile
@@ -81,7 +85,7 @@ def generate_batch(batch_size, num_skips, skip_window):
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
- for i in range(batch_size / num_skips):
+ for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
targets_to_avoid = [ skip_window ]
for j in range(num_skips):
@@ -111,7 +115,7 @@ num_skips = 2 # How many times to reuse an input to generate a label.
# construction are also the most frequent.
valid_size = 16 # Random set of words to evaluate similarity on.
valid_window = 100 # Only pick dev samples in the head of the distribution.
-valid_examples = np.array(random.sample(xrange(valid_window), valid_size))
+valid_examples = np.array(random.sample(np.arange(valid_window), valid_size))
num_sampled = 64 # Number of negative examples to sample.
graph = tf.Graph()
@@ -216,7 +220,7 @@ try:
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
plot_only = 500
low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only,:])
- labels = dictionary.keys()[:plot_only]
+ labels = list(dictionary.keys())[:plot_only]
plot_with_labels(low_dim_embs, labels)
except ImportError:
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py
index c417b086d6..29bd9602bb 100644
--- a/tensorflow/models/embedding/word2vec.py
+++ b/tensorflow/models/embedding/word2vec.py
@@ -13,9 +13,12 @@ The key ops used are:
* GradientDescentOptimizer for optimizing the loss.
* skipgram custom op that does input processing.
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import os
+from six.moves import xrange # pylint: disable=redefined-builtin
import sys
import threading
import time
diff --git a/tensorflow/models/embedding/word2vec_optimized.py b/tensorflow/models/embedding/word2vec_optimized.py
index 4d5b3fe58d..e6fccf689a 100644
--- a/tensorflow/models/embedding/word2vec_optimized.py
+++ b/tensorflow/models/embedding/word2vec_optimized.py
@@ -12,9 +12,12 @@ The key ops used are:
* neg_train custom op that efficiently calculates and applies the gradient using
true SGD.
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import os
+from six.moves import xrange # pylint: disable=redefined-builtin
import sys
import threading
import time
diff --git a/tensorflow/models/embedding/word2vec_optimized_test.py b/tensorflow/models/embedding/word2vec_optimized_test.py
index bc109d594e..caf1d6f58c 100644
--- a/tensorflow/models/embedding/word2vec_optimized_test.py
+++ b/tensorflow/models/embedding/word2vec_optimized_test.py
@@ -1,5 +1,9 @@
"""Tests for word2vec_optimized module."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import tensorflow.python.platform
diff --git a/tensorflow/models/embedding/word2vec_test.py b/tensorflow/models/embedding/word2vec_test.py
index 62a033da6f..24172b74b4 100644
--- a/tensorflow/models/embedding/word2vec_test.py
+++ b/tensorflow/models/embedding/word2vec_test.py
@@ -1,5 +1,9 @@
"""Tests for word2vec module."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import tensorflow.python.platform
diff --git a/tensorflow/models/image/alexnet/alexnet_benchmark.py b/tensorflow/models/image/alexnet/alexnet_benchmark.py
index e4be47ff38..740524fd0b 100644
--- a/tensorflow/models/image/alexnet/alexnet_benchmark.py
+++ b/tensorflow/models/image/alexnet/alexnet_benchmark.py
@@ -14,9 +14,13 @@ Forward-backward pass:
Run on Tesla K40c: 480 +/- 48 ms / batch
Run on Titan X: 244 +/- 30 ms / batch
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
from datetime import datetime
import math
+from six.moves import xrange # pylint: disable=redefined-builtin
import time
import tensorflow.python.platform
@@ -194,7 +198,9 @@ def run_benchmark():
init = tf.initialize_all_variables()
# Start running operations on the Graph.
- sess = tf.Session('')
+ config = tf.ConfigProto()
+ config.gpu_options.allocator_type = 'BFC'
+ sess = tf.Session(config=config)
sess.run(init)
# Run the forward benchmark.
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 6d79029dc8..8fcd790130 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -15,8 +15,11 @@ Summary of available functions:
# Create a graph to run one step of training with respect to the loss.
train_op = train(loss, global_step)
"""
-from __future__ import print_function
# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import gzip
import os
import re
@@ -25,6 +28,7 @@ import tarfile
import urllib
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10_input
diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py
index c8e6ec067f..789cf5e9d3 100644
--- a/tensorflow/models/image/cifar10/cifar10_eval.py
+++ b/tensorflow/models/image/cifar10/cifar10_eval.py
@@ -15,7 +15,10 @@ data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
from datetime import datetime
import math
import time
@@ -83,7 +86,7 @@ def eval_once(saver, summary_writer, top_k_op, summary_op):
step += 1
# Compute precision @ 1.
- precision = float(true_count) / float(total_sample_count)
+ precision = true_count / total_sample_count
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
summary = tf.Summary()
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
index 686f1bf987..361d7d0139 100644
--- a/tensorflow/models/image/cifar10/cifar10_input.py
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -1,5 +1,9 @@
"""Routine for decoding the CIFAR-10 binary file format."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/models/image/cifar10/cifar10_input_test.py b/tensorflow/models/image/cifar10/cifar10_input_test.py
index d43f5aedcf..d3806e4937 100644
--- a/tensorflow/models/image/cifar10/cifar10_input_test.py
+++ b/tensorflow/models/image/cifar10/cifar10_input_test.py
@@ -1,5 +1,9 @@
"""Tests for cifar10 input."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import tensorflow.python.platform
diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
index e0396ea782..5867a6ac94 100644
--- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
@@ -20,7 +20,10 @@ data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
from datetime import datetime
import os.path
import re
@@ -30,6 +33,7 @@ import time
import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
# pylint: disable=unused-import,g-bad-import-order
@@ -236,8 +240,8 @@ def train():
if step % 10 == 0:
num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
- examples_per_sec = num_examples_per_step / float(duration)
- sec_per_batch = float(duration) / FLAGS.num_gpus
+ examples_per_sec = num_examples_per_step / duration
+ sec_per_batch = duration / FLAGS.num_gpus
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
index 1a70b39d57..87cac46927 100644
--- a/tensorflow/models/image/cifar10/cifar10_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -17,7 +17,10 @@ data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
from datetime import datetime
import os.path
import time
@@ -26,7 +29,7 @@ import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
-
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
@@ -90,7 +93,7 @@ def train():
if step % 10 == 0:
num_examples_per_step = FLAGS.batch_size
- examples_per_sec = num_examples_per_step / float(duration)
+ examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
diff --git a/tensorflow/models/image/mnist/BUILD b/tensorflow/models/image/mnist/BUILD
index 76b31d0feb..6774810e82 100644
--- a/tensorflow/models/image/mnist/BUILD
+++ b/tensorflow/models/image/mnist/BUILD
@@ -11,9 +11,7 @@ py_binary(
"convolutional.py",
],
visibility = ["//tensorflow:__subpackages__"],
- deps = [
- "//tensorflow:tensorflow_py",
- ],
+ deps = ["//tensorflow:tensorflow_py"],
)
py_test(
@@ -26,9 +24,7 @@ py_test(
"--self_test=True",
],
main = "convolutional.py",
- deps = [
- "//tensorflow:tensorflow_py",
- ],
+ deps = ["//tensorflow:tensorflow_py"],
)
filegroup(
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index f9453654ef..e388b772fe 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -4,7 +4,10 @@ This should achieve a test error of 0.8%. Please keep this model as simple and
linear as possible, it is meant as a tutorial for simple convolutional models.
Run with --self_test on the command line to exectute a short self-test.
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import gzip
import os
import sys
@@ -13,6 +16,7 @@ import urllib
import tensorflow.python.platform
import numpy
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
@@ -145,9 +149,10 @@ def main(argv=None): # pylint: disable=unused-argument
seed=SEED))
conv2_biases = tf.Variable(tf.constant(0.1, shape=[64]))
fc1_weights = tf.Variable( # fully connected, depth 512.
- tf.truncated_normal([IMAGE_SIZE / 4 * IMAGE_SIZE / 4 * 64, 512],
- stddev=0.1,
- seed=SEED))
+ tf.truncated_normal(
+ [IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
+ stddev=0.1,
+ seed=SEED))
fc1_biases = tf.Variable(tf.constant(0.1, shape=[512]))
fc2_weights = tf.Variable(
tf.truncated_normal([512, NUM_LABELS],
@@ -236,7 +241,7 @@ def main(argv=None): # pylint: disable=unused-argument
tf.initialize_all_variables().run()
print('Initialized!')
# Loop through training steps.
- for step in xrange(int(num_epochs * train_size / BATCH_SIZE)):
+ for step in xrange(num_epochs * train_size // BATCH_SIZE):
# Compute the offset of the current minibatch in the data.
# Note that we could use better randomization across epochs.
offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
diff --git a/tensorflow/models/rnn/__init__.py b/tensorflow/models/rnn/__init__.py
index 475b6f8d31..9c4f2aa656 100644
--- a/tensorflow/models/rnn/__init__.py
+++ b/tensorflow/models/rnn/__init__.py
@@ -2,6 +2,10 @@
This file helps simplify the import process:
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.models.rnn import rnn
from tensorflow.models.rnn import rnn_cell
diff --git a/tensorflow/models/rnn/linear.py b/tensorflow/models/rnn/linear.py
index 96278e73e4..2f29c32184 100644
--- a/tensorflow/models/rnn/linear.py
+++ b/tensorflow/models/rnn/linear.py
@@ -1,5 +1,9 @@
"""Basic linear combinations that implicitly generate variables."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow as tf
diff --git a/tensorflow/models/rnn/linear_test.py b/tensorflow/models/rnn/linear_test.py
index 93ef10144f..83bf9747bf 100644
--- a/tensorflow/models/rnn/linear_test.py
+++ b/tensorflow/models/rnn/linear_test.py
@@ -1,4 +1,8 @@
# pylint: disable=g-bad-import-order,unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
index 1b01e355ee..8a6c7203c2 100644
--- a/tensorflow/models/rnn/ptb/ptb_word_lm.py
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -41,6 +41,8 @@ To run:
--data_path=/tmp/simple-examples/data/ --alsologtostderr
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import time
@@ -218,7 +220,7 @@ class LargeConfig(object):
def run_epoch(session, m, data, eval_op, verbose=False):
"""Runs the model on the given data."""
- epoch_size = ((len(data) / m.batch_size) - 1) / m.num_steps
+ epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps
start_time = time.time()
costs = 0.0
iters = 0
@@ -232,7 +234,7 @@ def run_epoch(session, m, data, eval_op, verbose=False):
costs += cost
iters += m.num_steps
- if verbose and step % (epoch_size / 10) == 10:
+ if verbose and step % (epoch_size // 10) == 10:
print("%.3f perplexity: %.3f speed: %.0f wps" %
(step * 1.0 / epoch_size, np.exp(costs / iters),
iters * m.batch_size / (time.time() - start_time)))
diff --git a/tensorflow/models/rnn/ptb/reader.py b/tensorflow/models/rnn/ptb/reader.py
index 9a0db9c525..62181d66e6 100644
--- a/tensorflow/models/rnn/ptb/reader.py
+++ b/tensorflow/models/rnn/ptb/reader.py
@@ -1,6 +1,10 @@
# pylint: disable=unused-import,g-bad-import-order
"""Utilities for parsing PTB text files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import collections
import os
import sys
@@ -9,6 +13,7 @@ import time
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.platform import gfile
@@ -25,7 +30,7 @@ def _build_vocab(filename):
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
- words, _ = zip(*count_pairs)
+ words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
return word_to_id
@@ -89,12 +94,12 @@ def ptb_iterator(raw_data, batch_size, num_steps):
raw_data = np.array(raw_data, dtype=np.int32)
data_len = len(raw_data)
- batch_len = data_len / batch_size
+ batch_len = data_len // batch_size
data = np.zeros([batch_size, batch_len], dtype=np.int32)
for i in range(batch_size):
data[i] = raw_data[batch_len * i:batch_len * (i + 1)]
- epoch_size = (batch_len - 1) / num_steps
+ epoch_size = (batch_len - 1) // num_steps
if epoch_size == 0:
raise ValueError("epoch_size == 0, decrease batch_size or num_steps")
diff --git a/tensorflow/models/rnn/ptb/reader_test.py b/tensorflow/models/rnn/ptb/reader_test.py
index c722cdb939..026f58cfc5 100644
--- a/tensorflow/models/rnn/ptb/reader_test.py
+++ b/tensorflow/models/rnn/ptb/reader_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.models.ptb_lstm.ptb_reader."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
# pylint: disable=g-bad-import-order,unused-import
diff --git a/tensorflow/models/rnn/rnn.py b/tensorflow/models/rnn/rnn.py
index 24582bcae7..412dfcc357 100644
--- a/tensorflow/models/rnn/rnn.py
+++ b/tensorflow/models/rnn/rnn.py
@@ -1,5 +1,9 @@
"""RNN helpers for TensorFlow models."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
diff --git a/tensorflow/models/rnn/rnn_cell.py b/tensorflow/models/rnn/rnn_cell.py
index 55d417fc2b..ff93b47892 100644
--- a/tensorflow/models/rnn/rnn_cell.py
+++ b/tensorflow/models/rnn/rnn_cell.py
@@ -1,7 +1,11 @@
"""Module for constructing RNN Cells."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import math
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import linear
@@ -270,16 +274,15 @@ class LSTMCell(RNNCell):
dtype = input_.dtype
- unit_shard_size = (4 * self._num_units) / self._num_unit_shards
+ unit_shard_size = (4 * self._num_units) // self._num_unit_shards
with tf.variable_scope(scope or type(self).__name__): # "LSTMCell"
w = tf.concat(
- 1, [tf.get_variable("W_%d" % i,
- shape=[self.input_size + num_proj,
- unit_shard_size],
- initializer=self._initializer,
- dtype=dtype)
- for i in range(self._num_unit_shards)])
+ 1,
+ [tf.get_variable("W_%d" % i,
+ shape=[self.input_size + num_proj, unit_shard_size],
+ initializer=self._initializer,
+ dtype=dtype) for i in xrange(self._num_unit_shards)])
b = tf.get_variable(
"B", shape=[4 * self._num_units],
@@ -313,12 +316,14 @@ class LSTMCell(RNNCell):
m = tf.sigmoid(o) * tf.tanh(c)
if self._num_proj is not None:
- proj_shard_size = self._num_proj / self._num_proj_shards
+ proj_shard_size = self._num_proj // self._num_proj_shards
w_proj = tf.concat(
- 1, [tf.get_variable("W_P_%d" % i,
- shape=[self._num_units, proj_shard_size],
- initializer=self._initializer, dtype=dtype)
- for i in range(self._num_proj_shards)])
+ 1,
+ [tf.get_variable("W_P_%d" % i,
+ shape=[self._num_units, proj_shard_size],
+ initializer=self._initializer,
+ dtype=dtype)
+ for i in xrange(self._num_proj_shards)])
# TODO(ebrevdo), use matmulsum
m = tf.matmul(m, w_proj)
diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/models/rnn/rnn_cell_test.py
index 8b4b209028..937e1557bd 100644
--- a/tensorflow/models/rnn/rnn_cell_test.py
+++ b/tensorflow/models/rnn/rnn_cell_test.py
@@ -1,9 +1,14 @@
"""Tests for RNN cells."""
# pylint: disable=g-bad-import-order,unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
diff --git a/tensorflow/models/rnn/rnn_test.py b/tensorflow/models/rnn/rnn_test.py
index 378315d296..5c1fd8a87c 100644
--- a/tensorflow/models/rnn/rnn_test.py
+++ b/tensorflow/models/rnn/rnn_test.py
@@ -1,6 +1,10 @@
"""Tests for rnn module."""
# pylint: disable=g-bad-import-order,unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py
index a3b6a838ca..875bcb5e6e 100644
--- a/tensorflow/models/rnn/seq2seq.py
+++ b/tensorflow/models/rnn/seq2seq.py
@@ -1,7 +1,11 @@
"""Library for creating sequence-to-sequence models."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import linear
diff --git a/tensorflow/models/rnn/seq2seq_test.py b/tensorflow/models/rnn/seq2seq_test.py
index d2949ecae2..228c16a4c8 100644
--- a/tensorflow/models/rnn/seq2seq_test.py
+++ b/tensorflow/models/rnn/seq2seq_test.py
@@ -1,11 +1,15 @@
"""Tests for functional style sequence-to-sequence models."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import math
import random
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import rnn
@@ -363,7 +367,7 @@ class Seq2SeqTest(tf.test.TestCase):
for ep in xrange(3):
log_perp = 0.0
for _ in xrange(50):
- bucket = random.choice(range(len(buckets)))
+ bucket = random.choice(np.arange(len(buckets)))
length = buckets[bucket][0]
i = [np.array([np.random.randint(9) + 1 for _ in xrange(batch_size)],
dtype=np.int32) for _ in xrange(length)]
diff --git a/tensorflow/models/rnn/translate/BUILD b/tensorflow/models/rnn/translate/BUILD
index 0899bf689e..57f17fb5ab 100644
--- a/tensorflow/models/rnn/translate/BUILD
+++ b/tensorflow/models/rnn/translate/BUILD
@@ -12,9 +12,7 @@ py_library(
srcs = [
"data_utils.py",
],
- deps = [
- "//tensorflow:tensorflow_py",
- ],
+ deps = ["//tensorflow:tensorflow_py"],
)
py_library(
diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py
index 6073f9197b..b9d951ccd7 100644
--- a/tensorflow/models/rnn/translate/data_utils.py
+++ b/tensorflow/models/rnn/translate/data_utils.py
@@ -1,4 +1,6 @@
"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import gzip
diff --git a/tensorflow/models/rnn/translate/seq2seq_model.py b/tensorflow/models/rnn/translate/seq2seq_model.py
index 3c9cfb007f..43e719570b 100644
--- a/tensorflow/models/rnn/translate/seq2seq_model.py
+++ b/tensorflow/models/rnn/translate/seq2seq_model.py
@@ -1,8 +1,13 @@
"""Sequence-to-sequence model with an attention mechanism."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import random
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py
index ec408eb127..51cf60157b 100644
--- a/tensorflow/models/rnn/translate/translate.py
+++ b/tensorflow/models/rnn/translate/translate.py
@@ -12,6 +12,8 @@ See the following papers for more information on neural translation models.
* http://arxiv.org/abs/1409.0473
* http://arxiv.org/pdf/1412.2007v2.pdf
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import math
@@ -23,6 +25,7 @@ import time
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.rnn.translate import data_utils
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 89eb22daba..7002ebfd65 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -620,9 +620,7 @@ py_library(
py_library(
name = "util",
srcs = glob(["util/**/*.py"]),
- deps = [
- "//google/protobuf:protobuf_python",
- ],
+ deps = ["//google/protobuf:protobuf_python"],
)
tf_proto_library_py(
@@ -672,9 +670,9 @@ tf_cuda_library(
":construction_fails_op",
":test_kernel_label_op_kernel",
"//tensorflow/core",
+ "//tensorflow/core:direct_session",
"//tensorflow/core:kernels",
"//tensorflow/core:lib",
- "//tensorflow/core:local",
"//tensorflow/core:protos_cc",
],
)
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 5527c01173..2cbdf191c6 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -4,6 +4,10 @@
Programs that want to build Brain Ops and Graphs without having to import the
constructors and utilities individually can import this file:
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py
index 4c7caa8a24..18baf696ae 100644
--- a/tensorflow/python/client/client_lib.py
+++ b/tensorflow/python/client/client_lib.py
@@ -32,6 +32,10 @@ examples of how a graph is launched in a [`tf.Session`](#Session).
@@DataLossError
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.client.session import InteractiveSession
from tensorflow.python.client.session import Session
diff --git a/tensorflow/python/client/events_writer_test.py b/tensorflow/python/client/events_writer_test.py
index 60bce49b1f..6ba10ced3b 100644
--- a/tensorflow/python/client/events_writer_test.py
+++ b/tensorflow/python/client/events_writer_test.py
@@ -1,4 +1,8 @@
"""Tests for the SWIG-wrapped events writer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
from tensorflow.core.framework import summary_pb2
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py
index 4c65a445ae..44227a0476 100644
--- a/tensorflow/python/client/graph_util.py
+++ b/tensorflow/python/client/graph_util.py
@@ -1,6 +1,10 @@
"""Helpers to manipulate a tensor graph in python.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.core.framework import graph_pb2
diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py
index 8066f722a8..e2ee625713 100644
--- a/tensorflow/python/client/graph_util_test.py
+++ b/tensorflow/python/client/graph_util_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.python.client.graph_util."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.client import graph_util
diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py
index 585c4e0f8f..316d3f6316 100644
--- a/tensorflow/python/client/notebook.py
+++ b/tensorflow/python/client/notebook.py
@@ -12,9 +12,10 @@ Press "a" in command mode to insert cell above or "b" to insert cell below.
Your root notebooks directory is FLAGS.notebook_dir
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
-
import os
import socket
import sys
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 0ac7f52df9..fedaa2c2ca 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1,5 +1,9 @@
"""A client interface for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import re
import sys
import threading
@@ -321,7 +325,7 @@ class BaseSession(SessionInterface):
# Validate and process feed_dict.
if feed_dict:
- for feed, feed_val in feed_dict.iteritems():
+ for feed, feed_val in feed_dict.items():
for subfeed, subfeed_val in _feed_fn(feed, feed_val):
try:
subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 4492840dcf..5472c96a75 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1,10 +1,15 @@
"""Tests for tensorflow.python.client.session.Session."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import threading
import time
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import config_pb2
from tensorflow.core.lib.core import error_codes_pb2
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index b6d7256e61..9870c109af 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -1,4 +1,8 @@
"""Class to represent a device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import copy
diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py
index 0a244b0815..aaf57a406c 100644
--- a/tensorflow/python/framework/device_test.py
+++ b/tensorflow/python/framework/device_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.python.framework.device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.framework import device
diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py
index d468d0fb68..abcfe7ca5e 100644
--- a/tensorflow/python/framework/docs.py
+++ b/tensorflow/python/framework/docs.py
@@ -3,6 +3,8 @@
Both updates the files in the file-system and executes g4 commands to
make sure any changes are ready to be submitted.
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import inspect
@@ -33,7 +35,8 @@ class Document(object):
class Index(Document):
"""An automatically generated index for a collection of documents."""
- def __init__(self, module_to_name, members, filename_to_library_map):
+ def __init__(self, module_to_name, members, filename_to_library_map,
+ path_prefix):
"""Creates a new Index.
Args:
@@ -41,10 +44,12 @@ class Index(Document):
members: Dictionary mapping member name to (fullname, member).
filename_to_library_map: A list of (filename, Library) pairs. The order
corresponds to the order in which the libraries appear in the index.
+ path_prefix: Prefix to add to links in the index.
"""
self._module_to_name = module_to_name
self._members = members
self._filename_to_library_map = filename_to_library_map
+ self._path_prefix = path_prefix
def write_markdown_to_file(self, f):
"""Writes this index to file `f`.
@@ -65,11 +70,11 @@ class Index(Document):
anchor_f = lambda name: _get_anchor(self._module_to_name, fullname_f(name))
for filename, library in self._filename_to_library_map:
- sorted_names = sorted(library.mentioned, key=str.lower)
+ sorted_names = sorted(library.mentioned, key=lambda x: (str.lower(x), x))
member_names = [n for n in sorted_names if n in self._members]
# TODO: This is a hack that should be removed as soon as the website code
# allows it.
- full_filename = '../../api_docs/python/' + filename
+ full_filename = self._path_prefix + filename
links = ["[`%s`](%s#%s)" % (name, full_filename, anchor_f(name))
for name in member_names]
if links:
@@ -94,7 +99,7 @@ def collect_members(module_to_name):
Dictionary mapping name to (fullname, member) pairs.
"""
members = {}
- for module, module_name in module_to_name.iteritems():
+ for module, module_name in module_to_name.items():
for name, member in inspect.getmembers(module):
if ((inspect.isfunction(member) or inspect.isclass(member)) and
not _always_drop_symbol_re.match(name)):
@@ -131,7 +136,7 @@ def _get_anchor(module_to_name, fullname):
if not _anchor_re.match(fullname):
raise ValueError("'%s' is not a valid anchor" % fullname)
anchor = fullname
- for module_name in module_to_name.itervalues():
+ for module_name in module_to_name.values():
if fullname.startswith(module_name + "."):
rest = fullname[len(module_name)+1:]
# Use this prefix iff it is longer than any found before
@@ -419,8 +424,7 @@ class Library(Document):
# if some methods are not categorized.
any_method_called_out = (len(methods) != num_methods)
if any_method_called_out:
- other_methods = {n: m for n, m in methods.iteritems()
- if n in cls.__dict__}
+ other_methods = {n: m for n, m in methods.items() if n in cls.__dict__}
if other_methods:
print("\n#### Other Methods", file=f)
else:
@@ -463,7 +467,7 @@ class Library(Document):
Otherwise, document missing symbols from just this module.
"""
if catch_all:
- names = self._members.iteritems()
+ names = self._members.items()
else:
names = inspect.getmembers(self._module)
leftovers = []
@@ -482,7 +486,7 @@ class Library(Document):
def assert_no_leftovers(self):
"""Generate an error if there are leftover members."""
leftovers = []
- for name in self._members.iterkeys():
+ for name in self._members.keys():
if name in self._members and name not in self._documented:
leftovers.append(name)
if leftovers:
diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py
index 948057b8aa..68c656da85 100644
--- a/tensorflow/python/framework/errors.py
+++ b/tensorflow/python/framework/errors.py
@@ -1,4 +1,8 @@
"""Exception types for TensorFlow errors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import traceback
import warnings
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
index ab59a729f6..23bb5d7f66 100644
--- a/tensorflow/python/framework/errors_test.py
+++ b/tensorflow/python/framework/errors_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.python.framework.errors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import warnings
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index eacd4f5a0b..3c9d941a35 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -39,6 +39,10 @@
"""
# Classes used when building a Graph.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework.ops import Graph
from tensorflow.python.framework.ops import Operation
from tensorflow.python.framework.ops import Tensor
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 5f62311f12..1f426e00b5 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -1,4 +1,6 @@
"""Updates generated docs from Python doc comments."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import os.path
@@ -121,7 +123,8 @@ def main(unused_argv):
# Generate index
with open(os.path.join(FLAGS.out_dir, "index.md"), "w") as f:
- docs.Index(module_to_name, members, libraries).write_markdown_to_file(f)
+ docs.Index(module_to_name, members, libraries,
+ "../../api_docs/python/").write_markdown_to_file(f)
if __name__ == "__main__":
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 6ad2a1b009..f10a33ae33 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -1,4 +1,8 @@
"""A utility function for importing TensorFlow graphs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import contextlib
import tensorflow.python.platform
@@ -242,7 +246,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
operation_name, output_index = _ParseTensorName(input_name)
try:
source_op = name_to_op[operation_name]
- source_tensor = source_op.values()[output_index]
+ source_tensor = list(source_op.values())[output_index]
except (KeyError, IndexError):
raise ValueError(
_InvalidNodeMessage(
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 470092313a..12f1cc0de7 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.python.framework.importer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/framework/op_def_registry.py b/tensorflow/python/framework/op_def_registry.py
index 2ec8c94a10..28e55a47a4 100644
--- a/tensorflow/python/framework/op_def_registry.py
+++ b/tensorflow/python/framework/op_def_registry.py
@@ -1,5 +1,9 @@
"""Global registry for OpDefs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.core.framework import op_def_pb2
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 024cd69b60..70321e76dc 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1,5 +1,9 @@
"""Classes and functions used to construct graphs."""
# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import collections
import contextlib
import copy
@@ -11,6 +15,7 @@ import weakref
import tensorflow.python.platform
+import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import device as pydev
@@ -128,21 +133,36 @@ class Tensor(object):
# List of Python operators that we allow to override.
OVERLOADABLE_OPERATORS = {
# Binary.
- "__add__", "__radd__",
- "__sub__", "__rsub__",
- "__mul__", "__rmul__",
- "__div__", "__rdiv__",
- "__truediv__", "__rtruediv__",
- "__mod__", "__rmod__",
- "__lt__", "__le__",
- "__gt__", "__ge__",
- "__and__", "__rand__",
- "__or__", "__ror__",
- "__xor__", "__rxor__",
+ "__add__",
+ "__radd__",
+ "__sub__",
+ "__rsub__",
+ "__mul__",
+ "__rmul__",
+ "__div__",
+ "__rdiv__",
+ "__truediv__",
+ "__rtruediv__",
+ "__floordiv__",
+ "__rfloordiv__",
+ "__mod__",
+ "__rmod__",
+ "__lt__",
+ "__le__",
+ "__gt__",
+ "__ge__",
+ "__and__",
+ "__rand__",
+ "__or__",
+ "__ror__",
+ "__xor__",
+ "__rxor__",
"__getitem__",
# Unary.
"__invert__",
- "__neg__", "__abs__"}
+ "__neg__",
+ "__abs__"
+ }
def __init__(self, op, value_index, dtype):
"""Creates a new `Tensor`.
@@ -848,7 +868,7 @@ def _NodeDef(op_type, name, device=None, attrs=None):
node_def.op = str(op_type)
node_def.name = str(name)
if attrs is not None:
- for k, v in attrs.iteritems():
+ for k, v in six.iteritems(attrs):
node_def.attr[k].CopyFrom(v)
if device is not None:
if callable(device):
@@ -959,8 +979,8 @@ class Operation(object):
if output_types is None:
output_types = []
self._output_types = output_types
- self._outputs = [Tensor(self, i, output_types[i])
- for i in xrange(len(output_types))]
+ self._outputs = [Tensor(self, i, output_type)
+ for i, output_type in enumerate(output_types)]
if input_types is None:
input_types = [i.dtype.base_dtype for i in self._inputs]
else:
@@ -1150,6 +1170,9 @@ class Operation(object):
def __bool__(self):
return bool(self._op._inputs)
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__
+ __nonzero__ = __bool__
+
def __getitem__(self, i):
return self._op._inputs[i]
# pylint: enable=protected-access
@@ -1845,7 +1868,7 @@ class Graph(object):
Returns:
A list of Operations.
"""
- return self._nodes_by_id.values()
+ return list(self._nodes_by_id.values())
def get_operation_by_name(self, name):
"""Returns the `Operation` with the given `name`.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index a406c5e56e..8656228edb 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.python.framework.ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.framework import device as pydev
@@ -199,11 +203,12 @@ class CreateOpTest(test_util.TensorFlowTestCase):
[],
[types.float32, types.string], None,
name="myop2")
- op3 = g.create_op(
- "foo",
- [op1.values()[0], op2.values()[1], op2.values()[0]],
- [types.float32, types.int32], None,
- name="myop3")
+ op3 = g.create_op("foo",
+ [list(op1.values())[0], list(op2.values())[1],
+ list(op2.values())[0]],
+ [types.float32, types.int32],
+ None,
+ name="myop3")
self.assertEquals(None, op1.device)
self.assertEquals("/device:GPU", op2.device)
self.assertEquals(None, op3.device)
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index 11a49195c2..96e599e1c7 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -1,6 +1,10 @@
"""For seeding individual ops based on a graph-level seed.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py
index d9556f0a06..ec914c37aa 100644
--- a/tensorflow/python/framework/registry.py
+++ b/tensorflow/python/framework/registry.py
@@ -4,6 +4,10 @@ This is typically used with a decorator that calls Register for adding
a class or function to a registry.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import traceback
from tensorflow.python.platform import logging
diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py
index 5b4f261ceb..702f20c7eb 100644
--- a/tensorflow/python/framework/registry_test.py
+++ b/tensorflow/python/framework/registry_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.ops.registry."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import registry
from tensorflow.python.platform import googletest
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 1fef77e479..b5462dcd17 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -1,4 +1,8 @@
"""Helper classes for tensor shape inference."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
@@ -163,8 +167,8 @@ class Dimension(object):
else:
return Dimension(self._value * other.value)
- def __div__(self, other):
- """Returns the quotient of `self` and `other`.
+ def __floordiv__(self, other):
+ """Returns the quotient of `self` and `other` rounded down.
Dimensions are summed as follows:
@@ -183,7 +187,7 @@ class Dimension(object):
if self._value is None or other.value is None:
return Dimension(None)
else:
- return Dimension(self._value / other.value)
+ return Dimension(self._value // other.value)
def __mod__(self, other):
"""Returns `self` modulo `other.
@@ -376,10 +380,10 @@ class TensorShape(object):
self._dims = [as_dimension(dims)]
else:
# Got a list of dimensions
- self._dims = map(as_dimension, dims_iter)
+ self._dims = [as_dimension(d) for d in dims_iter]
def __repr__(self):
- return "TensorShape(%s)" % str(self._dims)
+ return "TensorShape(%s)" % self._dims
@property
def dims(self):
@@ -400,10 +404,13 @@ class TensorShape(object):
raise ValueError("Cannot take the length of Shape with unknown rank.")
return len(self._dims)
- def __nonzero__(self):
+ def __bool__(self):
"""Returns True if this shape contains non-zero information."""
return self._dims is not None
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__
+ __nonzero__ = __bool__
+
def __getitem__(self, key):
"""Returns the value of a dimension or a shape, depending on the key.
@@ -710,7 +717,7 @@ def unknown_shape(ndims=None):
if ndims is None:
return TensorShape(None)
else:
- return TensorShape([Dimension(None) for _ in range(ndims)])
+ return TensorShape([Dimension(None)] * ndims)
def scalar():
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index 9743a8d199..be5fbb51cb 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -1,4 +1,8 @@
"""Functional tests for shape inference helper classes."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.framework import tensor_shape
@@ -19,8 +23,9 @@ class DimensionTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.Dimension(24),
dim * tensor_shape.Dimension(2))
self.assertEqual(tensor_shape.Dimension(24), dim * 2)
- self.assertEqual(tensor_shape.Dimension(6), dim / tensor_shape.Dimension(2))
- self.assertEqual(tensor_shape.Dimension(6), dim / 2)
+ self.assertEqual(
+ tensor_shape.Dimension(6), dim // tensor_shape.Dimension(2))
+ self.assertEqual(tensor_shape.Dimension(6), dim // 2)
self.assertEqual(tensor_shape.Dimension(12),
dim.merge_with(tensor_shape.Dimension(12)))
self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12))
@@ -44,8 +49,9 @@ class DimensionTest(test_util.TensorFlowTestCase):
(dim + tensor_shape.Dimension(None)).value)
self.assertEqual(tensor_shape.Dimension(None).value,
(dim * tensor_shape.Dimension(None)).value)
- self.assertEqual(tensor_shape.Dimension(None).value,
- (dim / tensor_shape.Dimension(None)).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value,
+ (dim // tensor_shape.Dimension(None)).value)
self.assertEqual(tensor_shape.Dimension(None).value,
dim.merge_with(tensor_shape.Dimension(None)).value)
self.assertIs(None,
@@ -69,9 +75,9 @@ class DimensionTest(test_util.TensorFlowTestCase):
self.assertEqual(
tensor_shape.Dimension(None).value, (unknown * known).value)
self.assertEqual(
- tensor_shape.Dimension(None).value, (known / unknown).value)
+ tensor_shape.Dimension(None).value, (known // unknown).value)
self.assertEqual(
- tensor_shape.Dimension(None).value, (unknown / known).value)
+ tensor_shape.Dimension(None).value, (unknown // known).value)
self.assertEqual(
tensor_shape.Dimension(12), known.merge_with(unknown))
self.assertEqual(
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 43e2128483..e00b3b6d91 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -1,7 +1,12 @@
"""Utilities to create TensorProtos."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import numbers
import tensorflow.python.platform
import numpy as np
+import six
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
@@ -33,12 +38,11 @@ if _FAST_TENSOR_UTIL_AVAILABLE:
np.complex128: fast_tensor_util.AppendComplex128ArrayToTensorProto,
np.object: fast_tensor_util.AppendObjectArrayToTensorProto,
np.bool: fast_tensor_util.AppendBoolArrayToTensorProto,
- types.qint8.as_numpy_dtype:
- fast_tensor_util.AppendInt8ArrayToTensorProto,
+ types.qint8.as_numpy_dtype: fast_tensor_util.AppendInt8ArrayToTensorProto,
types.quint8.as_numpy_dtype:
- fast_tensor_util.AppendUInt8ArrayToTensorProto,
+ fast_tensor_util.AppendUInt8ArrayToTensorProto,
types.qint32.as_numpy_dtype:
- fast_tensor_util.AppendInt32ArrayToTensorProto,
+ fast_tensor_util.AppendInt32ArrayToTensorProto,
# NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
}
else:
@@ -87,7 +91,7 @@ else:
def GetFromNumpyDTypeDict(dtype_dict, dtype):
# NOTE: dtype_dict.get(dtype) always returns None.
- for key, val in dtype_dict.iteritems():
+ for key, val in six.iteritems(dtype_dict):
if key == dtype:
return val
return None
@@ -524,7 +528,6 @@ def ConstantValue(tensor):
delta = ConstantValue(tensor.op.inputs[2])
if delta is None:
return None
- return np.array(range(start, limit, delta),
- dtype=tensor.dtype.as_numpy_dtype)
+ return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
else:
return None
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 7c1c0b8d3e..6f5c34d185 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -1,4 +1,8 @@
"""Functional tests for tensor_util."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 2b683c5123..e645d772e9 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1,6 +1,9 @@
# pylint: disable=invalid-name
"""Test utils for tensorflow."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import contextlib
import math
import re
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 38b8be3c54..fc84797501 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -1,9 +1,13 @@
"""Tests for tensorflow.ops.test_util."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import threading
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from google.protobuf import text_format
@@ -58,7 +62,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
def testCheckedThreadFails(self):
def err_func():
- return 1 / 0
+ return 1 // 0
t = self.checkedThread(target=err_func)
t.start()
diff --git a/tensorflow/python/framework/types.py b/tensorflow/python/framework/types.py
index ed45f1b2d0..ffb9d0f213 100644
--- a/tensorflow/python/framework/types.py
+++ b/tensorflow/python/framework/types.py
@@ -1,4 +1,8 @@
"""Library of dtypes (Tensor element types)."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -108,6 +112,11 @@ class DType(object):
issubclass(self.as_numpy_dtype, np.integer))
@property
+ def is_floating(self):
+ """Returns whether this is a (real) floating point type."""
+ return issubclass(self.as_numpy_dtype, np.floating)
+
+ @property
def is_quantized(self):
"""Returns whether this is a quantized data type."""
return self.base_dtype in [qint8, quint8, qint32, bfloat16]
@@ -299,7 +308,7 @@ _TYPE_TO_STRING = {
types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
}
_STRING_TO_TF = {value: _INTERN_TABLE[key]
- for key, value in _TYPE_TO_STRING.iteritems()}
+ for key, value in _TYPE_TO_STRING.items()}
# Add non-canonical aliases.
_STRING_TO_TF["float"] = float32
_STRING_TO_TF["float_ref"] = float32_ref
diff --git a/tensorflow/python/framework/types_test.py b/tensorflow/python/framework/types_test.py
index 5c50080db3..8686942baf 100644
--- a/tensorflow/python/framework/types_test.py
+++ b/tensorflow/python/framework/types_test.py
@@ -1,5 +1,8 @@
"""Tests for tensorflow.python.framework.importer."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -115,6 +118,18 @@ class TypesTest(test_util.TensorFlowTestCase):
self.assertEqual(types.as_dtype("string").is_integer, False)
self.assertEqual(types.as_dtype("bool").is_integer, False)
+ def testIsFloating(self):
+ self.assertEqual(types.as_dtype("int8").is_floating, False)
+ self.assertEqual(types.as_dtype("int16").is_floating, False)
+ self.assertEqual(types.as_dtype("int32").is_floating, False)
+ self.assertEqual(types.as_dtype("int64").is_floating, False)
+ self.assertEqual(types.as_dtype("uint8").is_floating, False)
+ self.assertEqual(types.as_dtype("complex64").is_floating, False)
+ self.assertEqual(types.as_dtype("float32").is_floating, True)
+ self.assertEqual(types.as_dtype("float64").is_floating, True)
+ self.assertEqual(types.as_dtype("string").is_floating, False)
+ self.assertEqual(types.as_dtype("bool").is_floating, False)
+
def testMinMax(self):
# make sure min/max evaluates for all data types that have min/max
for datatype_enum in types_pb2.DataType.values():
@@ -163,7 +178,7 @@ class TypesTest(test_util.TensorFlowTestCase):
self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
def testRepr(self):
- for enum, name in types._TYPE_TO_STRING.iteritems():
+ for enum, name in types._TYPE_TO_STRING.items():
dtype = types.DType(enum)
self.assertEquals(repr(dtype), 'tf.' + name)
dtype2 = eval(repr(dtype))
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index 2cd6101a87..f5f2569f53 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.argmax_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 108cc7599e..bf5e98518d 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for array_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py
index 5541c541b2..500d85e6d0 100644
--- a/tensorflow/python/kernel_tests/attention_ops_test.py
+++ b/tensorflow/python/kernel_tests/attention_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.ops.attention_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
@@ -80,8 +84,8 @@ class ExtractGlimpseTest(tf.test.TestCase):
# Check entries.
min_random_val = 0
max_random_val = max(rows, cols)
- for i in range(0, glimpse_sizes[0]):
- for j in range(0, glimpse_sizes[1]):
+ for i in range(glimpse_sizes[0]):
+ for j in range(glimpse_sizes[1]):
if expected_rows[i] is None or expected_cols[j] is None:
self.assertGreaterEqual(value_rows[0][i][j][0], min_random_val)
self.assertLessEqual(value_rows[0][i][j][0], max_random_val)
@@ -102,15 +106,15 @@ class ExtractGlimpseTest(tf.test.TestCase):
self._VerifyValues(tensor_in_sizes=[41, 61],
glimpse_sizes=[41, 61],
offsets=[0.0, 0.0],
- expected_rows=range(1, 42),
- expected_cols=range(1, 62))
+ expected_rows=list(range(1, 42)),
+ expected_cols=list(range(1, 62)))
def testTooLargeCenterGlimpse(self):
self._VerifyValues(tensor_in_sizes=[41, 61],
glimpse_sizes=[43, 63],
offsets=[0.0, 0.0],
- expected_rows=[None] + range(1, 42) + [None],
- expected_cols=[None] + range(1, 62) + [None])
+ expected_rows=[None] + list(range(1, 42)) + [None],
+ expected_cols=[None] + list(range(1, 62)) + [None])
def testGlimpseFullOverlap(self):
self._VerifyValues(tensor_in_sizes=[41, 61],
diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
index 8ae37fec3a..d27b38b40e 100644
--- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
@@ -1,4 +1,7 @@
"""Tests for tensorflow.ops.tf.BatchMatMul."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
index c62a910496..640c29a550 100644
--- a/tensorflow/python/kernel_tests/bcast_ops_test.py
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.kernels.bcast_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py
index da706e8259..2badc0f8bc 100644
--- a/tensorflow/python/kernel_tests/bias_op_test.py
+++ b/tensorflow/python/kernel_tests/bias_op_test.py
@@ -1,5 +1,8 @@
"""Functional tests for BiasAdd."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
index a36b8587d5..48857d7150 100644
--- a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
+++ b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for CandidateSamplerOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index e67b0694c4..b8d9786daa 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -1,4 +1,7 @@
"""Tests for tensorflow.ops.tf.cast."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index 17e8d116be..aadf47ba73 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -1,7 +1,12 @@
"""Tests for tensorflow.ops.tf.Cholesky."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index 46bba7514d..7d4ee20a55 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.ops.clip_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index 3f6c43f0a6..b495948554 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -1,4 +1,8 @@
"""Functional tests for Concat Op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 92f9b5fe4a..c02c21250c 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -1,4 +1,8 @@
"""Tests for ConstantOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index adf3552739..e293dea581 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1,10 +1,15 @@
# pylint: disable=g-long-lambda
"""Tests for tensorflow.ops.control_flow_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
@@ -29,7 +34,7 @@ def check_consumers(graph):
for v in op.inputs:
cnt = consumer_count.get(v, 0)
consumer_count[v] = cnt + 1
- for k, v in consumer_count.iteritems():
+ for k, v in consumer_count.items():
if len(k.consumers()) != v:
return False
return True
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 88f37ef952..4d88809c88 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1,6 +1,7 @@
"""Functional tests for convolutional operations."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
-import math
import tensorflow.python.platform
@@ -372,11 +373,11 @@ class Conv2DTest(tf.test.TestCase):
filter_shape = [filter_rows, filter_cols, in_depth, out_depth]
# TODO(yangke): re-factor the computation of output shape.
if padding == "VALID":
- output_rows = int(math.ceil((input_rows - filter_rows + 1.0) / stride))
- output_cols = int(math.ceil((input_cols - filter_cols + 1.0) / stride))
+ output_rows = (input_rows - filter_rows + stride) // stride
+ output_cols = (input_cols - filter_cols + stride) // stride
else:
- output_rows = int(math.ceil(float(input_rows) / stride))
- output_cols = int(math.ceil(float(input_cols) / stride))
+ output_rows = (input_rows + stride - 1) // stride
+ output_cols = (input_cols + stride - 1) // stride
output_shape = [batch, output_rows, output_cols, out_depth]
input_size = 1
for x in input_shape:
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 22491f231a..a225db20d5 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -1,5 +1,9 @@
"""Functional tests for coefficient-wise operations.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -10,7 +14,8 @@ from tensorflow.python.kernel_tests import gradient_checker as gc
_ADD = lambda x, y: x + y
_SUB = lambda x, y: x - y
_MUL = lambda x, y: x * y
-_DIV = lambda x, y: x / y
+_TRUEDIV = lambda x, y: x / y
+_FLOORDIV = lambda x, y: x // y
_MOD = lambda x, y: x % y
_NEG = lambda x: -x
_ABS = abs
@@ -229,9 +234,10 @@ class BinaryOpTest(tf.test.TestCase):
def _compareBoth(self, x, y, np_func, tf_func):
self._compareCpu(x, y, np_func, tf_func)
- if x.dtype == np.float32 or x.dtype == np.float64:
- self._compareGradientX(x, y, np_func, tf_func)
- self._compareGradientY(x, y, np_func, tf_func)
+ if x.dtype in (np.float32, np.float64):
+ if tf_func not in (_FLOORDIV, tf.floordiv):
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
self._compareGpu(x, y, np_func, tf_func)
def testFloatBasic(self):
@@ -240,11 +246,13 @@ class BinaryOpTest(tf.test.TestCase):
self._compareBoth(x, y, np.add, tf.add)
self._compareBoth(x, y, np.subtract, tf.sub)
self._compareBoth(x, y, np.multiply, tf.mul)
- self._compareBoth(x, y + 0.1, np.divide, tf.div)
+ self._compareBoth(x, y + 0.1, np.true_divide, tf.truediv)
+ self._compareBoth(x, y + 0.1, np.floor_divide, tf.floordiv)
self._compareBoth(x, y, np.add, _ADD)
self._compareBoth(x, y, np.subtract, _SUB)
self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.divide, _DIV)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
def testFloatDifferentShapes(self):
x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
@@ -267,11 +275,13 @@ class BinaryOpTest(tf.test.TestCase):
self._compareBoth(x, y, np.add, tf.add)
self._compareBoth(x, y, np.subtract, tf.sub)
self._compareBoth(x, y, np.multiply, tf.mul)
- self._compareBoth(x, y + 0.1, np.divide, tf.div)
+ self._compareBoth(x, y + 0.1, np.true_divide, tf.truediv)
+ self._compareBoth(x, y + 0.1, np.floor_divide, tf.floordiv)
self._compareBoth(x, y, np.add, _ADD)
self._compareBoth(x, y, np.subtract, _SUB)
self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.divide, _DIV)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
def testInt8Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
@@ -291,14 +301,14 @@ class BinaryOpTest(tf.test.TestCase):
self._compareBoth(x, y, np.add, tf.add)
self._compareBoth(x, y, np.subtract, tf.sub)
self._compareBoth(x, y, np.multiply, tf.mul)
- # NOTE: int32 division is ill-defined.
- self._compareBoth(x, y, np.divide, tf.div)
+ self._compareBoth(x, y, np.true_divide, tf.truediv)
+ self._compareBoth(x, y, np.floor_divide, tf.floordiv)
self._compareBoth(x, y, np.mod, tf.mod)
self._compareBoth(x, y, np.add, _ADD)
self._compareBoth(x, y, np.subtract, _SUB)
self._compareBoth(x, y, np.multiply, _MUL)
- # NOTE: int32 division is ill-defined.
- self._compareBoth(x, y, np.divide, _DIV)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
self._compareBoth(x, y, np.mod, _MOD)
def testInt64Basic(self):
@@ -306,13 +316,13 @@ class BinaryOpTest(tf.test.TestCase):
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
self._compareBoth(x, y, np.subtract, tf.sub)
self._compareBoth(x, y, np.multiply, tf.mul)
- # NOTE: int64 division is ill-defined.
- self._compareBoth(x, y, np.divide, tf.div)
+ self._compareBoth(x, y, np.true_divide, tf.truediv)
+ self._compareBoth(x, y, np.floor_divide, tf.floordiv)
self._compareBoth(x, y, np.mod, tf.mod)
self._compareBoth(x, y, np.subtract, _SUB)
self._compareBoth(x, y, np.multiply, _MUL)
- # NOTE: int64 division is ill-defined.
- self._compareBoth(x, y, np.divide, _DIV)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
self._compareBoth(x, y, np.mod, _MOD)
def testComplex64Basic(self):
@@ -323,19 +333,20 @@ class BinaryOpTest(tf.test.TestCase):
self._compareCpu(x, y, np.add, tf.add)
self._compareCpu(x, y, np.subtract, tf.sub)
self._compareCpu(x, y, np.multiply, tf.mul)
- self._compareCpu(x, y + 0.1, np.divide, tf.div)
+ self._compareCpu(x, y + 0.1, np.true_divide, tf.truediv)
self._compareCpu(x, y, np.add, _ADD)
self._compareCpu(x, y, np.subtract, _SUB)
self._compareCpu(x, y, np.multiply, _MUL)
- self._compareCpu(x, y + 0.1, np.divide, _DIV)
+ self._compareCpu(x, y + 0.1, np.true_divide, _TRUEDIV)
def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
self._compareCpu(x, y, np_func, tf_func)
- if x.dtype == np.float32 or x.dtype == np.float64:
- self._compareGradientX(x, y, np_func, tf_func)
- self._compareGradientY(x, y, np_func, tf_func)
+ if x.dtype in (np.float32, np.float64):
+ if tf_func not in (_FLOORDIV, tf.floordiv):
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
self._compareGpu(x, y, np_func, tf_func)
# TODO(josh11b,vrv): Refactor this to use parameterized tests.
@@ -349,6 +360,8 @@ class BinaryOpTest(tf.test.TestCase):
]
for dtype in dtypes:
for (np_func, tf_func) in funcs:
+ if dtype == np.complex64 and tf_func in (_FLOORDIV, tf.floordiv):
+ continue # floordiv makes no sense for complex numbers
self._compareBCast(xs, ys, dtype, np_func, tf_func)
self._compareBCast(ys, xs, dtype, np_func, tf_func)
@@ -376,8 +389,10 @@ class BinaryOpTest(tf.test.TestCase):
def _testBCastD(self, xs, ys):
funcs = [
- (np.divide, tf.div),
- (np.divide, _DIV)
+ (np.true_divide, tf.truediv),
+ (np.floor_divide, tf.floordiv),
+ (np.true_divide, _TRUEDIV),
+ (np.floor_divide, _FLOORDIV),
]
self._testBCastByFunc(funcs, xs, ys)
@@ -574,8 +589,8 @@ class BinaryOpTest(tf.test.TestCase):
self._testBCastD([10, 3, 1, 2], [3, 1, 2])
def testMismatchedDimensions(self):
- for func in [tf.add, tf.sub, tf.mul, tf.div,
- _ADD, _SUB, _MUL, _DIV]:
+ for func in [tf.add, tf.sub, tf.mul, tf.div, _ADD, _SUB, _MUL, _TRUEDIV,
+ _FLOORDIV]:
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "Incompatible shapes" in e.message):
func(tf.convert_to_tensor([10.0, 20.0, 30.0]),
@@ -959,10 +974,13 @@ class MathOpsOverloadTest(tf.test.TestCase):
(np.add, _ADD),
(np.subtract, _SUB),
(np.multiply, _MUL),
- (np.divide, _DIV)
+ (np.true_divide, _TRUEDIV),
+ (np.floor_divide, _FLOORDIV),
]
for dtype in dtypes:
for np_func, tf_func in funcs:
+ if dtype == tf.complex64 and tf_func == _FLOORDIV:
+ continue # floordiv makes no sense for complex
self._compareBinary(10, 5, dtype, np_func, tf_func)
# Mod only works for int32 and int64.
for dtype in [tf.int32, tf.int64]:
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index ae0917f8c4..2a7e2f6625 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -1,5 +1,9 @@
"""Tests for DecodeCSV op from parsing_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py
index abd50a7527..2cee3b1484 100644
--- a/tensorflow/python/kernel_tests/decode_raw_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py
@@ -1,5 +1,9 @@
"""Tests for DecodeRaw op from parsing_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
index ad0724931e..4c0caf14af 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
@@ -1,4 +1,8 @@
"""Tests for state updating ops that may have benign race conditions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 2e1ea468c3..1051cd56f3 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.tf.Assign*."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index d4e2b88339..a478811af3 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.tf.MatrixDeterminant."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index 7b53ee26fa..0fd87c7e97 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy
diff --git a/tensorflow/python/kernel_tests/division_future_test.py b/tensorflow/python/kernel_tests/division_future_test.py
new file mode 100644
index 0000000000..b9b3b13b68
--- /dev/null
+++ b/tensorflow/python/kernel_tests/division_future_test.py
@@ -0,0 +1,50 @@
+"""Tests for division with division imported from __future__.
+
+This file should be exactly the same as division_past_test.py except
+for the __future__ division line.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DivisionTestCase(tf.test.TestCase):
+
+ def testDivision(self):
+ """Test all the different ways to divide."""
+ values = [1, 2, 7, 11]
+ functions = (lambda x: x), tf.constant
+ # TODO(irving): Test int8, int16 once we support casts for those.
+ dtypes = np.int32, np.int64, np.float32, np.float64
+
+ def check(x, y):
+ if isinstance(x, tf.Tensor):
+ x = x.eval()
+ if isinstance(y, tf.Tensor):
+ y = y.eval()
+ self.assertEqual(x.dtype, y.dtype)
+ self.assertEqual(x, y)
+ with self.test_session():
+ for dtype in dtypes:
+ for x in map(dtype, values):
+ for y in map(dtype, values):
+ for fx in functions:
+ for fy in functions:
+ tf_x = fx(x)
+ tf_y = fy(y)
+ div = x / y
+ tf_div = tf_x / tf_y
+ check(div, tf_div)
+ floordiv = x // y
+ tf_floordiv = tf_x // tf_y
+ check(floordiv, tf_floordiv)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/division_past_test.py b/tensorflow/python/kernel_tests/division_past_test.py
new file mode 100644
index 0000000000..b65a724f7e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/division_past_test.py
@@ -0,0 +1,50 @@
+"""Tests for division with division imported from __future__.
+
+This file should be exactly the same as division_past_test.py except
+for the __future__ division line.
+"""
+
+from __future__ import absolute_import
+# from __future__ import division # Intentionally skip this import
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DivisionTestCase(tf.test.TestCase):
+
+ def testDivision(self):
+ """Test all the different ways to divide."""
+ values = [1, 2, 7, 11]
+ functions = (lambda x: x), tf.constant
+ # TODO(irving): Test int8, int16 once we support casts for those.
+ dtypes = np.int32, np.int64, np.float32, np.float64
+
+ def check(x, y):
+ if isinstance(x, tf.Tensor):
+ x = x.eval()
+ if isinstance(y, tf.Tensor):
+ y = y.eval()
+ self.assertEqual(x.dtype, y.dtype)
+ self.assertEqual(x, y)
+ with self.test_session():
+ for dtype in dtypes:
+ for x in map(dtype, values):
+ for y in map(dtype, values):
+ for fx in functions:
+ for fy in functions:
+ tf_x = fx(x)
+ tf_y = fy(y)
+ div = x / y
+ tf_div = tf_x / tf_y
+ check(div, tf_div)
+ floordiv = x // y
+ tf_floordiv = tf_x // tf_y
+ check(floordiv, tf_floordiv)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index a7a276893d..fdb82d5220 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -1,7 +1,12 @@
"""Tests for the DynamicPartition op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index 9ac49390b9..9644b2100d 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.data_flow_ops.dynamic_stitch."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/edit_distance_op_test.py b/tensorflow/python/kernel_tests/edit_distance_op_test.py
index 5919adcfaf..b04720d070 100644
--- a/tensorflow/python/kernel_tests/edit_distance_op_test.py
+++ b/tensorflow/python/kernel_tests/edit_distance_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.kernels.edit_distance_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 50755b6c46..03844d6177 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -1,10 +1,14 @@
"""Functional tests for ops used with embeddings."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import itertools
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker as gc
@@ -12,7 +16,7 @@ from tensorflow.python.kernel_tests import gradient_checker as gc
def _AsLong(array):
"""Casts arrays elements to long type. Used to convert from numpy tf."""
- return [long(x) for x in array]
+ return [int(x) for x in array]
class ScatterAddSubTest(tf.test.TestCase):
@@ -104,7 +108,7 @@ def _EmbeddingParams(num_shards, vocab_size,
feed_dict = {}
if not shape: shape = [10]
assert not vocab_size % num_shards
- shape = [vocab_size / num_shards] + shape
+ shape = [vocab_size // num_shards] + shape
for i in range(num_shards):
param_name = _PName(i)
constant_t = tf.constant(1.0, shape=shape, dtype=dtype,
@@ -130,8 +134,8 @@ def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None):
ids = [ids]
wts = [wts]
for i, wt_val in zip(ids, wts):
- val = np.copy(params[_PName(i % num_shards) + ":0"]
- [i / num_shards, :]) * wt_val
+ val = np.copy(params[_PName(i % num_shards) + ":0"][
+ i // num_shards, :]) * wt_val
if val_aggr is None:
assert wt_aggr is None
val_aggr = val
@@ -258,7 +262,7 @@ class EmbeddingLookupTest(tf.test.TestCase):
self.assertAllEqual(simple, tf.gather(params, ids).eval())
# Run a few random sharded versions
for procs in 1, 2, 3:
- stride = procs * tf.range(0, params.shape[0] / procs)
+ stride = procs * tf.range(0, params.shape[0] // procs)
split_params = [tf.gather(params, stride + p)
for p in xrange(procs)]
sharded = tf.nn.embedding_lookup(split_params, ids).eval()
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 11b7d46318..fbbea309e1 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import random
import re
import time
@@ -6,6 +10,7 @@ import time
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -448,7 +453,9 @@ class FIFOQueueTest(tf.test.TestCase):
elements_enqueued += 1
else:
count = random.randint(0, min(20, 250 - elements_enqueued))
- range_to_enqueue = range(elements_enqueued, elements_enqueued + count)
+ range_to_enqueue = np.arange(elements_enqueued,
+ elements_enqueued + count,
+ dtype=np.int32)
enqueuemany_op.run({enqueuemany_placeholder: range_to_enqueue})
elements_enqueued += count
@@ -459,7 +466,7 @@ class FIFOQueueTest(tf.test.TestCase):
def testMixtureOfDequeueAndDequeueMany(self):
with self.test_session() as sess:
q = tf.FIFOQueue(10, tf.int32, shapes=())
- enqueue_op = q.enqueue_many((range(250),))
+ enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
dequeued_t = q.dequeue()
count_placeholder = tf.placeholder(tf.int32, shape=())
dequeuemany_t = q.dequeue_many(count_placeholder)
@@ -477,7 +484,9 @@ class FIFOQueueTest(tf.test.TestCase):
elements_dequeued += 1
else:
count = random.randint(0, min(20, 250 - elements_dequeued))
- expected_range = range(elements_dequeued, elements_dequeued + count)
+ expected_range = np.arange(elements_dequeued,
+ elements_dequeued + count,
+ dtype=np.int32)
self.assertAllEqual(
expected_range, dequeuemany_t.eval({count_placeholder: count}))
elements_dequeued += count
@@ -1045,7 +1054,7 @@ class FIFOQueueTest(tf.test.TestCase):
def testBigDequeueMany(self):
with self.test_session() as sess:
q = tf.FIFOQueue(2, tf.int32, ((),))
- elem = range(4)
+ elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
deq = q.dequeue_many(4)
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 39e97531d2..347b2d501e 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.tf.gather."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py
index fe74768986..1adf22d512 100644
--- a/tensorflow/python/kernel_tests/gradient_checker.py
+++ b/tensorflow/python/kernel_tests/gradient_checker.py
@@ -3,6 +3,10 @@
The gradient checker verifies numerically that an op/graph properly
computes the gradients
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/gradient_checker_test.py b/tensorflow/python/kernel_tests/gradient_checker_test.py
index a844b7c637..f303141de9 100644
--- a/tensorflow/python/kernel_tests/gradient_checker_test.py
+++ b/tensorflow/python/kernel_tests/gradient_checker_test.py
@@ -1,4 +1,7 @@
"""Tests for tensorflow.kernels.gradient_checker."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 2209cf08ad..c64b0be950 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -1,4 +1,8 @@
"""Tests for IdentityOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index d2a51788c4..c27cfd34f8 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -1,4 +1,8 @@
"""Tests for PrecisionOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 4ce6081b7b..1b9f1323e8 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index 2eb8bdd26f..d1171102cd 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -1,6 +1,10 @@
"""Tests for tensorflow.python.ops.io_ops."""
# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tempfile
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 50e5328c3e..2ec3ebe938 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -1,4 +1,7 @@
"""Tests for tensorflow.ops.linalg_grad."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index b4607be1fb..14d657c805 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -1,8 +1,13 @@
"""Tests for tensorflow.kernels.listdiff_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -69,9 +74,7 @@ class ListDiffTest(tf.test.TestCase):
y = np.random.randint(int_low, int_high, size=y_size)
out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
if out_idx:
- out_idx = map(list, zip(*out_idx))
- out = out_idx[0]
- idx = out_idx[1]
+ out, idx = map(list, zip(*out_idx))
else:
out = []
idx = []
@@ -89,7 +92,7 @@ class ListDiffTest(tf.test.TestCase):
x = [1, 2, 3, 4]
y = [5, 6]
out = x
- idx = range(len(x))
+ idx = np.arange(len(x))
self._testListDiff(x, y, out, idx)
def testInt32EmptyX(self):
@@ -103,7 +106,7 @@ class ListDiffTest(tf.test.TestCase):
x = [1, 2, 3, 4]
y = []
out = x
- idx = range(len(x))
+ idx = np.arange(len(x))
self._testListDiff(x, y, out, idx)
def testInt32EmptyXY(self):
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 50e7422878..e9c461d535 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.kernels.logging_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/lookup_table_op_test.py b/tensorflow/python/kernel_tests/lookup_table_op_test.py
index a6b91560f1..7b5942cacd 100644
--- a/tensorflow/python/kernel_tests/lookup_table_op_test.py
+++ b/tensorflow/python/kernel_tests/lookup_table_op_test.py
@@ -1,4 +1,8 @@
"""Tests for lookup table ops from tf."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/lrn_op_test.py b/tensorflow/python/kernel_tests/lrn_op_test.py
index 85ef65b653..22da927467 100644
--- a/tensorflow/python/kernel_tests/lrn_op_test.py
+++ b/tensorflow/python/kernel_tests/lrn_op_test.py
@@ -1,5 +1,8 @@
"""Tests for local response normalization."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import copy
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index c38a2a91f1..561ab300ef 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -1,5 +1,8 @@
"""Tests for tensorflow.ops.math_ops.matmul."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index 541a937185..8c1cda5e15 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
index 8cb2fe2f8b..f071eeaeda 100644
--- a/tensorflow/python/kernel_tests/numerics_test.py
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.numerics."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py
index 5f3b1823c0..14359b28c6 100644
--- a/tensorflow/python/kernel_tests/pack_op_test.py
+++ b/tensorflow/python/kernel_tests/pack_op_test.py
@@ -1,4 +1,8 @@
"""Functional tests for Pack Op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index cb939a3a9e..95123cf84f 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.ops.nn_ops.Pad."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 529ad81e16..0398037f41 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.ops.parsing_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import itertools
import tensorflow.python.platform
@@ -33,7 +37,7 @@ def _compare_output_to_expected(
tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
i = 0 # Index into the flattened output of session.run()
- for k, v in dict_tensors.iteritems():
+ for k, v in dict_tensors.items():
expected_v = expected_tensors[k]
tf.logging.info("Comparing key: %s", k)
if isinstance(v, tf.SparseTensor):
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 5a35fd17fc..45779c1661 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -1,5 +1,8 @@
"""Functional tests for pooling operations."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/random_ops_test.py b/tensorflow/python/kernel_tests/random_ops_test.py
index aa107a22de..2ba4dac3a1 100644
--- a/tensorflow/python/kernel_tests/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random_ops_test.py
@@ -1,9 +1,12 @@
"""Tests for tensorflow.ops.random_ops."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
index 343ffdcb76..9078dce6ca 100644
--- a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
+++ b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.data_flow_ops.Queue."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import random
import re
import time
@@ -6,6 +10,7 @@ import time
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -264,7 +269,7 @@ class RandomShuffleQueueTest(tf.test.TestCase):
for _ in range(8):
float_val, int_val = sess.run(dequeued_t)
results.append((float_val, [int_val[0], int_val[1]]))
- expected = zip(float_elems, int_elems) + zip(float_elems, int_elems)
+ expected = list(zip(float_elems, int_elems)) * 2
self.assertItemsEqual(expected, results)
def testDequeueMany(self):
@@ -1028,7 +1033,7 @@ class RandomShuffleQueueTest(tf.test.TestCase):
def testBigDequeueMany(self):
with self.test_session() as sess:
q = tf.RandomShuffleQueue(2, 0, tf.int32, ((),))
- elem = range(4)
+ elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
deq = q.dequeue_many(4)
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 484e3eca43..a92640fdf9 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for Reader ops from io_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index e5cab62c09..c3b7d7a70d 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -1,4 +1,8 @@
"""Functional tests for reduction ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index be79dd40ac..c3f69ee323 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -1,5 +1,8 @@
"""Tests for Relu and ReluGrad."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index 3c91db1221..ef7b8a60ee 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -1,5 +1,8 @@
"""Tests for tensorflow.ops.reshape_op."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index e7d8e70ae8..bc1f06317b 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -1,5 +1,8 @@
"""Tests for tensorflow.ops.reverse_sequence_op."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/save_restore_ops_test.py b/tensorflow/python/kernel_tests/save_restore_ops_test.py
index d59d76c58f..ed666e9592 100644
--- a/tensorflow/python/kernel_tests/save_restore_ops_test.py
+++ b/tensorflow/python/kernel_tests/save_restore_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.io_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index dd645819a3..cb2211d6a5 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.tf.scatter."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 558ce06285..df84519337 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -1,4 +1,8 @@
"""Functional tests for segment reduction ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -13,9 +17,8 @@ class SegmentReductionHelper(tf.test.TestCase):
num_elem = 1
for x in input_shape:
num_elem *= x
- values = range(1, num_elem + 1)
- np_values = np.array(values).reshape(input_shape).astype(
- dtype.as_numpy_dtype)
+ values = np.arange(1, num_elem + 1)
+ np_values = values.reshape(input_shape).astype(dtype.as_numpy_dtype)
return tf.constant(values, shape=input_shape,
dtype=dtype), np_values
@@ -68,7 +71,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
n = 10
shape = [n, 2]
- indices = [int(i / 3) for i in range(n)]
+ indices = [i // 3 for i in range(n)]
for dtype in dtypes:
with self.test_session(use_gpu=False):
tf_x, np_x = self._input(shape, dtype=dtype)
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index ad5425e6b5..89658ce851 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -1,5 +1,8 @@
"""Tests for various tensorflow.ops.tf."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -229,7 +232,7 @@ class TileTest(tf.test.TestCase):
"int64": (tf.int64, int),
"string": (tf.string, str)
}
- for dtype_np, v in types_to_test.iteritems():
+ for dtype_np, v in types_to_test.items():
with self.test_session():
dtype_tf = v[0]
cast = v[1]
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 62d7e31dfc..d61e6984df 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -1,7 +1,12 @@
"""Functional tests for slice op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index fd25970093..6cd2cec2e0 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -1,4 +1,8 @@
"""Tests for SoftmaxOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index 216362340c..da8a2b918b 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -1,5 +1,8 @@
"""Tests for Softplus and SoftplusGrad."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
index 0f5650b89c..4d5ad7d674 100644
--- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
@@ -1,5 +1,9 @@
"""Tests for SparseConcat."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index 4529be21fc..ff442dfc53 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -1,4 +1,6 @@
"""Tests for tensorflow.ops.tf.matmul."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
index c3bcc25311..277cd35de7 100644
--- a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
@@ -1,5 +1,9 @@
"""Tests for SparseReorder."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
index 2bab89923e..315defe0b7 100644
--- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.kernels.sparse_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/sparsemask_op_test.py b/tensorflow/python/kernel_tests/sparsemask_op_test.py
index ffde8f7944..8fa4c1ea8c 100644
--- a/tensorflow/python/kernel_tests/sparsemask_op_test.py
+++ b/tensorflow/python/kernel_tests/sparsemask_op_test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 19906aa02b..aaa4fdaa55 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -1,4 +1,8 @@
"""Functional tests for Split Op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -79,7 +83,7 @@ class SplitOpTest(tf.test.TestCase):
result = sess.run(tf.split(split_dim, num_split, inp))
slices = [slice(0, x) for x in shape]
offset = 0
- length = shape[split_dim] / num_split
+ length = shape[split_dim] // num_split
for i in range(num_split):
slices[split_dim] = slice(offset, offset + length)
offset += length
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
index 8615b271b8..1804b3aeae 100644
--- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -1,4 +1,8 @@
"""Tests for StringToHashBucket op from string_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py
index 39505e18ba..2358975549 100644
--- a/tensorflow/python/kernel_tests/string_to_number_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py
@@ -1,5 +1,9 @@
"""Tests for StringToNumber op from parsing_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/summary_image_op_test.py b/tensorflow/python/kernel_tests/summary_image_op_test.py
index dfdb2c8938..d51681b8e8 100644
--- a/tensorflow/python/kernel_tests/summary_image_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_image_op_test.py
@@ -1,7 +1,12 @@
"""Tests for summary image op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.ops import image_ops
@@ -32,7 +37,7 @@ class SummaryImageOpTest(tf.test.TestCase):
scale = 127 / np.abs(const.reshape(4, -1)).max(axis=1)
offset = 128
adjusted = np.floor(scale[:, None, None, None] * const + offset)
- const[0, 1, 2, depth / 2] = np.nan
+ const[0, 1, 2, depth // 2] = np.nan
# Summarize
summ = tf.image_summary("img", const)
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 13e5021ccc..9a59c26787 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for summary ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index 497dc9ac1e..6675de31b1 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -1,4 +1,8 @@
"""Tests for TopK op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 2786eaf37b..e424a4faed 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -1,4 +1,8 @@
"""Functional tests for Transpose op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import itertools
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index 4d6543a206..3585152b06 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.kernels.unique_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py
index 4929af035f..7f5ed77a0d 100644
--- a/tensorflow/python/kernel_tests/unpack_op_test.py
+++ b/tensorflow/python/kernel_tests/unpack_op_test.py
@@ -1,7 +1,12 @@
"""Functional tests for Unpack Op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.kernel_tests import gradient_checker
diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py
index 6185b2256e..e4c9849729 100644
--- a/tensorflow/python/kernel_tests/variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/variable_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.tf.variable_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index bb538198ea..f3521d41af 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -1,4 +1,8 @@
"""Tests for variable store."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index f2a7ea0af8..0b17ffa82d 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -1,4 +1,8 @@
"""Tests for tf.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import operator
import tensorflow.python.platform
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 263f98f622..ebb0a6d881 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.reverse_sequence_op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index c6ecaff799..7d1c2d5a22 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -1,5 +1,8 @@
"""Tests for SoftmaxCrossEntropyWithLogits op."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/lib/core/pywrap_status_test.py b/tensorflow/python/lib/core/pywrap_status_test.py
index 000a784b6c..ef4127dc53 100644
--- a/tensorflow/python/lib/core/pywrap_status_test.py
+++ b/tensorflow/python/lib/core/pywrap_status_test.py
@@ -1,5 +1,9 @@
"""Tests for SWIG wrapped brain::Status."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import googletest
diff --git a/tensorflow/python/lib/io/python_io.py b/tensorflow/python/lib/io/python_io.py
index aedcd2ef03..eabe3d49dc 100644
--- a/tensorflow/python/lib/io/python_io.py
+++ b/tensorflow/python/lib/io/python_io.py
@@ -26,4 +26,8 @@ and the mask of a CRC is
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.lib.io.tf_record import *
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 00825bbda2..0ba943e354 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -1,5 +1,9 @@
"""For reading and writing TFRecords files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python import pywrap_tensorflow
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 2a463940d6..311cbd2558 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -1,5 +1,9 @@
"""Gradients for operators defined in array_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index ed780db625..28138fbf39 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -42,6 +42,10 @@ or join multiple tensors together.
@@dynamic_partition
@@dynamic_stitch
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import sys
import tensorflow.python.platform
import numpy as np
@@ -852,7 +856,7 @@ def _ReshapeShape(op):
if num_elements % known_elements != 0:
raise ValueError("input has %s elements, which isn't divisible by %d" %
(num_elements, known_elements))
- new_shape[unknown_index] = num_elements / known_elements
+ new_shape[unknown_index] = num_elements // known_elements
return [tensor_shape.TensorShape(new_shape)]
else:
# We don't know the input shape, but we know n-1 of the dimensions
@@ -1042,7 +1046,7 @@ def _SplitShape(op):
"dimension but got split_dim %d (size = %d) and num_split %d" %
(split_dim, input_shape[split_dim].value, num_split))
prefix = input_shape[:split_dim]
- size_in_split_dim = input_shape[split_dim] / num_split
+ size_in_split_dim = input_shape[split_dim] // num_split
suffix = input_shape[split_dim + 1:]
output_shape = prefix.concatenate(size_in_split_dim).concatenate(suffix)
return [output_shape] * num_split
@@ -1091,7 +1095,7 @@ def _TileGradShape(op):
else:
output_dims = []
for i, dim in enumerate(input_shape.dims):
- output_dims.append(dim / multiples[i])
+ output_dims.append(dim // multiples[i])
return [tensor_shape.TensorShape(output_dims)]
diff --git a/tensorflow/python/ops/attention_ops.py b/tensorflow/python/ops/attention_ops.py
index 4829bcd7cd..59c4f3783b 100644
--- a/tensorflow/python/ops/attention_ops.py
+++ b/tensorflow/python/ops/attention_ops.py
@@ -1,5 +1,9 @@
"""Operations for implementing attention.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
index 06857c0adc..3e3520cbe4 100644
--- a/tensorflow/python/ops/candidate_sampling_ops.py
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -1,5 +1,9 @@
"""Wrappers for primitive Neural Net (NN) Operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 08781932f9..3dedd33cb9 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -1,4 +1,7 @@
"""Operations for clipping (gradient, weight) tensors to min/max values."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import collections
diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py
index c41d1ff71d..8bab14c186 100644
--- a/tensorflow/python/ops/common_shapes.py
+++ b/tensorflow/python/ops/common_shapes.py
@@ -1,5 +1,7 @@
"""A library of common shape functions."""
-import math
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
from tensorflow.python.framework import tensor_shape
@@ -114,12 +116,10 @@ def _Get2DOutputSize(input_height, input_width, filter_height, filter_width,
if input_height.value is None or filter_height.value is None:
out_rows = None
elif padding_type == "VALID":
- out_rows = int(
- math.ceil((input_height.value - filter_height.value + 1.0)
- / row_stride))
+ out_rows = ((input_height.value - filter_height.value + row_stride) //
+ row_stride)
elif padding_type == "SAME":
- out_rows = int(math.ceil(input_height.value * 1.0
- / row_stride))
+ out_rows = (input_height.value + row_stride - 1) // row_stride
else:
raise ValueError("Invalid value for padding: %r" % padding_type)
@@ -127,11 +127,10 @@ def _Get2DOutputSize(input_height, input_width, filter_height, filter_width,
if input_width.value is None or filter_width.value is None:
out_cols = None
elif padding_type == "VALID":
- out_cols = int(
- math.ceil((input_width.value - filter_width.value + 1.0)
- / col_stride))
+ out_cols = ((input_width.value - filter_width.value + col_stride) //
+ col_stride)
elif padding_type == "SAME":
- out_cols = int(math.ceil(input_width.value * 1.0 / col_stride))
+ out_cols = (input_width.value + col_stride - 1) // col_stride
return out_rows, out_cols
@@ -357,8 +356,8 @@ def max_pool_shape(op):
if stride_d != ksize_d:
raise ValueError("Depthwise max pooling requires the depth window "
"to equal the depth stride.")
- return [tensor_shape.TensorShape(
- [batch_size, in_rows, in_cols, depth / ksize_d])]
+ return [tensor_shape.TensorShape([batch_size, in_rows, in_cols, depth //
+ ksize_d])]
def no_outputs(unused_op):
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
index 58e4222bc7..2e8fed0b49 100644
--- a/tensorflow/python/ops/constant_op.py
+++ b/tensorflow/python/ops/constant_op.py
@@ -79,10 +79,13 @@ print sess.run(var)
@@set_random_seed
"""
-"""Constant Operation.
-Has to be separate from array_ops to avoid a cyclic dependency.
-"""
+# Must be separate from array_ops to avoid a cyclic dependency.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index 3a1a5b91c0..d6a0c6e6c2 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -1,4 +1,9 @@
"""Gradients for operators defined in control_flow_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import,undefined-variable
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 2c089deed2..fe7c73bd31 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -47,6 +47,12 @@ debug your graph.
@@Assert
@@Print
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import types
@@ -1230,12 +1236,12 @@ def group(*inputs, **kwargs):
ops_on_device[dev] = [inp]
if len(ops_on_device) == 1:
# 1-level tree. The root node is the returned NoOp node.
- dev, deps = ops_on_device.items()[0]
+ (dev, deps), = ops_on_device.items()
return _GroupControlDeps(dev, deps, name=name)
# 2-level tree. The root node is the returned NoOp node.
# deps contains 1 NoOp node for each device.
deps = []
- for dev in sorted(ops_on_device.iterkeys()):
+ for dev in sorted(six.iterkeys(ops_on_device)):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
return _GroupControlDeps(None, deps, name=name)
@@ -1409,7 +1415,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="Case"):
if not exclusive:
logging.warn("%s: Provided dictionary of predicate/fn pairs, but "
"exclusive=False. Order of conditional tests is "
- "not guaranteed." % name)
+ "not guaranteed.", name)
for tup in pfp:
if not isinstance(tup, _basetuple) or len(tup) != 2:
raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 34b1ab0a25..35937a65f7 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for control_flow_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.core.framework import graph_pb2
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
index d2473490ce..248ce44533 100644
--- a/tensorflow/python/ops/data_flow_grad.py
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -1,5 +1,9 @@
"""Gradients for operators defined in data_flow_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
@@ -13,13 +17,13 @@ from tensorflow.python.ops import math_ops
def _DynamicStitchGrads(op, grad):
"""Gradients for DynamicStitch."""
- num_values = len(op.inputs) / 2
+ num_values = len(op.inputs) // 2
indices_grad = [None] * num_values
def AsInt32(x):
return (x if op.inputs[0].dtype == types.int32 else
math_ops.cast(x, types.int32))
- inputs = [AsInt32(op.inputs[i]) for i in range(num_values)]
+ inputs = [AsInt32(op.inputs[i]) for i in xrange(num_values)]
if isinstance(grad, ops.IndexedSlices):
output_shape = array_ops.shape(op.outputs[0])
output_rows = output_shape[0]
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index b625aba899..178f716e48 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1,5 +1,9 @@
"""Data Flow Operations."""
# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import re
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 6c73e687d6..80bedd4984 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -1,5 +1,9 @@
"""Operations for embeddings."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
@@ -59,10 +63,10 @@ def embedding_lookup(params, ids, name=None):
np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
- for p in range(np):
+ for p in xrange(np):
# TODO(agarwal): handle device allocations here and later in the
# colocate code.
- gather_ids = plist[p] / np
+ gather_ids = plist[p] // np
with ops.device(params[p].device):
partitioned_result.append(array_ops.gather(params[p], gather_ids))
# Stitch these back together
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index ffa7828c04..e4fe65c5a1 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -1,11 +1,16 @@
"""Implements the graph generation for computation of gradients."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import collections
import warnings
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -407,7 +412,7 @@ def gradients(ys, xs, grad_ys=None, name="gradients",
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
- # pylint: enable=protected-access
+ # pylint: enable=protected-access
op_wrapper = op
if has_control_flow:
op_wrapper = control_flow_ops.MakeWrapper(op)
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 96af01e1d8..11e1a4a88d 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.gradients."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import warnings
import tensorflow.python.platform
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 189b72e0c7..4cc3a4e1f3 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -86,6 +86,10 @@ adjustments are often useful to expand a training set and reduce overfitting.
@@per_image_whitening
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
@@ -407,16 +411,16 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
offset_crop_width = 0
offset_pad_width = 0
if target_width < original_width:
- offset_crop_width = int((original_width - target_width) / 2)
+ offset_crop_width = (original_width - target_width) // 2
elif target_width > original_width:
- offset_pad_width = int((target_width - original_width) / 2)
+ offset_pad_width = (target_width - original_width) // 2
offset_crop_height = 0
offset_pad_height = 0
if target_height < original_height:
- offset_crop_height = int((original_height - target_height) / 2)
+ offset_crop_height = (original_height - target_height) // 2
elif target_height > original_height:
- offset_pad_height = int((target_height - original_height) / 2)
+ offset_pad_height = (target_height - original_height) // 2
# Maybe crop if needed.
cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 2c51299198..c8a5eb3756 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -1,9 +1,14 @@
"""Tests for tensorflow.ops.image_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import test_util
from tensorflow.python.ops import constant_op
@@ -683,7 +688,7 @@ class JpegTest(test_util.TensorFlowTestCase):
def averageError(self, image0, image1):
self.assertEqual(image0.shape, image1.shape)
image0 = image0.astype(int) # Avoid overflow
- return np.abs(image0 - image1).sum() / float(np.prod(image0.shape))
+ return np.abs(image0 - image1).sum() / np.prod(image0.shape)
def testExisting(self):
# Read a real jpeg and verify shape
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index fe074d2556..610feb742e 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -1,4 +1,7 @@
"""Operations often used for initializing tensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import math
from tensorflow.python.framework import types
@@ -121,11 +124,12 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None):
# is the right thing for matrix multiply and convolutions (see above).
for dim in shape[:-1]:
input_size *= float(dim)
- max_val = math.sqrt(float(3) / float(input_size)) * factor
+ max_val = math.sqrt(3 / input_size) * factor
return random_ops.random_uniform(shape, -max_val, max_val,
dtype, seed=seed)
return _initializer
+
# TODO(vrv): Unhide when we are ready to expose this publicly.
def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None,
name="random_walk"):
@@ -151,9 +155,9 @@ def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None,
# layer widths.
rwg = 1.13
elif nonlinearity == array_ops.identity:
- rwg = math.exp(1.0 / float(2.0 * num_inputs))
+ rwg = math.exp(1.0 / (2.0 * num_inputs))
elif nonlinearity == nn_ops.relu:
- rwg = math.sqrt(2.0) * math.exp(1.2 / float(max(num_inputs, 6) - 2.4))
+ rwg = math.sqrt(2.0) * math.exp(1.2 / (max(num_inputs, 6) - 2.4))
else:
assert False, "Unsupported nonlinearity for Random Walk initialization."
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 829081fe58..282a520546 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -101,6 +101,10 @@ want them run by N threads.
@@shuffle_batch_join
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import types
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 893618c9dd..90de52887b 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -1,4 +1,8 @@
"""Gradients for operators defined in linalg_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 76fd83fb3d..22f4d72ff9 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -1,5 +1,9 @@
"""Operations for linear algebra."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_linalg_ops
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index daf208da9e..c4337268bf 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -1,5 +1,9 @@
"""Logging Operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_logging_ops
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index cb808ff5b8..b404fbc7d7 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -1,4 +1,7 @@
"""Gradients for operators defined in math_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
@@ -29,7 +32,7 @@ def _ReductionGradAssist(op):
def _SumGrad(op, grad):
"""Gradient for Sum."""
_, new_output_shape, input_shape = _ReductionGradAssist(op)
- tile_scaling = input_shape / new_output_shape
+ tile_scaling = input_shape // new_output_shape
grad = array_ops.reshape(grad, new_output_shape)
return [array_ops.tile(grad, tile_scaling), None]
@@ -61,7 +64,7 @@ def _MeanGrad(op, grad):
sum_grad = _SumGrad(op, grad)[0]
input_shape = array_ops.shape(op.inputs[0])
output_shape = array_ops.shape(op.outputs[0])
- factor = (math_ops.reduce_prod(input_shape) /
+ factor = (math_ops.reduce_prod(input_shape) //
math_ops.reduce_prod(output_shape))
return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
@@ -71,7 +74,7 @@ def _ProdGrad(op, grad):
"""Gradient for Prod."""
# TODO(kearnes): this gives NaNs for 0s in the input tensor
_, new_output_shape, input_shape = _ReductionGradAssist(op)
- tile_scaling = input_shape / new_output_shape
+ tile_scaling = input_shape // new_output_shape
grad = array_ops.reshape(grad * op.outputs[0], new_output_shape)
grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0])
return grad, None
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b0e548be70..f7289ff234 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -130,11 +130,14 @@ a tensor.
@@invert_permutation
"""
-import itertools
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
+import six.moves
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -415,10 +418,105 @@ def _OverrideBinaryOperatorHelper(func, op_name):
del r_binary_op_wrapper
+# Conversion table for __truediv__. None entries mean no conversion required.
+_TRUEDIV_TABLE = {
+ types.uint8: types.float32,
+ types.int8: types.float32,
+ types.int16: types.float32,
+ types.int32: types.float64,
+ types.int64: types.float64,
+ types.float32: None,
+ types.float64: None,
+ types.complex64: None,
+}
+
+
+def truediv(x, y, name=None):
+ """Divides x / y elementwise, always producing floating point results.
+
+ The same as `tf.div` for floating point arguments, but casts integer arguments
+ to floating point before dividing so that the result is always floating point.
+ This op is generated by normal `x / y` division in Python 3 and in Python 2.7
+ with `from __future__ import division`. If you want integer division that
+ rounds down, use `x // y` or `tf.floordiv`.
+
+ `x` and `y` must have the same numeric type. If the inputs are floating
+ point, the output will have the same type. If the inputs are integral, the
+ inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
+ and `int64` (matching the behavior of Numpy).
+
+ Args:
+ x: `Tensor` numerator of numeric type.
+ y: `Tensor` denominator of numeric type.
+ name: A name for the operation (optional).
+
+ Returns:
+ `x / y` evaluated in floating point.
+
+ Raises:
+ TypeError: If `x` and `y` have different dtypes.
+ """
+ with ops.op_scope([x, y], name, "truediv") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ y = ops.convert_to_tensor(y, name="y")
+ x_dtype = x.dtype.base_dtype
+ y_dtype = y.dtype.base_dtype
+ if x_dtype != y_dtype:
+ raise TypeError("x and y must have the same dtype, got %r != %r" %
+ (x_dtype, y_dtype))
+ try:
+ dtype = _TRUEDIV_TABLE[x_dtype]
+ except KeyError:
+ raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
+ if dtype is not None:
+ x = cast(x, dtype)
+ y = cast(y, dtype)
+ return div(x, y, name=name)
+
+
+def floordiv(x, y, name=None):
+ """Divides `x / y` elementwise, rounding down for floating point.
+
+ The same as `tf.div(x,y)`, but uses `tf.floor(tf.div(x,y))` for floating
+ point arguments so that the result is always an integer (though possibly an
+ integer represented as floating point). This op is generated by `x // y`
+ floor division in Python 3 and in Python 2.7 with
+ `from __future__ import division`.
+
+ Note that for efficiency, __floordiv__ uses C semantics for negative numbers
+ (unlike Python and Numpy).
+
+ `x` and `y` must have the same type, and the result will have the same type
+ as well.
+
+ Args:
+ x: `Tensor` numerator of real numeric type.
+ y: `Tensor` numerator of real numeric type.
+ name: A name for the operation (optional).
+
+ Returns:
+ `x / y` rounded down (except possibly for integers in C).
+
+ Raises:
+ TypeError: If the inputs are complex.
+ """
+ with ops.op_scope([x, y], name, "floordiv") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ dtype = x.dtype
+ if dtype.is_floating:
+ return floor(div(x, y), name=name)
+ else:
+ if not dtype.is_integer:
+ raise TypeError("Expected floating point or integer, got %r" % dtype)
+ return div(x, y, name=name)
+
+
_OverrideBinaryOperatorHelper(add, "add")
_OverrideBinaryOperatorHelper(sub, "sub")
_OverrideBinaryOperatorHelper(mul, "mul")
_OverrideBinaryOperatorHelper(div, "div")
+_OverrideBinaryOperatorHelper(truediv, "truediv")
+_OverrideBinaryOperatorHelper(floordiv, "floordiv")
_OverrideBinaryOperatorHelper(mod, "mod")
@@ -475,8 +573,8 @@ def _RangeShape(op):
if start_value is None or limit_value is None or delta_value is None:
return [tensor_shape.vector(None)]
else:
- return [tensor_shape.vector(
- (limit_value - start_value + delta_value - 1) / delta_value)]
+ return [tensor_shape.vector((limit_value - start_value + delta_value - 1) //
+ delta_value)]
# Reduction operations
@@ -1020,8 +1118,9 @@ def _BroadcastShape(op):
# To compute the broadcasted dimensions, we zip together shape_x and shape_y,
# and pad with 1 to make them the same length.
- broadcasted_dims = reversed(list(itertools.izip_longest(
- reversed(shape_x.dims), reversed(shape_y.dims),
+ broadcasted_dims = reversed(list(six.moves.zip_longest(
+ reversed(shape_x.dims),
+ reversed(shape_y.dims),
fillvalue=tensor_shape.Dimension(1))))
# Next we combine the dimensions according to the numpy broadcasting rules.
# http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 86ea04f54d..5b6541a848 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.math_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 1ad08ca4c9..caf47b1431 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -167,7 +167,11 @@ classes when using one of the sampled loss functions above.
@@compute_accidental_hits
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
@@ -526,15 +530,24 @@ def moments(x, axes, name=None):
name: Name used to scope the operations that compute the moments.
Returns:
- Two `Tensors`: `mean` and `variance`.
+ Two `Tensor` objects: `mean` and `variance`.
"""
with ops.op_scope([x, axes], name, "moments"):
x = ops.convert_to_tensor(x, name="x")
- divisor = 1.0
- for d in xrange(len(x.get_shape())):
- if d in axes:
+ x_shape = x.get_shape()
+ if all(x_shape[d].value is not None for d in axes):
+ # The shape is known in the relevant axes, so we can statically
+ # compute the divisor.
+ divisor = 1.0
+ for d in set(axes):
divisor *= x.get_shape()[d].value
- divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
+ divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
+ else:
+ divisor = constant_op.constant(1.0, dtype=x.dtype)
+ x_dynamic_shape = array_ops.shape(x)
+ for d in set(axes):
+ divisor *= math_ops.cast(x_dynamic_shape[d], x.dtype)
+ divisor = math_ops.inv(divisor, name="divisor")
axes = constant_op.constant(axes, name="axes")
# Note: We do not use Mean here because it is very slow on GPU.
# Note 2: The expression below is potentially more stable.
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 0cf867d217..535d55f00f 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -1,5 +1,9 @@
"""Gradients for operators defined in nn_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 0ffe95de2b..728a33aa20 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1,5 +1,9 @@
"""Wrappers for primitive Neural Net (NN) Operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
@@ -75,6 +79,7 @@ def deconv2d(value, filter, output_shape, strides, padding="SAME",
padding=padding,
name=name)
+
# pylint: disable=protected-access
def bias_add(value, bias, name=None):
"""Adds `bias` to `value`.
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 0bb9e787a5..48f7a4c987 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1,14 +1,19 @@
"""Tests for tensorflow.ops.nn."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import math
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import test_util
from tensorflow.python.framework import types
from tensorflow.python.kernel_tests import gradient_checker as gc
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients
@@ -80,9 +85,9 @@ class ZeroFractionTest(test_util.TensorFlowTestCase):
def _ZeroFraction(self, x):
assert x.shape
- total_elements = float(np.prod(x.shape))
- nonzeros = float(np.count_nonzero(x.flatten()))
- return 1.0 - (nonzeros / total_elements)
+ total_elements = np.prod(x.shape)
+ nonzeros = np.count_nonzero(x.flatten())
+ return 1.0 - nonzeros / total_elements
def testZeroFraction(self):
x_shape = [5, 17]
@@ -561,6 +566,30 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
class MomentsTest(test_util.TensorFlowTestCase):
+ def RunMomentTestWithDynamicShape(self, shape, global_norm):
+ with self.test_session():
+ # shape = [batch, width, height, depth]
+ assert len(shape) == 4
+
+ x_numpy = np.random.normal(size=shape).astype(np.float32)
+ x = array_ops.placeholder(types.float32, shape=[None] * len(shape))
+
+ axes = [0, 1, 2] if global_norm else [0]
+ mean, var = nn.moments(x, axes)
+
+ num_elements = np.prod([shape[i] for i in axes])
+
+ ax = (0, 1, 2) if global_norm else (0)
+ expected_mean = np.sum(x_numpy, axis=ax) / num_elements
+ expected_mean_squared = np.multiply(expected_mean, expected_mean)
+ expected_x_squared = np.sum(
+ np.multiply(x_numpy, x_numpy), axis=ax) / num_elements
+ expected_variance = expected_x_squared - expected_mean_squared
+
+ # Check that the moments are correct.
+ self.assertAllClose(expected_mean, mean.eval(feed_dict={x: x_numpy}))
+ self.assertAllClose(expected_variance, var.eval(feed_dict={x: x_numpy}))
+
def RunMomentTest(self, shape, global_norm):
with self.test_session():
# shape = [batch, width, height, depth]
@@ -568,7 +597,7 @@ class MomentsTest(test_util.TensorFlowTestCase):
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = constant_op.constant(x_numpy)
- x.set_shape(shape)
+
axes = [0, 1, 2] if global_norm else [0]
mean, var = nn.moments(x, axes)
@@ -587,9 +616,11 @@ class MomentsTest(test_util.TensorFlowTestCase):
def testBasic(self):
self.RunMomentTest(shape=[2, 3, 5, 4], global_norm=False)
+ self.RunMomentTestWithDynamicShape(shape=[2, 3, 5, 4], global_norm=False)
def testGlobalNormalization(self):
self.RunMomentTest(shape=[2, 3, 5, 4], global_norm=True)
+ self.RunMomentTestWithDynamicShape(shape=[2, 3, 5, 4], global_norm=True)
def _testGlobalGradient(self, from_y="mean"):
with self.test_session():
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index 93f5d5db20..dda6569b5a 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -1,5 +1,9 @@
"""Connects all float and double tensors to CheckNumericsOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index 5947b6df89..957c5123bc 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -1,5 +1,9 @@
"""Class to hold a library of OpDefs and use it to create Brain operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import numbers
from tensorflow.core.framework import attr_value_pb2
@@ -138,6 +142,7 @@ def _MakeStr(v, arg_name):
if not isinstance(v, basestring):
raise TypeError("Expected string for argument '%s' not %s." %
(arg_name, repr(v)))
+ # TODO(irving): Figure out what to do here from Python 3
return str(v) # Convert unicode strings to bytes.
diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py
index ed2341b2bc..d62e6c0df0 100644
--- a/tensorflow/python/ops/op_def_library_test.py
+++ b/tensorflow/python/ops/op_def_library_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.python.ops.op_def_library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from google.protobuf import text_format
from tensorflow.core.framework import op_def_pb2
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 74368d9e69..860a046da0 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -1,4 +1,8 @@
"""Parsing Ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import re
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 07af214d6c..9c2c8d0bc1 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -1,5 +1,9 @@
"""Operations for generating random numbers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -21,6 +25,7 @@ def _ShapeTensor(shape):
dtype = None
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
+
# pylint: disable=protected-access
def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
seed=None, name=None):
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 3685b671b7..fc18f4407b 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -1,4 +1,8 @@
"""Gradients for operators defined in sparse_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import sparse_ops
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index e0209bd100..1a7af78a33 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -22,9 +22,14 @@ dimension, and dense along all other dimensions.
@@sparse_retain
@@sparse_fill_empty_rows
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -136,7 +141,7 @@ def _SparseConcatShape(op):
output_val_elems = tensor_shape.Dimension(0)
output_shape_shape = tensor_shape.TensorShape(None)
- for i in range(num_inputs):
+ for i in xrange(num_inputs):
num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0])
output_ind_rows += num_elems_i
output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1])
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
index 07a5e6c6da..23979e68f7 100644
--- a/tensorflow/python/ops/sparse_ops_test.py
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for Python ops defined in sparse_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index beef8e75b5..c67b1813ed 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -2,6 +2,10 @@
"""Import names of Tensor Flow standard Ops."""
# Imports the following modules so that @RegisterGradient get executed.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import math_grad
diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py
index d9b084693c..5bfae8d9a9 100644
--- a/tensorflow/python/ops/state_grad.py
+++ b/tensorflow/python/ops/state_grad.py
@@ -1,5 +1,9 @@
"""Gradients for operators defined in state_ops.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 011761904e..93c2877254 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -60,6 +60,10 @@ automatically by the optimizers in most cases.
@@IndexedSlices
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 8181fe9a2a..d409c20ec4 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -1,5 +1,9 @@
"""String Ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import common_shapes
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
index d65fd1ea7c..a2acd80e0d 100644
--- a/tensorflow/python/ops/summary_ops.py
+++ b/tensorflow/python/ops/summary_ops.py
@@ -1,5 +1,9 @@
"""Summary Operations."""
# pylint: disable=wildcard-import,protected-access
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_summary_ops
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 532d9fa83e..47149163af 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1,5 +1,9 @@
"""A class to store named variables and a scope operator to manage sharing."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import contextlib
from tensorflow.python.framework import ops
@@ -72,14 +76,14 @@ class _VariableStore(object):
found_var = self._vars[name]
if not shape.is_compatible_with(found_var.get_shape()):
raise ValueError("Trying to share variable %s, but specified shape %s"
- " and found shape %s." % (name, str(shape),
- str(found_var.get_shape())))
+ " and found shape %s." % (name, shape,
+ found_var.get_shape()))
if not dtype.is_compatible_with(found_var.dtype):
dtype_str = dtype.name
found_type_str = found_var.dtype.name
raise ValueError("Trying to share variable %s, but specified dtype %s"
- " and found dtype %s." % (name, str(dtype_str),
- str(found_type_str)))
+ " and found dtype %s." % (name, dtype_str,
+ found_type_str))
return found_var
# The code below handles only the case of creating a new variable.
@@ -97,7 +101,7 @@ class _VariableStore(object):
collections=collections)
self._vars[name] = v
logging.info("Created variable %s with shape %s and init %s", v.name,
- format(shape), str(initializer))
+ format(shape), initializer)
return v
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 2f30c85da2..c650141cfa 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1,4 +1,8 @@
"""Variable class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/platform/__init__.py b/tensorflow/python/platform/__init__.py
index 10b12f4abc..e67e33997f 100644
--- a/tensorflow/python/platform/__init__.py
+++ b/tensorflow/python/platform/__init__.py
@@ -1,5 +1,8 @@
"""Setup system-specific platform environment for TensorFlow."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from . import control_imports
if control_imports.USE_OSS:
from tensorflow.python.platform.default._init import *
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
index 7186d6e0b5..589df603d4 100644
--- a/tensorflow/python/platform/app.py
+++ b/tensorflow/python/platform/app.py
@@ -1,5 +1,8 @@
"""Switch between depending on pyglib.app or an OSS replacement."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/default/_app.py b/tensorflow/python/platform/default/_app.py
index 5917d00ce3..0b6b802c43 100644
--- a/tensorflow/python/platform/default/_app.py
+++ b/tensorflow/python/platform/default/_app.py
@@ -1,11 +1,15 @@
"""Generic entry point script."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import sys
from tensorflow.python.platform import flags
def run():
- f = flags.FLAGS
- f._parse_flags()
- main = sys.modules['__main__'].main
- sys.exit(main(sys.argv))
+ f = flags.FLAGS
+ f._parse_flags()
+ main = sys.modules['__main__'].main
+ sys.exit(main(sys.argv))
diff --git a/tensorflow/python/platform/default/_flags.py b/tensorflow/python/platform/default/_flags.py
index ceccda6e5c..3a9b0f18c1 100644
--- a/tensorflow/python/platform/default/_flags.py
+++ b/tensorflow/python/platform/default/_flags.py
@@ -1,4 +1,8 @@
"""Implementation of the flags interface."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import argparse
diff --git a/tensorflow/python/platform/default/_gfile.py b/tensorflow/python/platform/default/_gfile.py
index b3c4b8f9b9..54a33c3093 100644
--- a/tensorflow/python/platform/default/_gfile.py
+++ b/tensorflow/python/platform/default/_gfile.py
@@ -1,5 +1,9 @@
"""File processing utilities."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import errno
import functools
import glob as _glob
diff --git a/tensorflow/python/platform/default/_googletest.py b/tensorflow/python/platform/default/_googletest.py
index 42e0eac18a..04d1efefa9 100644
--- a/tensorflow/python/platform/default/_googletest.py
+++ b/tensorflow/python/platform/default/_googletest.py
@@ -1,4 +1,8 @@
"""Imports unittest as a replacement for testing.pybase.googletest."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import inspect
import itertools
import os
diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py
index 5f0ace51fb..66bf2c0889 100644
--- a/tensorflow/python/platform/default/_logging.py
+++ b/tensorflow/python/platform/default/_logging.py
@@ -2,6 +2,10 @@
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
# pylint: disable=invalid-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import sys
import time
diff --git a/tensorflow/python/platform/default/_parameterized.py b/tensorflow/python/platform/default/_parameterized.py
index 5d141568ed..51d9e92516 100644
--- a/tensorflow/python/platform/default/_parameterized.py
+++ b/tensorflow/python/platform/default/_parameterized.py
@@ -1,2 +1,6 @@
"""Extension to unittest to run parameterized tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
raise ImportError("Not implemented yet.")
diff --git a/tensorflow/python/platform/default/_resource_loader.py b/tensorflow/python/platform/default/_resource_loader.py
index 69f425072f..2d67591468 100644
--- a/tensorflow/python/platform/default/_resource_loader.py
+++ b/tensorflow/python/platform/default/_resource_loader.py
@@ -1,5 +1,9 @@
"""Read a file and return its contents."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
from tensorflow.python.platform import logging
@@ -23,4 +27,4 @@ def load_resource(path):
with open(path, 'rb') as f:
return f.read()
except IOError as e:
- logging.warning('IOError %s on path %s' % (e, path))
+ logging.warning('IOError %s on path %s', e, path)
diff --git a/tensorflow/python/platform/default/_status_bar.py b/tensorflow/python/platform/default/_status_bar.py
index 2953908724..33c8db6e3e 100644
--- a/tensorflow/python/platform/default/_status_bar.py
+++ b/tensorflow/python/platform/default/_status_bar.py
@@ -1,5 +1,9 @@
"""A no-op implementation of status bar functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
def SetupStatusBarInsideGoogle(unused_link_text, unused_port):
pass
diff --git a/tensorflow/python/platform/default/flags_test.py b/tensorflow/python/platform/default/flags_test.py
index 1b15ca138a..18db845545 100644
--- a/tensorflow/python/platform/default/flags_test.py
+++ b/tensorflow/python/platform/default/flags_test.py
@@ -1,4 +1,8 @@
"""Tests for our flags implementation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import sys
from tensorflow.python.platform.default import _googletest as googletest
diff --git a/tensorflow/python/platform/default/gfile_test.py b/tensorflow/python/platform/default/gfile_test.py
index 9eec952e95..2efd073281 100644
--- a/tensorflow/python/platform/default/gfile_test.py
+++ b/tensorflow/python/platform/default/gfile_test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import shutil
@@ -24,8 +28,8 @@ class _BaseTest(object):
try:
shutil.rmtree(self._tmp_dir)
except OSError:
- logging.warn("[%s] Post-test directory cleanup failed: %s"
- % (self, self._tmp_dir))
+ logging.warn("[%s] Post-test directory cleanup failed: %s",
+ self, self._tmp_dir)
class _GFileBaseTest(_BaseTest):
diff --git a/tensorflow/python/platform/default/logging_test.py b/tensorflow/python/platform/default/logging_test.py
index fd492bc384..5b36eda5b5 100644
--- a/tensorflow/python/platform/default/logging_test.py
+++ b/tensorflow/python/platform/default/logging_test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.platform.default import _googletest as googletest
from tensorflow.python.platform.default import _logging as logging
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
index 85bb200e18..641c73dc1a 100644
--- a/tensorflow/python/platform/flags.py
+++ b/tensorflow/python/platform/flags.py
@@ -1,5 +1,8 @@
"""Switch between depending on pyglib.flags or open-source gflags."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index a0737cd59b..4f539f043a 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -1,5 +1,8 @@
"""Switch between depending on pyglib.gfile or an OSS replacement."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
index 2b4808552a..bb07578cbe 100644
--- a/tensorflow/python/platform/googletest.py
+++ b/tensorflow/python/platform/googletest.py
@@ -1,5 +1,8 @@
"""Switch between depending on googletest or unittest."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/logging.py b/tensorflow/python/platform/logging.py
index 6a064398d5..0627b3e70c 100644
--- a/tensorflow/python/platform/logging.py
+++ b/tensorflow/python/platform/logging.py
@@ -1,5 +1,8 @@
"""Switch between depending on pyglib.logging or regular logging."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/parameterized.py b/tensorflow/python/platform/parameterized.py
index 62b615474f..59c0777193 100644
--- a/tensorflow/python/platform/parameterized.py
+++ b/tensorflow/python/platform/parameterized.py
@@ -1,5 +1,8 @@
"""Switch between depending on pyglib.gfile or an OSS replacement."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py
index 44ae05caf7..46f0ab654b 100644
--- a/tensorflow/python/platform/resource_loader.py
+++ b/tensorflow/python/platform/resource_loader.py
@@ -1,5 +1,8 @@
"""Load a file resource and return the contents."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/status_bar.py b/tensorflow/python/platform/status_bar.py
index 87a84d9898..c31a7f233c 100644
--- a/tensorflow/python/platform/status_bar.py
+++ b/tensorflow/python/platform/status_bar.py
@@ -1,5 +1,8 @@
"""Switch between an internal status bar and a no-op version."""
from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 7d46f9cbc2..c21821243d 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.platform.googletest import GetTempDir
from tensorflow.python.platform.googletest import main
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py
index ae067d94fe..4c4963c1b9 100644
--- a/tensorflow/python/summary/event_accumulator.py
+++ b/tensorflow/python/summary/event_accumulator.py
@@ -1,4 +1,7 @@
"""Takes a generator of values, and accumulates them for a frontend."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import collections
import threading
@@ -393,7 +396,7 @@ class EventAccumulator(object):
bucket_limit = list(histo.bucket_limit)
bucket_total = sum(bucket)
- fraction_weights = [float(10000*x)/bucket_total for x in bucket]
+ fraction_weights = [10000 * x / bucket_total for x in bucket]
cumsum_weights = _CumulativeSum(fraction_weights)
percentiles = [
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
index c8de80ccba..a28906c71b 100644
--- a/tensorflow/python/summary/event_accumulator_test.py
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -1,7 +1,12 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.platform import gfile
@@ -189,8 +194,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
# Create the expected values after compressing hst2
expected_vals2 = [
ea.CompressedHistogramValue(bp, val)
- for bp, val in [(0, -2), (2500, 2), (5000, 2 + float(1) / 3), (
- 7500, 2 + float(2) / 3), (10000, 3)]
+ for bp, val in [(0, -2), (2500, 2), (5000, 2 + 1 / 3), (7500, 2 + 2 / 3
+ ), (10000, 3)]
]
expected_cmphst2 = ea.CompressedHistogramEvent(
wall_time=2,
diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py
index 9966d76b21..9070d75e3c 100644
--- a/tensorflow/python/summary/event_multiplexer.py
+++ b/tensorflow/python/summary/event_multiplexer.py
@@ -1,11 +1,16 @@
"""Provides an interface for working with multiple event files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import threading
from tensorflow.python.platform import gfile
from tensorflow.python.platform import logging
from tensorflow.python.summary import event_accumulator
+import six
class EventMultiplexer(object):
@@ -77,7 +82,7 @@ class EventMultiplexer(object):
self._autoupdate_interval = None
self._size_guidance = size_guidance
if run_path_map is not None:
- for (run, path) in run_path_map.iteritems():
+ for (run, path) in six.iteritems(run_path_map):
self.AddRun(path, run)
def AddRun(self, path, name=None):
@@ -108,8 +113,8 @@ class EventMultiplexer(object):
if name in self._paths and self._paths[name] != path:
# TODO(danmane) - Make it impossible to overwrite an old path with
# a new path (just give the new path a distinct name)
- logging.warning('Conflict for name %s: old path %s, new path %s' %
- (name, self._paths[name], path))
+ logging.warning('Conflict for name %s: old path %s, new path %s',
+ name, self._paths[name], path)
logging.info('Constructing EventAccumulator for %s', path)
accumulator = event_accumulator.EventAccumulator(path,
self._size_guidance)
@@ -162,9 +167,9 @@ class EventMultiplexer(object):
subname = s
self.AddRun(os.path.join(path, s), subname)
- if filter(event_accumulator.IsTensorFlowEventsFile, paths):
+ if list(filter(event_accumulator.IsTensorFlowEventsFile, paths)):
directory_name = os.path.split(path)[1]
- logging.info('Directory %s has event files; loading' % directory_name)
+ logging.info('Directory %s has event files; loading', directory_name)
if name:
dname = name
else:
@@ -176,7 +181,7 @@ class EventMultiplexer(object):
"""Call `Reload` on every `EventAccumulator`."""
self._reload_called = True
with self._accumulators_mutex:
- loaders = self._accumulators.values()
+ loaders = list(self._accumulators.values())
for l in loaders:
l.Reload()
@@ -187,7 +192,7 @@ class EventMultiplexer(object):
self._autoupdate_interval = interval
self._autoupdate_called = True
with self._accumulators_mutex:
- loaders = self._accumulators.values()
+ loaders = list(self._accumulators.values())
for l in loaders:
l.AutoUpdate(interval)
return self
@@ -295,7 +300,7 @@ class EventMultiplexer(object):
"""
with self._accumulators_mutex:
# To avoid nested locks, we construct a copy of the run-accumulator map
- items = list(self._accumulators.iteritems())
+ items = list(six.iteritems(self._accumulators))
return {
run_name: accumulator.Tags()
for run_name, accumulator in items
@@ -334,7 +339,7 @@ def AutoloadingMultiplexer(path_to_run, interval_secs=60,
if not isinstance(path_to_run, dict):
raise TypeError('path_to_run should be a dict, was %s', path_to_run)
def Load():
- for (path, name) in path_to_run.iteritems():
+ for (path, name) in six.iteritems(path_to_run):
logging.info('Checking for new runs in %s', path)
multiplexer.AddRunsFromDirectory(path, name)
t = threading.Timer(interval_secs, Load)
diff --git a/tensorflow/python/summary/event_multiplexer_test.py b/tensorflow/python/summary/event_multiplexer_test.py
index 35a8aed266..4b6b29f66d 100644
--- a/tensorflow/python/summary/event_multiplexer_test.py
+++ b/tensorflow/python/summary/event_multiplexer_test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import tensorflow.python.platform
@@ -67,7 +71,7 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
def testRunNamesRespected(self):
x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
- self.assertItemsEqual(x.Runs().keys(), ['run1', 'run2'])
+ self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'run2'])
self.assertEqual(x._GetAccumulator('run1')._path, 'path1')
self.assertEqual(x._GetAccumulator('run2')._path, 'path2')
@@ -127,14 +131,14 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
path1 = join(realdir, 'path1')
gfile.MkDir(path1)
x.AddRunsFromDirectory(realdir)
- self.assertEqual(x.Runs().keys(), ['path1'], 'loaded run: path1')
+ self.assertEqual(sorted(x.Runs().keys()), ['path1'], 'loaded run: path1')
loader1 = x._GetAccumulator('path1')
self.assertEqual(loader1._path, path1, 'has the correct path')
path2 = join(realdir, 'path2')
gfile.MkDir(path2)
x.AddRunsFromDirectory(realdir)
- self.assertItemsEqual(x.Runs().keys(), ['path1', 'path2'])
+ self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2'])
self.assertEqual(x._GetAccumulator('path1'), loader1,
'loader1 not regenerated')
loader2 = x._GetAccumulator('path2')
@@ -142,7 +146,7 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
path2_2 = join(path2, 'path2')
gfile.MkDir(path2_2)
x.AddRunsFromDirectory(path2)
- self.assertItemsEqual(x.Runs().keys(), ['path1', 'path2'])
+ self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2'])
self.assertNotEqual(loader2, x._GetAccumulator('path2'),
'loader2 regenerated')
self.assertEqual(x._GetAccumulator('path2')._path, path2_2,
@@ -207,7 +211,7 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
x = event_multiplexer.EventMultiplexer()
x.AddRun('run1_path', 'run1')
run1 = x._GetAccumulator('run1')
- self.assertEqual(x.Runs().keys(), ['run1'])
+ self.assertEqual(sorted(x.Runs().keys()), ['run1'])
self.assertEqual(run1._path, 'run1_path')
x.AddRun('run1_path', 'run1')
@@ -219,7 +223,7 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
self.assertNotEqual(run1, new_run1)
x.AddRun('runName3')
- self.assertItemsEqual(x.Runs().keys(), ['run1', 'runName3'])
+ self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'runName3'])
self.assertEqual(x._GetAccumulator('runName3')._path, 'runName3')
def testAddRunMaintainsLoading(self):
diff --git a/tensorflow/python/summary/impl/directory_watcher.py b/tensorflow/python/summary/impl/directory_watcher.py
index 830e538cb6..d5557022a1 100644
--- a/tensorflow/python/summary/impl/directory_watcher.py
+++ b/tensorflow/python/summary/impl/directory_watcher.py
@@ -1,4 +1,8 @@
"""Contains the implementation for the DirectoryWatcher class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
from tensorflow.python.platform import gfile
diff --git a/tensorflow/python/summary/impl/directory_watcher_test.py b/tensorflow/python/summary/impl/directory_watcher_test.py
index a22e3f2922..6b7f9ec33e 100644
--- a/tensorflow/python/summary/impl/directory_watcher_test.py
+++ b/tensorflow/python/summary/impl/directory_watcher_test.py
@@ -1,5 +1,9 @@
"""Tests for directory_watcher."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
import shutil
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
index ac7c4be2b1..eab8783a10 100644
--- a/tensorflow/python/summary/impl/event_file_loader.py
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -1,4 +1,6 @@
"""Functionality for loading events from a record file."""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
from tensorflow.core.util import event_pb2
diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py
index 1dc29d85d5..0b98c69d70 100644
--- a/tensorflow/python/summary/impl/event_file_loader_test.py
+++ b/tensorflow/python/summary/impl/event_file_loader_test.py
@@ -1,5 +1,9 @@
"""Tests for event_file_loader."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os
from tensorflow.python.framework import test_util
diff --git a/tensorflow/python/summary/impl/reservoir.py b/tensorflow/python/summary/impl/reservoir.py
index 2c9b294841..6af510a814 100644
--- a/tensorflow/python/summary/impl/reservoir.py
+++ b/tensorflow/python/summary/impl/reservoir.py
@@ -1,5 +1,9 @@
"""A key-value[] store that implements reservoir sampling on the values."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import collections
import random
import threading
@@ -64,7 +68,7 @@ class Reservoir(object):
['list', 'of', 'keys'] in the Reservoir.
"""
with self._mutex:
- return self._buckets.keys()
+ return list(self._buckets.keys())
def Items(self, key):
"""Return items associated with given key.
diff --git a/tensorflow/python/summary/impl/reservoir_test.py b/tensorflow/python/summary/impl/reservoir_test.py
index 46cbde5940..6bc5a7ed76 100644
--- a/tensorflow/python/summary/impl/reservoir_test.py
+++ b/tensorflow/python/summary/impl/reservoir_test.py
@@ -1,5 +1,10 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.platform import googletest
from tensorflow.python.summary.impl import reservoir
@@ -84,7 +89,7 @@ class ReservoirBucketTest(googletest.TestCase):
b = reservoir._ReservoirBucket(100)
for i in xrange(100):
b.AddItem(i)
- self.assertEqual(b.Items(), range(100))
+ self.assertEqual(b.Items(), list(xrange(100)))
def testDoesntOverfill(self):
b = reservoir._ReservoirBucket(10)
@@ -119,7 +124,7 @@ class ReservoirBucketTest(googletest.TestCase):
b = reservoir._ReservoirBucket(0)
for i in xrange(20):
b.AddItem(i)
- self.assertEqual(b.Items(), range(i+1))
+ self.assertEqual(b.Items(), list(range(i + 1)))
def testSizeRequirement(self):
with self.assertRaises(ValueError):
@@ -134,7 +139,7 @@ class ReservoirBucketStatisticalDistributionTest(googletest.TestCase):
self.total = 1000000
self.samples = 10000
self.n_buckets = 100
- self.total_per_bucket = self.total / self.n_buckets
+ self.total_per_bucket = self.total // self.n_buckets
self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly '
'divisible by the number of buckets')
self.assertTrue(self.total > self.samples, 'need to have more items '
@@ -164,7 +169,7 @@ class ReservoirBucketStatisticalDistributionTest(googletest.TestCase):
modbins = [0] * self.n_buckets
# Slice off the last item when we iterate.
for item in b.Items()[0:-1]:
- divbins[item / self.total_per_bucket] += 1
+ divbins[item // self.total_per_bucket] += 1
modbins[item % self.n_buckets] += 1
for bucket_index in xrange(self.n_buckets):
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 41cf2e00f4..4e07552a5f 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -1,4 +1,8 @@
"""Adagrad for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.training import optimizer
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index ee83791eb5..451d9054b9 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -1,4 +1,8 @@
"""Functional tests for aggregate operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 266430bb13..bb7ee5de0b 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -1,4 +1,8 @@
"""Adam for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index f92728d0c7..b8f6d24e69 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -1,4 +1,8 @@
"""Tests for Adam."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index f090e6d222..2d0f1c2be9 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -1,4 +1,8 @@
"""Coordinator to help multiple threads stop when requested."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import sys
import threading
import time
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
index bcd1234f3d..ed5de04613 100644
--- a/tensorflow/python/training/coordinator_test.py
+++ b/tensorflow/python/training/coordinator_test.py
@@ -1,4 +1,8 @@
"""Tests for Coordinator."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import sys
import threading
import time
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index 6b9471a5ed..d479162e7f 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -1,4 +1,8 @@
"""FTRL-Proximal for Tensor Flow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index eb581048f1..9a4f6a87ee 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -1,4 +1,8 @@
"""Functional tests for Ftrl operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index 21247aacf3..1ed11209ad 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -1,4 +1,8 @@
"""GradientDescent for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
# pylint: disable=unused-import
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index d5b0cae401..00be5ffd6c 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -1,4 +1,8 @@
"""Functional test for GradientDescent."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 0a383efcf9..6734690397 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -40,7 +40,11 @@ want them run by N threads.
@@shuffle_batch_join
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import types
@@ -249,9 +253,9 @@ def _merge_shapes(shape_list, enqueue_many):
def _shapes(tensor_list_list, shapes, enqueue_many):
if shapes is None:
l = len(tensor_list_list[0])
- shapes = [_merge_shapes([tl[i].get_shape().as_list()
- for tl in tensor_list_list],
- enqueue_many) for i in range(l)]
+ shapes = [_merge_shapes(
+ [tl[i].get_shape().as_list() for tl in tensor_list_list], enqueue_many)
+ for i in xrange(l)]
return shapes
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index fe8c195e77..d85df52431 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -1,9 +1,14 @@
"""Tests for training.input."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
import os
import itertools
import tensorflow.python.platform
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -127,7 +132,7 @@ class RangeInputProducerTest(tf.test.TestCase):
# No randomness, so just see repeated copies of the input.
output = dequeue_many.eval()
- self.assertAllEqual(range(range_size) * num_epochs, output)
+ self.assertAllEqual(list(xrange(range_size)) * num_epochs, output)
# Reached the limit.
with self.assertRaises(tf.errors.OutOfRangeError):
@@ -254,8 +259,8 @@ class BatchTest(tf.test.TestCase):
for i in range(num_batches):
results = sess.run(batched)
- self.assertAllEqual(results[0],
- range(i * batch_size, (i + 1) * batch_size))
+ self.assertAllEqual(results[0], np.arange(i * batch_size,
+ (i + 1) * batch_size))
self.assertAllEqual(results[1], ["string"] * batch_size)
# Reached the limit.
@@ -318,7 +323,7 @@ class BatchJoinTest(tf.test.TestCase):
all_a = []
seen_b = 0
saw_both = 0
- num_batches = (num_a + num_b) / batch_size
+ num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = sess.run(batched)
tf.logging.info("Batch %d: %s", i, results[0])
@@ -337,7 +342,7 @@ class BatchJoinTest(tf.test.TestCase):
self.assertGreater(saw_both, 1)
# Verify the order of results from "a" were preserved.
- self.assertAllEqual(all_a, range(num_a))
+ self.assertAllEqual(all_a, np.arange(num_a))
self.assertEqual(seen_b, num_b)
# Reached the limit.
@@ -441,7 +446,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
all_a = []
seen_b = 0
saw_both = 0
- num_batches = (num_a + num_b) / batch_size
+ num_batches = (num_a + num_b) // batch_size
for i in range(num_batches):
results = sess.run(batched)
tf.logging.info("Batch %d: %s", i, results[0])
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index cafcb26d01..8450fae5cb 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -1,4 +1,8 @@
"""Various learning rate decay functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index b85d58cae7..bc103591c2 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -1,4 +1,8 @@
"""Functional test for learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.framework import test_util
@@ -33,7 +37,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
# Decayed learning rate
assign_100.op.run()
- expected = .1 * 0.96 ** (100 / 3)
+ expected = .1 * 0.96**(100 // 3)
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testVariables(self):
@@ -52,7 +56,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
# Decayed learning rate
assign_100.op.run()
- expected = .1 * 0.96 ** (100 / 3)
+ expected = .1 * 0.96**(100 // 3)
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index fdd434359f..cdd8f0cd5c 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -1,4 +1,8 @@
"""Momentum for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.training import optimizer
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 2cf86d97c9..580294cf6a 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -1,7 +1,12 @@
"""Tests for Momentum."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index e0062a902b..d834b698ba 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -1,4 +1,8 @@
"""Maintain moving averages of parameters."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 73ee94b400..53d524a325 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -1,6 +1,11 @@
"""Functional test for moving_averages.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import test_util
from tensorflow.python.framework import types
from tensorflow.python.ops import constant_op
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index c0480f6c5c..63826170be 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -1,6 +1,10 @@
"""Base class for optimizers."""
# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import types as tf_types
from tensorflow.python.ops import array_ops
@@ -204,7 +208,7 @@ class Optimizer(object):
loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP))
if gate_gradients == Optimizer.GATE_GRAPH:
grads = control_flow_ops.tuple(grads)
- grads_and_vars = zip(grads, var_list)
+ grads_and_vars = list(zip(grads, var_list))
self._assert_valid_dtypes([v for g, v in grads_and_vars if g is not None])
return grads_and_vars
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
index af9048c114..6c91a89ab7 100644
--- a/tensorflow/python/training/queue_runner.py
+++ b/tensorflow/python/training/queue_runner.py
@@ -1,4 +1,8 @@
"""Create threads to run multiple enqueue ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import threading
import tensorflow.python.platform
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index c94c02da66..034004162c 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -1,4 +1,8 @@
"""Tests for QueueRunner."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import time
import tensorflow.python.platform
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index 6dc0ce11ea..02eeb3ff0e 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -13,6 +13,10 @@ delta = - mom
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.training import optimizer
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 520df73ca8..c7456897d5 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -1,4 +1,8 @@
"""Tests for rmsprop."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
import tensorflow.python.platform
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 321e1cdd34..fcd02716d2 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1,5 +1,9 @@
# pylint: disable=invalid-name
"""Save and restore variables."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import collections
import numbers
import os.path
@@ -228,8 +232,7 @@ class BaseSaverBuilder(object):
per_device = collections.defaultdict(lambda: [])
for var_to_save in vars_to_save:
per_device[var_to_save.var.device].append(var_to_save)
- return sorted([(dev, tup) for dev, tup in per_device.iteritems()],
- key=lambda t: t[0])
+ return sorted(per_device.items(), key=lambda t: t[0])
def _VarListToDict(self, var_list):
"""Create a dictionary of names to variable lists.
@@ -295,7 +298,7 @@ class BaseSaverBuilder(object):
vars_to_save = []
seen_variables = set()
- for name in sorted(names_to_variables.iterkeys()):
+ for name in sorted(names_to_variables.keys()):
if not isinstance(name, basestring):
raise TypeError("names_to_variables must be a dict mapping string "
"names to variable Tensors. Name is not a string: %s" %
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index bfc856cbdb..4e248f625c 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -1,4 +1,8 @@
"""Tests for tensorflow.ops.io_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
import time
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
index 4d099b1609..7d1838d9f2 100644
--- a/tensorflow/python/training/summary_io.py
+++ b/tensorflow/python/training/summary_io.py
@@ -1,5 +1,9 @@
"""Reads Summaries from and writes Summaries to event files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
import Queue
import threading
@@ -105,7 +109,7 @@ class SummaryWriter(object):
summary = summ
event = event_pb2.Event(wall_time=time.time(), summary=summary)
if global_step is not None:
- event.step = long(global_step)
+ event.step = int(global_step)
self.add_event(event)
def add_event(self, event):
@@ -129,7 +133,7 @@ class SummaryWriter(object):
"""
event = event_pb2.Event(wall_time=time.time(), graph_def=graph_def)
if global_step is not None:
- event.step = long(global_step)
+ event.step = int(global_step)
self._event_queue.put(event)
def flush(self):
diff --git a/tensorflow/python/training/summary_writer_test.py b/tensorflow/python/training/summary_writer_test.py
index 2ec416f68f..b1fd97bde2 100644
--- a/tensorflow/python/training/summary_writer_test.py
+++ b/tensorflow/python/training/summary_writer_test.py
@@ -1,4 +1,8 @@
"""Tests for training_coordinator.py."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import glob
import os.path
import shutil
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index cb6b93bfd0..a045bd4183 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -109,6 +109,10 @@ overview of summaries, event files, and visualization in TensorBoard.
"""
# Optimizers.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.training.adagrad import AdagradOptimizer
from tensorflow.python.training.adam import AdamOptimizer
from tensorflow.python.training.ftrl import FtrlOptimizer
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
index 410b23e04d..42242e468b 100644
--- a/tensorflow/python/training/training_ops.py
+++ b/tensorflow/python/training/training_ops.py
@@ -1,5 +1,9 @@
"""Python wrappers for training ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.training import gen_training_ops
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
index 902b9b0d78..945d35ec04 100644
--- a/tensorflow/python/training/training_ops_test.py
+++ b/tensorflow/python/training/training_ops_test.py
@@ -1,5 +1,9 @@
"""Tests for tensorflow.learning.training_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import itertools
import tensorflow.python.platform
@@ -94,12 +98,12 @@ class TrainingOpsTest(TensorFlowTestCase):
def testSparseApplyAdagrad(self):
for (dtype, index_type) in itertools.product(
[np.float32, np.float64], [np.int32, np.int64]):
- x_val = [range(10), range(10, 20), range(20, 30)]
- y_val = [range(1, 11), range(11, 21), range(21, 31)]
+ x_val = [np.arange(10), np.arange(10, 20), np.arange(20, 30)]
+ y_val = [np.arange(1, 11), np.arange(11, 21), np.arange(21, 31)]
x = np.array(x_val).astype(dtype)
y = np.array(y_val).astype(dtype)
lr = np.array(2.0).astype(dtype)
- grad_val = [range(10), range(10)]
+ grad_val = [np.arange(10), np.arange(10)]
grad = np.array(grad_val).astype(dtype)
indices = np.array([0, 2]).astype(index_type)
self._testTypesForSparseAdagrad(x, y, lr, grad, indices)
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 14166e25c6..d92b3e1bd6 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -1,4 +1,8 @@
"""Utility functions for training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import os.path
from tensorflow.python.platform import gfile
diff --git a/tensorflow/python/user_ops/user_ops.py b/tensorflow/python/user_ops/user_ops.py
index 20e2604e05..0562d084db 100644
--- a/tensorflow/python/user_ops/user_ops.py
+++ b/tensorflow/python/user_ops/user_ops.py
@@ -1,5 +1,9 @@
"""All user ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.ops import gen_user_ops
from tensorflow.python.ops.gen_user_ops import *
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index 19f7128f4e..51cef58a98 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -52,11 +52,17 @@ Alternatively:
self.assertProto2SameElements(a, c)
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import copy
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import text_format
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
def assertProto2Equal(self, a, b, check_initialized=True,
@@ -201,7 +207,7 @@ def NormalizeRepeatedFields(pb, dedupe=True):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
- for v in values.itervalues():
+ for v in six.itervalues(values):
NormalizeRepeatedFields(v, dedupe=dedupe)
else:
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
@@ -256,7 +262,7 @@ def NormalizeNumberFields(pb):
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64):
- normalized_values = [long(x) for x in values]
+ normalized_values = [int(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32,
@@ -282,7 +288,7 @@ def NormalizeNumberFields(pb):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
- for v in values.itervalues():
+ for v in six.itervalues(values):
NormalizeNumberFields(v)
else:
for v in values:
@@ -331,7 +337,7 @@ def Proto2Cmp(a, b):
if isinstance(pb, message.Message):
return dict((desc.number, value) for desc, value in pb.ListFields())
elif _IsRepeatedContainer(pb):
- return dict(enumerate(list(pb)))
+ return dict(enumerate(pb))
else:
return pb
@@ -344,7 +350,7 @@ def Proto2Cmp(a, b):
# this list performs double duty: it compares two messages by tag value *or*
# two repeated fields by element, in order. the magic is in the format()
# function, which converts them both to the same easily comparable format.
- for tag in sorted(set(a.keys() + b.keys())):
+ for tag in sorted(set(a.keys()) | set(b.keys())):
if tag not in a:
return -1 # b is greater
elif tag not in b:
diff --git a/tensorflow/python/util/protobuf/compare_test.py b/tensorflow/python/util/protobuf/compare_test.py
index 25d1fb2914..d8cb53bc2b 100644
--- a/tensorflow/python/util/protobuf/compare_test.py
+++ b/tensorflow/python/util/protobuf/compare_test.py
@@ -2,6 +2,10 @@
"""Tests for python.util.protobuf.compare."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import copy
import re
import textwrap
@@ -9,6 +13,7 @@ import textwrap
from tensorflow.python.platform import googletest
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.protobuf import compare_test_pb2
+import six
from google.protobuf import text_format
@@ -282,23 +287,23 @@ class NormalizeNumbersTest(googletest.TestCase):
pb = compare_test_pb2.Large()
pb.int64_ = 4
compare.NormalizeNumberFields(pb)
- self.assertTrue(isinstance(pb.int64_, long))
+ self.assertTrue(isinstance(pb.int64_, six.integer_types))
pb.int64_ = 4
compare.NormalizeNumberFields(pb)
- self.assertTrue(isinstance(pb.int64_, long))
+ self.assertTrue(isinstance(pb.int64_, six.integer_types))
pb.int64_ = 9999999999999999
compare.NormalizeNumberFields(pb)
- self.assertTrue(isinstance(pb.int64_, long))
+ self.assertTrue(isinstance(pb.int64_, six.integer_types))
def testNormalizesRepeatedInts(self):
pb = compare_test_pb2.Large()
pb.int64s.extend([1, 400, 999999999999999])
compare.NormalizeNumberFields(pb)
- self.assertTrue(isinstance(pb.int64s[0], long))
- self.assertTrue(isinstance(pb.int64s[1], long))
- self.assertTrue(isinstance(pb.int64s[2], long))
+ self.assertTrue(isinstance(pb.int64s[0], six.integer_types))
+ self.assertTrue(isinstance(pb.int64s[1], six.integer_types))
+ self.assertTrue(isinstance(pb.int64s[2], six.integer_types))
def testNormalizesFloats(self):
pb1 = compare_test_pb2.Large()
diff --git a/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html b/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html
index c053d6f7a7..6ff365a6c1 100644
--- a/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html
+++ b/tensorflow/tensorboard/components/tf-graph-board/tf-graph-board.html
@@ -88,6 +88,7 @@ paper-progress {
<div id="main">
<tf-graph id="graph"
graph-hierarchy="[[graphHierarchy]]"
+ render-hierarchy="{{_renderHierarchy}}"
selected-node="{{_selectedNode}}"
highlighted-node="{{_highlightedNode}}"
color-by="[[colorBy]]"
@@ -100,9 +101,12 @@ paper-progress {
<tf-graph-info id="graph-info"
title="selected"
graph-hierarchy="[[graphHierarchy]]"
+ render-hierarchy="[[_renderHierarchy]]"
graph="[[graph]]"
selected-node="{{_selectedNode}}"
highlighted-node="{{_highlightedNode}}"
+ color-by="[[colorBy]]"
+ color-by-params="[[colorByParams]]"
></tf-graph-info>
</div>
</div>
@@ -126,6 +130,7 @@ Polymer({
* for the progress bar and the displayed message.
*/
progress: Object,
+ colorBy: String,
colorByParams: {
type: Object,
notify: true,
@@ -133,6 +138,7 @@ Polymer({
// Private API: Data routing between child components.
_selectedNode: String,
_highlightedNode: String,
+ _renderHierarchy: Object,
},
/** True if the progress is not complete yet (< 100 %). */
_isNotComplete: function(progress) {
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
index 6c9333de4c..8ada82770c 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
@@ -28,6 +28,7 @@ export interface Hierarchy {
getPredecessors(nodeName: string): Edges;
getSuccessors(nodeName: string): Edges;
getTopologicalOrdering(nodeName: string): { [childName: string]: number };
+ getTemplateIndex(): (string) => number;
}
/**
@@ -323,6 +324,17 @@ class HierarchyImpl implements Hierarchy {
return ordering;
}
+ /**
+ * Returns a d3 Ordinal function that can be used to look up the index of
+ * a node based on its template id.
+ */
+ getTemplateIndex(): (string) => number {
+ let templateNames = d3.keys(this.templates);
+ let templateIndex = d3.scale.ordinal()
+ .domain(templateNames)
+ .range(d3.range(0, templateNames.length));
+ return (templateId: string) => <number>templateIndex(templateId);
+ }
}
/**
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
index 363f006fd5..b4970a9cdc 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
@@ -8,10 +8,23 @@
module tf.graph.render {
/**
+ * Color parameters for op nodes.
+ */
+export let OpNodeColors = {
+ DEFAULT_FILL: "white",
+ DEFAULT_STROKE: "#b2b2b2"
+};
+
+/**
* Color parameters for node encoding.
* @type {Object}
*/
export let MetanodeColors = {
+ /**
+ * Default fill and stroke to use when no other information is available.
+ */
+ DEFAULT_FILL: "#d9d9d9",
+ DEFAULT_STROKE: "#a6a6a6",
SATURATION: 0.6,
LIGHTNESS: 0.85,
/**
@@ -42,6 +55,14 @@ export let MetanodeColors = {
};
/**
+ * Color parameters for op nodes.
+ */
+export let SeriesNodeColors = {
+ DEFAULT_FILL: "white",
+ DEFAULT_STROKE: "#b2b2b2"
+};
+
+/**
* Parameters that affect how the graph is rendered on the screen.
*/
interface RenderGraphParams {
@@ -169,11 +190,66 @@ export class RenderGraphInformation {
this.root.expanded = true;
}
+ /**
+ * Get a previously created RenderNodeInformation by its node name.
+ */
getRenderNodeByName(nodeName: string): RenderNodeInformation {
return this.index[nodeName];
}
/**
+ * Get a previously created RenderNodeInformation for the specified node name,
+ * or create one if it hasn't been created yet.
+ */
+ getOrCreateRenderNodeByName(nodeName: string): RenderNodeInformation {
+ // Polymer may invoke this with null.
+ if (!nodeName) {
+ return null;
+ }
+
+ if (nodeName in this.index) {
+ return this.index[nodeName];
+ }
+
+ let node = this.hierarchy.node(nodeName);
+ let renderInfo = node.isGroupNode ?
+ new RenderGroupNodeInformation(<GroupNode>node) :
+ new RenderNodeInformation(node);
+ this.index[nodeName] = renderInfo;
+
+ if (node.stats) {
+ renderInfo.memoryColor = this.memoryUsageScale(node.stats.totalBytes);
+ renderInfo.computeTimeColor =
+ this.computeTimeScale(node.stats.totalMicros);
+ }
+
+ if (node.isGroupNode) {
+ // Make a list of tuples (device, proportion), where proportion
+ // is the fraction of op nodes that have that device.
+ let pairs = _.pairs((<GroupNode>node).deviceHistogram);
+ if (pairs.length > 0) {
+ // Compute the total # of devices.
+ let numDevices = _.sum(pairs, _.last);
+ renderInfo.deviceColors = _.map(pairs, pair => ({
+ color: this.deviceColorMap(pair[0]),
+ // Normalize to a proportion of total # of devices.
+ proportion: pair[1] / numDevices
+ }));
+ }
+ } else {
+ let device = (<OpNode>renderInfo.node).device;
+ if (device) {
+ renderInfo.deviceColors = [{
+ color: this.deviceColorMap(device),
+ proportion: 1.0
+ }];
+ }
+ }
+
+ return this.index[nodeName];
+ }
+
+ /**
* Return the nearest ancestor node, including itself, that is visible
* in the visualization. This method is used so that we can select
* (highlight) a node that isn't drawn yet, by selecting (highlighting)
@@ -223,19 +299,10 @@ export class RenderGraphInformation {
// groups between which there is no visible path (other than annotations).
_.each(metagraph.nodes(), childName => {
- let childNode = metagraph.node(childName);
- let childRenderInfo = childNode.isGroupNode ?
- new RenderGroupNodeInformation(<GroupNode>childNode) :
- new RenderNodeInformation(childNode);
- this.index[childName] = childRenderInfo;
- coreGraph.setNode(childName, childRenderInfo);
+ let childRenderInfo = this.getOrCreateRenderNodeByName(childName);
+ let childNode = childRenderInfo.node;
- if (childRenderInfo.node.stats != null) {
- childRenderInfo.memoryColor =
- this.memoryUsageScale(childRenderInfo.node.stats.totalBytes);
- childRenderInfo.computeTimeColor =
- this.computeTimeScale(childRenderInfo.node.stats.totalMicros);
- }
+ coreGraph.setNode(childName, childRenderInfo);
if (!childNode.isGroupNode) {
_.each((<OpNode>childNode).inEmbeddings, embedding => {
@@ -250,29 +317,8 @@ export class RenderGraphInformation {
AnnotationType.SUMMARY, this.params);
this.index[embedding.name] = new RenderNodeInformation(embedding);
});
- let device = (<OpNode>childRenderInfo.node).device;
- if (device != null) {
- childRenderInfo.deviceColors = [{
- color: this.deviceColorMap(device),
- proportion: 1.0
- }];
- }
- } else {
- // Make a list of tuples (device, proportion), where proportion
- // is the fraction of op nodes that have that device.
- let pairs = _.pairs((<GroupNode> childNode).deviceHistogram);
- if (pairs.length > 0) {
- // Compute the total # of devices.
- let numDevices = _.sum(pairs, _.last);
- childRenderInfo.deviceColors = _.map(pairs, pair => {
- return {
- color: this.deviceColorMap(pair[0]),
- // Normalize to a proportion of total # of devices.
- proportion: pair[1] / numDevices
- };
- });
- }
}
+
});
// Add render metaedge info for edges in the metagraph.
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
index 8c74b37e07..37d409f41a 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
@@ -411,31 +411,27 @@ function position(nodeGroup, d: render.RenderNodeInformation, sceneBehavior) {
};
/** Enum specifying the options to color nodes by */
-let ColorBy = {
- STRUCTURE: 0,
- DEVICE: 1,
- COMPUTE_TIME: 2,
- MEMORY: 3
-};
+export enum ColorBy { STRUCTURE, DEVICE, COMPUTE_TIME, MEMORY };
/**
* Returns the fill color for the node given its state and the "color by"
* option.
*/
-function getFillForNode(sceneBehavior, colorBy,
+export function getFillForNode(templateIndex, colorBy,
renderInfo: render.RenderNodeInformation, isExpanded: boolean): string {
let colorParams = tf.graph.render.MetanodeColors;
switch (colorBy) {
case ColorBy.STRUCTURE:
if (renderInfo.node.type === tf.graph.NodeType.META) {
let tid = (<Metanode>renderInfo.node).templateId;
- return tid === null ? colorParams.UNKNOWN : colorParams.STRUCTURE_PALETTE(
- sceneBehavior.templateIndex(tid), renderInfo.expanded);
+ return tid === null ?
+ colorParams.UNKNOWN :
+ colorParams.STRUCTURE_PALETTE(templateIndex(tid), isExpanded);
} else if (renderInfo.node.type === tf.graph.NodeType.SERIES) {
// If expanded, we're showing the background rect, which we want to
// appear gray. Otherwise we're showing a stack of ellipses which we
// want to show white.
- return renderInfo.expanded ? colorParams.EXPANDED_COLOR : "white";
+ return isExpanded ? colorParams.EXPANDED_COLOR : "white";
} else if (renderInfo.node.type === NodeType.BRIDGE) {
return renderInfo.structural ? "#f0e" :
(<BridgeNode>renderInfo.node).inbound ? "#0ef" : "#fe0";
@@ -504,22 +500,24 @@ export function stylize(nodeGroup, renderInfo: render.RenderNodeInformation,
// Main node always exists here and it will be reached before subscene,
// so d3 selection is fine here.
let node = nodeGroup.select("." + nodeClass + " ." + Class.Node.COLOR_TARGET);
- let fillColor = getFillForNode(sceneBehavior,
+ let fillColor = getFillForNode(sceneBehavior.templateIndex,
ColorBy[sceneBehavior.colorBy.toUpperCase()],
renderInfo, isExpanded);
node.style("fill", fillColor);
// Choose outline to be darker version of node color if the node is a single
// color and is not selected.
- if (isSelected) {
- node.style("stroke", null);
- } else {
- // If node is colored by a gradient, then use a dark gray outline.
- let outlineColor = fillColor.substring(0, 3) === "url" ?
- tf.graph.render.MetanodeColors.GRADIENT_OUTLINE :
- d3.rgb(fillColor).darker().toString();
- node.style("stroke", outlineColor);
- }
+ node.style("stroke", isSelected ? null : getStrokeForFill(fillColor));
};
+/**
+ * Given a node's fill color/gradient, determine the stroke for the node.
+ */
+export function getStrokeForFill(fill: string) {
+ // If node is colored by a gradient, then use a dark gray outline.
+ return fill.substring(0, 3) === "url" ?
+ tf.graph.render.MetanodeColors.GRADIENT_OUTLINE :
+ d3.rgb(fill).darker().toString();
+}
+
} // close module
diff --git a/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html b/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html
index b7900f86de..e47269861b 100644
--- a/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html
+++ b/tensorflow/tensorboard/components/tf-graph-info/tf-graph-info.html
@@ -18,10 +18,12 @@ h2 {
</style>
<template is="dom-if" if="{{selectedNode}}">
<paper-material elevation="1" class="card">
- <tf-node-info graph-hierarchy='[[graphHierarchy]]'
+ <tf-node-info graph-hierarchy="[[graphHierarchy]]"
+ render-hierarchy="[[renderHierarchy]]"
flat-graph="[[graph]]"
- node-name='[[selectedNode]]'
- highlighted-node='{{highlightedNode}}'>
+ node-name="[[selectedNode]]"
+ highlighted-node="{{highlightedNode}}"
+ color-by="[[colorBy]]">
</tf-node-info>
</paper-material>
</template>
@@ -35,6 +37,8 @@ h2 {
title: String,
graphHierarchy: Object,
graph: Object,
+ renderHierarchy: Object,
+ colorBy: String,
// Two-ways
selectedNode: {
type: String,
diff --git a/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html b/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html
index 5044bf2bb1..056c851924 100644
--- a/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html
+++ b/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html
@@ -88,6 +88,10 @@
max-width: 20px;
padding: 0;
}
+
+ .non-control-list-item {
+ padding-left: 10px;
+ }
</style>
<template>
<paper-item>
@@ -101,7 +105,11 @@
<div class="node-name">[[_getNodeName(nodeName)]]</div>
</div>
<div secondary>
- <tf-graph-icon class="node-icon" node="[[_node]]"></tf-graph-icon>
+ <tf-graph-icon class="node-icon" node="[[_node]]"
+ render-info="[[_getRenderInfo(nodeName, renderHierarchy)]]"
+ color-by="[[colorBy]]"
+ template-index="[[_templateIndex]]"
+ ></tf-graph-icon>
<template is="dom-if" if="{{_node.op}}">
<div class="subtitle">
Operation:
@@ -147,10 +155,15 @@
<iron-list class="sub-list" id ="inputsList"
items="[[_predecessors.regular]]">
<template>
- <tf-node-list-item card-node="[[_node]]"
- item-node="[[_getNode(item, graphHierarchy)]]"
- name="[[item]]"
- item-type="predecessors">
+ <tf-node-list-item
+ class="non-control-list-item"
+ card-node="[[_node]]"
+ item-node="[[_getNode(item, graphHierarchy)]]"
+ item-render-info="[[_getRenderInfo(item, renderHierarchy)]]"
+ name="[[item]]"
+ item-type="predecessors"
+ color-by="[[colorBy]]"
+ template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -168,10 +181,15 @@
<template is="dom-if" if="{{_openedControlPred}}" restamp="true">
<iron-list class="sub-list" items="[[_predecessors.control]]">
<template>
- <tf-node-list-item card-node="[[_node]]"
- item-node="[[_getNode(item, graphHierarchy)]]"
- name="[[item]]"
- item-type="predecessors">
+ <tf-node-list-item
+ card-node="[[_node]]"
+ item-node="[[_getNode(item, graphHierarchy)]]"
+ item-render-info=
+ "[[_getRenderInfo(item, renderHierarchy)]]"
+ name="[[item]]"
+ item-type="predecessors"
+ color-by="[[colorBy]]"
+ template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -187,10 +205,15 @@
<iron-list class="sub-list" id ="outputsList"
items="[[_successors.regular]]">
<template>
- <tf-node-list-item card-node="[[_node]]"
- item-node="[[_getNode(item, graphHierarchy)]]"
- name="[[item]]"
- item-type="successor">
+ <tf-node-list-item
+ class="non-control-list-item"
+ card-node="[[_node]]"
+ item-node="[[_getNode(item, graphHierarchy)]]"
+ item-render-info="[[_getRenderInfo(item, renderHierarchy)]]"
+ name="[[item]]"
+ item-type="successor"
+ color-by="[[colorBy]]"
+ template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -208,10 +231,15 @@
<template is="dom-if" if="{{_openedControlSucc}}" restamp="true">
<iron-list class="sub-list" items="[[_successors.control]]">
<template>
- <tf-node-list-item card-node="[[_node]]"
- item-node="[[_getNode(item, graphHierarchy)]]"
- name="[[item]]"
- item-type="successors">
+ <tf-node-list-item
+ card-node="[[_node]]"
+ item-node="[[_getNode(item, graphHierarchy)]]"
+ item-render-info=
+ "[[_getRenderInfo(item, renderHierarchy)]]"
+ name="[[item]]"
+ item-type="successors"
+ color-by="[[colorBy]]"
+ template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -233,6 +261,13 @@
properties: {
nodeName: String,
graphHierarchy: Object,
+ renderHierarchy: Object,
+ /** What to color the nodes by (compute time, memory, device etc.) */
+ colorBy: String,
+ _templateIndex: {
+ type: Function,
+ computed: '_getTemplateIndex(graphHierarchy)'
+ },
_node: {
type: Object,
computed: '_getNode(nodeName, graphHierarchy)',
@@ -282,14 +317,20 @@
expandNode: function() {
this.fire('_node.expand', this.node);
},
- _getNode: function(n, graphHierarchy) {
- return graphHierarchy.node(n);
+ _getTemplateIndex: function(graphHierarchy) {
+ return graphHierarchy.getTemplateIndex();
+ },
+ _getNode: function(nodeName, graphHierarchy) {
+ return graphHierarchy.node(nodeName);
},
_getNodeName: function(nodeName) {
// Insert a zero-width whitespace character before each slash so that
// long node names wrap cleanly at path boundaries.
return (nodeName || '').replace(/\//g, '\u200B/');
},
+ _getRenderInfo: function(nodeName, renderHierarchy) {
+ return this.renderHierarchy.getOrCreateRenderNodeByName(nodeName);
+ },
_getAttributes: function(node) {
this.async(this._resizeList.bind(this, "#attributesList"));
return node && node.attr ? node.attr.map(function(entry) {
diff --git a/tensorflow/tensorboard/components/tf-graph-info/tf-node-list-item.html b/tensorflow/tensorboard/components/tf-graph-info/tf-node-list-item.html
index f16e9e4511..f97efcd9b7 100644
--- a/tensorflow/tensorboard/components/tf-graph-info/tf-node-list-item.html
+++ b/tensorflow/tensorboard/components/tf-graph-info/tf-node-list-item.html
@@ -39,8 +39,10 @@
on-mouseover="_nodeListener"
on-mouseout="_nodeListener"
on-click="_nodeListener">
- <tf-graph-icon class="node-icon"
- node="[[itemNode]]" height="12"></tf-graph-icon>
+ <tf-graph-icon class="node-icon" height="12"
+ color-by="[[colorBy]]" color-by-params="[[colorByParams]]"
+ node="[[itemNode]]" render-info="[[itemRenderInfo]]"
+ template-index="[[templateIndex]]"></tf-graph-icon>
<span title$="[[name]]">[[name]]</span>
</div>
</template>
@@ -61,11 +63,19 @@
* @type {tf.graph.Node}
*/
itemNode: Object,
+ /**
+ * The render node information for the item node. Used by the graph
+ * icon in determining fill color.
+ */
+ itemRenderInfo: Object,
name: String,
itemType: {
type: String,
observer: '_itemTypeChanged'
- }
+ },
+ colorBy: String,
+ colorByParams: Object,
+ templateIndex: Function,
},
_itemTypeChanged: function() {
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html
index fafaa3b954..d752bb583f 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html
@@ -6,7 +6,9 @@
<template is="dom-if" if="[[_isConst(node, const)]]">
<svg height$="[[height]]"
preserveAspectRatio="xMinYMid meet" viewBox="0 0 10 10">
- <circle fill="white" stroke="#848484" cx="5" cy="5" r="3" />
+ <circle cx="5" cy="5" r="3"
+ fill$="[[_getFill(_computedFill, 'OP')]]"
+ stroke$="[[_getStroke(_computedFill, 'OP')]]" />
</svg>
</template>
<template is="dom-if" if="[[_isSummary(node, summary)]]">
@@ -18,8 +20,8 @@
preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 8">
<use xmlns:xlink="http://www.w3.org/1999/xlink"
xlink:href="#op-node-stamp"
- fill="white"
- stroke="#ccc"
+ fill$="[[_getFill(_computedFill, 'OP')]]"
+ stroke$="[[_getStroke(_computedFill, 'OP')]]"
x="8" y="4" />
</svg>
</template>
@@ -28,7 +30,9 @@
<svg height$="[[height]]"
preserveAspectRatio="xMinYMid meet" viewBox="0 0 37 16">
<rect x="1" y="1"
- fill="#d9d9d9" stroke="#ccc" stroke-width="2px"
+ fill$="[[_getFill(_computedFill, 'META')]]"
+ stroke$="[[_getStroke(_computedFill, 'META')]]"
+ stroke-width="2px"
height="14" width="35"
rx="5" ry="5" />
</svg>
@@ -39,8 +43,8 @@
preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 15">
<use xmlns:xlink="http://www.w3.org/1999/xlink"
xlink:href="#op-series-vertical-stamp"
- fill="white"
- stroke="#ccc"
+ fill$="[[_getFill(_computedFill, 'SERIES')]]"
+ stroke$="[[_getStroke(_computedFill, 'SERIES')]]"
x="0" y="2" />
</svg>
</template>
@@ -49,8 +53,8 @@
preserveAspectRatio="xMinYMid meet" viewBox="0 0 24 10">
<use xmlns:xlink="http://www.w3.org/1999/xlink"
xlink:href="#op-series-horizontal-stamp"
- fill="white"
- stroke="#ccc"
+ fill$="[[_getFill(_computedFill, 'SERIES')]]"
+ stroke$="[[_getStroke(_computedFill, 'SERIES')]]"
x="0" y="1" />
</svg>
</template>
@@ -74,7 +78,36 @@
value: null
},
- /** Type of node to draw. */
+ /**
+ * Render node information associated with this node. Optional. If
+ * specified, this is only used when computing the fill of the icon
+ * element.
+ * @type {tf.graph.render.RenderNodeInformation}
+ */
+ renderInfo: {
+ type: Object,
+ value: null
+ },
+
+ /**
+ * String indicating the type of coloring to use for this node, used
+ * only for deterimining the fill.
+ */
+ colorBy: {
+ type: Object,
+ value: "structural"
+ },
+
+ /**
+ * Function used by structural coloring algorithim to determine which
+ * color to use based on the template ID of the node. Optional.
+ */
+ templateIndex: {
+ type: Function,
+ value: null
+ },
+
+ /** Type of node to draw (ignored if node is set). */
type: {
type: String,
value: null
@@ -98,11 +131,70 @@
value: false
},
+ /**
+ * Fill for the icon, optional. If fill is specified and node is not
+ * specified, then this value will override the default for the
+ * element. However, if node is specified, this value will be ignored.
+ */
+ fill: {
+ type: String,
+ value: null
+ },
+
/** Height of the SVG element in pixels, used for scaling. */
height: {
type: Number,
value: 20
+ },
+
+ /** The computed fill for the node. **/
+ _computedFill: {
+ type: String,
+ computed:
+ "_getComputedFill(node, renderInfo, colorBy, templateIndex, fill)"
}
+
+ },
+
+ /**
+ * Get the computed fill value for the element.
+ */
+ _getComputedFill: function(inputNode, inputRenderInfo, inputColorBy,
+ inputTemplateIndex, inputFill) {
+ if (inputNode && inputRenderInfo &&
+ inputColorBy && inputTemplateIndex) {
+ var ns = tf.graph.scene.node;
+ var colorBy = ns.ColorBy[inputColorBy.toUpperCase()];
+ return ns.getFillForNode(inputTemplateIndex, colorBy,
+ inputRenderInfo, false);
+ }
+ return inputFill;
+ },
+
+ /**
+ * Get the fill value for the element, or if that's not possible, return
+ * the default fill value for the node type.
+ */
+ _getFill: function(inputComputedFill, inputNodeType) {
+ return inputComputedFill || ({
+ OP: tf.graph.render.OpNodeColors.DEFAULT_FILL,
+ META: tf.graph.render.MetanodeColors.DEFAULT_FILL,
+ SERIES: tf.graph.render.SeriesNodeColors.DEFAULT_FILL
+ })[inputNodeType];
+ },
+
+ /**
+ * Get the stroke value for the element, or if that's not possible,
+ * return the default stroke value for the node type.
+ */
+ _getStroke: function(inputComputedFill, inputNodeType) {
+ return inputComputedFill ?
+ tf.graph.scene.node.getStrokeForFill(inputComputedFill) :
+ ({
+ OP: tf.graph.render.OpNodeColors.DEFAULT_STROKE,
+ META: tf.graph.render.MetanodeColors.DEFAULT_STROKE,
+ SERIES: tf.graph.render.SeriesNodeColors.DEFAULT_STROKE
+ })[inputNodeType];
},
/**
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
index 34c2d3dc3d..5984cb67ec 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
@@ -4,8 +4,8 @@
<script src="../tf-graph-common/lib/layout.js"></script>
<script src="../../bower_components/dagre/dist/dagre.core.js"></script>
<!--
- A module that takes graph-hierarchy as input and produce
- a svg dom using dagre and d3.
+ A module that takes a render hierarchy as input and produces an SVG DOM using
+ dagre and d3.
-->
<dom-module id="tf-graph-scene">
<template>
@@ -94,7 +94,7 @@
Polymer({
is: 'tf-graph-scene',
properties: {
- graphHierarchy: Object,
+ renderHierarchy: Object,
name: String,
colorBy: {
type: String,
@@ -134,9 +134,10 @@ Polymer({
/**
* @type {d3.scale.ordinal}
* Scale mapping from template name to a number between 0 and N-1
- * where N is the number of different template names.
+ * where N is the number of different template names. Used by
+ * tf.graph.scene.node when computing node color by structure.
*/
- templateIndex: Object,
+ templateIndex: Function,
/**
* @type {tf.scene.Minimap}
* A minimap object to notify for zoom events.
@@ -198,16 +199,17 @@ Polymer({
progress: Object
},
observers: [
- '_buildAndFit(graphHierarchy)'
+ '_buildAndFit(renderHierarchy)'
],
getNode: function(nodeName) {
- return this.graphHierarchy.getRenderNodeByName(nodeName);
+ return this.renderHierarchy.getRenderNodeByName(nodeName);
},
isNodeExpanded: function(node) {
return node.expanded;
},
setNodeExpanded: function(renderNode) {
- this._build(this.graphHierarchy);
+ this._build(this.renderHierarchy);
+ this._updateLabels(!this._zoomed);
},
/**
* Resets the state of the component. Called whenever the whole graph
@@ -226,20 +228,15 @@ Polymer({
.selectAll('*').remove();
},
/** Main method for building the scene */
- _build: function(graphHierarchy) {
- if (!graphHierarchy) { return; } //handle untruthy input
- var templateNames = d3.keys(graphHierarchy.hierarchy.templates);
-
- this.templateIndex = d3.scale.ordinal()
- .domain(templateNames)
- .range(d3.range(0, templateNames.length));
+ _build: function(renderHierarchy) {
+ this.templateIndex = renderHierarchy.hierarchy.getTemplateIndex();
tf.time('tf-graph-scene (layout):', function() {
// layout the scene for this meta / series node
- tf.graph.layout.scene(graphHierarchy.root, this);
+ tf.graph.layout.scene(renderHierarchy.root, this);
}.bind(this));
tf.time('tf-graph-scene (build scene):', function() {
- tf.graph.scene.buildGroup(d3.select(this.$.root), graphHierarchy.root, this);
+ tf.graph.scene.buildGroup(d3.select(this.$.root), renderHierarchy.root, this);
tf.graph.scene.addGraphClickListener(this.$.svg, this);
}.bind(this));
// Update the minimap again when the graph is done animating.
@@ -302,21 +299,24 @@ Polymer({
tf.graph.layout.PARAMS.minimap.size,
tf.graph.layout.PARAMS.subscene.meta.labelHeight);
},
- _buildAndFit: function(graphHierarchy) {
+ _buildAndFit: function(renderHierarchy) {
this._resetState();
- this._build(graphHierarchy);
+ this._build(renderHierarchy);
// Fit to screen after the graph is done animating.
setTimeout(this.fit.bind(this), tf.graph.layout.PARAMS.animation.duration);
},
_updateLabels: function(showLabels) {
var titleStyle = this.getElementsByClassName('title')[0].style;
var auxTitleStyle = this.getElementsByClassName('auxTitle')[0].style;
- var core = this.getElementsByClassName(tf.graph.scene.Class.Scene.CORE)[0];
+ var core = d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." +
+ tf.graph.scene.Class.Scene.CORE)[0][0];
// Only show labels if the graph is fully loaded.
if (showLabels && core && this.progress && this.progress.value === 100) {
var aux =
- this.getElementsByClassName(tf.graph.scene.Class.Scene.INEXTRACT)[0] ||
- this.getElementsByClassName(tf.graph.scene.Class.Scene.OUTEXTRACT)[0];
+ d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." +
+ tf.graph.scene.Class.Scene.INEXTRACT)[0][0] ||
+ d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." +
+ tf.graph.scene.Class.Scene.OUTEXTRACT)[0][0];
var coreX = core.getCTM().e;
var auxX = aux ? aux.getCTM().e : null;
titleStyle.display = 'inline';
@@ -421,7 +421,7 @@ Polymer({
}
// Update the minimap to reflect the highlighted (selected) node.
this.minimap.update();
- var node = this.graphHierarchy.hierarchy.node(selectedNode);
+ var node = this.renderHierarchy.hierarchy.node(selectedNode);
var nodeParents = [];
// Create list of all metanode parents of the selected node.
while (node.parentNode != null
@@ -432,8 +432,8 @@ Polymer({
// Ensure each parent metanode is built and expanded.
var topParentNodeToBeExpanded;
_.forEachRight(nodeParents, function(parentName) {
- this.graphHierarchy.buildSubhierarchy(parentName);
- var renderNode = this.graphHierarchy.getRenderNodeByName(parentName);
+ this.renderHierarchy.buildSubhierarchy(parentName);
+ var renderNode = this.renderHierarchy.getRenderNodeByName(parentName);
if (renderNode.node.isGroupNode && !renderNode.expanded) {
renderNode.expanded = true;
if (!topParentNodeToBeExpanded) {
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph.html b/tensorflow/tensorboard/components/tf-graph/tf-graph.html
index 905d96e237..0bcd4d5521 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph.html
@@ -39,7 +39,7 @@ paper-button {
<div class="vertical">
<h2>[[title]]</h2>
<tf-graph-scene id="scene" class="auto"
- graph-hierarchy="[[_renderHierarchy]]"
+ render-hierarchy="[[renderHierarchy]]"
highlighted-node="[[_getVisible(highlightedNode)]]"
selected-node="[[selectedNode]]"
color-by="[[colorBy]]"
@@ -78,6 +78,12 @@ Polymer({
notify: true,
readOnly: true, // Produces and doesn't consume.
},
+ renderHierarchy: {
+ type: Object,
+ readOnly: true,
+ notify: true,
+ computed: '_buildRenderHierarchy(graphHierarchy, _graphParams)'
+ },
// internal properties
_graphParams: {
type: Object,
@@ -89,12 +95,6 @@ Polymer({
type: Number,
value: 1
},
- _renderHierarchy: {
- type: Object,
- readOnly: true,
- notify: true,
- computed: '_buildRenderHierarchy(graphHierarchy, _graphParams)'
- },
_allowGraphSelect: {
type: Boolean,
value: true
@@ -142,7 +142,7 @@ Polymer({
if (!name) {
return name;
}
- return this._renderHierarchy.getNearestVisibleAncestor(name);
+ return this.renderHierarchy.getNearestVisibleAncestor(name);
},
listeners: {
'graph-select': '_graphSelected',
@@ -203,12 +203,12 @@ Polymer({
},
_nodeToggleExpand: function(event) {
var nodeName = event.detail.name;
- var renderNode = this._renderHierarchy.getRenderNodeByName(nodeName);
+ var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName);
// Op nodes are not expandable.
if (renderNode.node.type === tf.graph.NodeType.OP) {
return;
}
- this._renderHierarchy.buildSubhierarchy(nodeName);
+ this.renderHierarchy.buildSubhierarchy(nodeName);
renderNode.expanded = !renderNode.expanded;
this.querySelector('#scene').setNodeExpanded(renderNode);
// Also select the expanded node.
diff --git a/tensorflow/tensorboard/float_wrapper.py b/tensorflow/tensorboard/float_wrapper.py
index 9fe45d9070..a437eb54bf 100644
--- a/tensorflow/tensorboard/float_wrapper.py
+++ b/tensorflow/tensorboard/float_wrapper.py
@@ -8,6 +8,10 @@ JSONEncoder nor passing a function in the |default| keyword argument overrides
this.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import math
@@ -20,7 +24,7 @@ def WrapSpecialFloats(obj):
elif isinstance(obj, float) and math.isnan(obj):
return 'NaN'
elif isinstance(obj, list) or isinstance(obj, tuple):
- return map(WrapSpecialFloats, obj)
+ return list(map(WrapSpecialFloats, obj))
elif isinstance(obj, dict):
return {
WrapSpecialFloats(k): WrapSpecialFloats(v)
diff --git a/tensorflow/tensorboard/float_wrapper_test.py b/tensorflow/tensorboard/float_wrapper_test.py
index 5f6594733c..773cf10e68 100644
--- a/tensorflow/tensorboard/float_wrapper_test.py
+++ b/tensorflow/tensorboard/float_wrapper_test.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import tensorflow.python.platform
from tensorflow.python.platform import googletest
diff --git a/tensorflow/tensorboard/scripts/demo_from_server.py b/tensorflow/tensorboard/scripts/demo_from_server.py
index 9b453f82a0..26c0187d2d 100644
--- a/tensorflow/tensorboard/scripts/demo_from_server.py
+++ b/tensorflow/tensorboard/scripts/demo_from_server.py
@@ -1,10 +1,15 @@
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
+
import json
import os
import urllib2
import requests
import os.path
import shutil
+import six
+
class TensorBoardStaticSerializer(object):
"""Serialize all the routes from a TensorBoard server to static json."""
@@ -81,8 +86,8 @@ class TensorBoardStaticSerializer(object):
def Run(self):
"""Main method that loads and serializes everything."""
runs = self._RetrieveAndSave('runs')
- for run, tag_type_to_tags in runs.iteritems():
- for tag_type, tags in tag_type_to_tags.iteritems():
+ for run, tag_type_to_tags in six.iteritems(runs):
+ for tag_type, tags in six.iteritems(tag_type_to_tags):
try:
if tag_type == 'graph':
if tags:
diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py
index c75db0d1f7..e1a09655f2 100644
--- a/tensorflow/tensorboard/tensorboard.py
+++ b/tensorflow/tensorboard/tensorboard.py
@@ -3,6 +3,8 @@
This is a simple web server to proxy data from the event_loader to the web, and
serve static web files.
"""
+from __future__ import absolute_import
+from __future__ import division
from __future__ import print_function
import BaseHTTPServer
diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py
index cd50f43069..0ea7f3a58d 100644
--- a/tensorflow/tensorboard/tensorboard_handler.py
+++ b/tensorflow/tensorboard/tensorboard_handler.py
@@ -5,6 +5,10 @@ and for handling the API calls to endpoints like /tags that require information
about loaded events.
"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import BaseHTTPServer
import csv
import gzip
@@ -16,6 +20,7 @@ import StringIO
import urllib
import urlparse
+from six.moves import xrange # pylint: disable=redefined-builtin
from google.protobuf import text_format
import tensorflow.python.platform
@@ -316,7 +321,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
# Strip off the leading forward slash.
path = path.lstrip('/')
if not self._path_is_safe(path):
- logging.info('path %s not safe, sending 404' % path)
+ logging.info('path %s not safe, sending 404', path)
# Traversal attack, so 404.
self.send_error(404)
return
@@ -329,7 +334,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
try:
contents = resource_loader.load_resource(path)
except IOError:
- logging.info('path %s not found, sending 404' % path)
+ logging.info('path %s not found, sending 404', path)
self.send_error(404)
return
diff --git a/tensorflow/tools/docker/simple_console.py b/tensorflow/tools/docker/simple_console.py
index b2fdc739e7..a7872b222f 100644
--- a/tensorflow/tools/docker/simple_console.py
+++ b/tensorflow/tools/docker/simple_console.py
@@ -1,14 +1,18 @@
"""Start a simple interactive console with TensorFlow available."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import code
import sys
def main(_):
- """Run an interactive console."""
- code.interact()
- return 0
+ """Run an interactive console."""
+ code.interact()
+ return 0
if __name__ == '__main__':
- sys.exit(main(sys.argv))
+ sys.exit(main(sys.argv))
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index e5d25c52ee..f6ddca8b25 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -1,3 +1,7 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import fnmatch
import os
from setuptools import find_packages, setup, Extension
diff --git a/tensorflow/tools/pip_package/simple_console.py b/tensorflow/tools/pip_package/simple_console.py
index b2fdc739e7..a7872b222f 100644
--- a/tensorflow/tools/pip_package/simple_console.py
+++ b/tensorflow/tools/pip_package/simple_console.py
@@ -1,14 +1,18 @@
"""Start a simple interactive console with TensorFlow available."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
import code
import sys
def main(_):
- """Run an interactive console."""
- code.interact()
- return 0
+ """Run an interactive console."""
+ code.interact()
+ return 0
if __name__ == '__main__':
- sys.exit(main(sys.argv))
+ sys.exit(main(sys.argv))