aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md2
-rw-r--r--configure.py18
-rw-r--r--tensorflow/BUILD15
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/compiler/jit/BUILD7
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc8
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc3
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc102
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.h8
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc70
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc40
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h35
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc177
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.h58
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc284
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc22
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h11
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_device.cc41
-rw-r--r--tensorflow/compiler/jit/xla_device.h14
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc89
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h31
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc62
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h6
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc7
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc20
-rw-r--r--tensorflow/compiler/xla/service/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc9
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc1
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc9
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc38
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h10
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc14
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc25
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD35
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.cc205
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.h71
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc126
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc73
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc75
-rw-r--r--tensorflow/compiler/xla/service/hlo_token.h1
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc350
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h34
-rw-r--r--tensorflow/compiler/xla/shape_util.cc11
-rw-r--r--tensorflow/compiler/xla/tests/BUILD17
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h5
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc615
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc71
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc55
-rw-r--r--tensorflow/compiler/xla/xla.proto19
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc11
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc16
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc18
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc16
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc14
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc14
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc16
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc184
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py75
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py127
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/proto/learner.proto8
-rw-r--r--tensorflow/contrib/boosted_trees/proto/split_info.proto7
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py9
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py3
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/constrained_optimization/python/candidates.py2
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py19
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py128
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py53
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py65
-rw-r--r--tensorflow/contrib/data/kernels/assert_next_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc5
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc11
-rw-r--r--tensorflow/contrib/data/kernels/unique_dataset_op.cc4
-rw-r--r--tensorflow/contrib/distribute/__init__.py4
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py4
-rw-r--r--tensorflow/contrib/distribute/python/estimator_integration_test.py16
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py52
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py5
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py11
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py12
-rw-r--r--tensorflow/contrib/distribute/python/values.py39
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py3
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb4
-rw-r--r--tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb2
-rw-r--r--tensorflow/contrib/estimator/BUILD28
-rw-r--r--tensorflow/contrib/estimator/__init__.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/exporter.py280
-rw-r--r--tensorflow/contrib/estimator/python/estimator/exporter_test.py206
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py53
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py57
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py1
-rw-r--r--tensorflow/contrib/gan/BUILD2
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py91
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py40
-rw-r--r--tensorflow/contrib/gan/python/train.py14
-rw-r--r--tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc7
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes.py4
-rw-r--r--tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc4
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py6
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py2
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py12
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py6
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py8
-rw-r--r--tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc4
-rw-r--r--tensorflow/contrib/layers/python/layers/initializers.py2
-rw-r--r--tensorflow/contrib/learn/BUILD3
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py6
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/context.h8
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD73
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.cc13
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.h10
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc7
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc52
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel_test.cc8
-rw-r--r--tensorflow/contrib/lite/error_reporter.cc13
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm2
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/Podfile2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/Podfile2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc4
-rw-r--r--tensorflow/contrib/lite/interpreter.h2
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc17
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc58
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc124
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h476
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h425
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h56
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc6
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc27
-rw-r--r--tensorflow/contrib/lite/python/lite.py66
-rw-r--r--tensorflow/contrib/lite/rpi_makefile.inc33
-rw-r--r--tensorflow/contrib/lite/testing/BUILD3
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py69
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.cc8
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc94
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.cc4
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h27
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.cc2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h3
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc16
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h4
-rw-r--r--tensorflow/contrib/lite/toco/BUILD2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc78
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc173
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc4
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.cc14
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc2
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile (renamed from tensorflow/contrib/lite/Makefile)132
-rwxr-xr-xtensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh (renamed from tensorflow/contrib/lite/build_ios_universal_lib.sh)18
-rwxr-xr-xtensorflow/contrib/lite/tools/make/build_rpi_lib.sh (renamed from tensorflow/contrib/lite/build_rpi_lib.sh)4
-rwxr-xr-xtensorflow/contrib/lite/tools/make/download_dependencies.sh (renamed from tensorflow/contrib/lite/download_dependencies.sh)4
-rw-r--r--tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc (renamed from tensorflow/contrib/lite/ios_makefile.inc)26
-rw-r--r--tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc10
-rw-r--r--tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc10
-rw-r--r--tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc60
-rw-r--r--tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc21
-rw-r--r--tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc41
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py33
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py11
-rwxr-xr-xtensorflow/contrib/makefile/download_dependencies.sh4
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification.py4
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py11
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop.py32
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop_test.py128
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py24
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py54
-rw-r--r--tensorflow/contrib/rnn/BUILD8
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py5
-rw-r--r--tensorflow/contrib/saved_model/BUILD1
-rw-r--r--tensorflow/contrib/stat_summarizer/BUILD5
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py328
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest_test.py305
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py160
-rw-r--r--tensorflow/contrib/training/python/training/training.py6
-rw-r--r--tensorflow/core/BUILD56
-rw-r--r--tensorflow/core/api_def/api_test.cc4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt11
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.h2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/context.h2
-rw-r--r--tensorflow/core/common_runtime/executor.cc2
-rw-r--r--tensorflow/core/common_runtime/executor.h2
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h36
-rw-r--r--tensorflow/core/framework/dataset.cc54
-rw-r--r--tensorflow/core/framework/dataset.h120
-rw-r--r--tensorflow/core/framework/function.h4
-rw-r--r--tensorflow/core/framework/op_kernel.h6
-rw-r--r--tensorflow/core/framework/shape_inference.cc3
-rw-r--r--tensorflow/core/graph/gradients.cc41
-rw-r--r--tensorflow/core/graph/testlib.cc25
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc18
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc72
-rw-r--r--tensorflow/core/kernels/BUILD73
-rw-r--r--tensorflow/core/kernels/constant_op.cc38
-rw-r--r--tensorflow/core/kernels/constant_op.h20
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc56
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc10
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/filter_by_component_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc28
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/random_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/range_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/data/reader_dataset_ops.cc12
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc18
-rw-r--r--tensorflow/core/kernels/data/take_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/window_dataset.cc14
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc11
-rw-r--r--tensorflow/core/kernels/data/zip_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/host_constant_op.cc78
-rw-r--r--tensorflow/core/kernels/host_constant_op.h42
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc73
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc2
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc12
-rw-r--r--tensorflow/core/kernels/shape_ops.h8
-rw-r--r--tensorflow/core/ops/array_ops.cc24
-rw-r--r--tensorflow/core/ops/array_ops_test.cc15
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt15
-rw-r--r--tensorflow/core/ops/image_ops.cc77
-rw-r--r--tensorflow/core/ops/lookup_ops.cc139
-rw-r--r--tensorflow/core/ops/ops.pbtxt15
-rw-r--r--tensorflow/core/platform/default/build_config.bzl9
-rw-r--r--tensorflow/docs_src/community/index.md8
-rw-r--r--tensorflow/docs_src/guide/eager.md6
-rw-r--r--tensorflow/docs_src/install/install_sources_windows.md320
-rw-r--r--tensorflow/docs_src/install/install_windows.md2
-rw-r--r--tensorflow/docs_src/install/leftnav_files1
-rw-r--r--tensorflow/docs_src/tutorials/sequences/recurrent.md6
-rw-r--r--tensorflow/java/BUILD6
-rw-r--r--tensorflow/java/maven/hadoop/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/run_inside_container.sh6
-rw-r--r--tensorflow/java/maven/spark-connector/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java29
-rw-r--r--tensorflow/python/BUILD9
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/eager/benchmarks_test.py58
-rw-r--r--tensorflow/python/eager/function.py153
-rw-r--r--tensorflow/python/eager/function_test.py28
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py118
-rw-r--r--tensorflow/python/estimator/estimator.py491
-rw-r--r--tensorflow/python/estimator/estimator_test.py35
-rw-r--r--tensorflow/python/estimator/keras.py6
-rw-r--r--tensorflow/python/framework/ops.py12
-rw-r--r--tensorflow/python/framework/test_util.py2
-rwxr-xr-xtensorflow/python/keras/BUILD111
-rw-r--r--tensorflow/python/keras/applications/__init__.py2
-rw-r--r--tensorflow/python/keras/applications/applications_test.py58
-rw-r--r--tensorflow/python/keras/applications/densenet_test.py101
-rw-r--r--tensorflow/python/keras/applications/imagenet_utils_test.py93
-rw-r--r--tensorflow/python/keras/applications/inception_resnet_v2_test.py59
-rw-r--r--tensorflow/python/keras/applications/inception_v3_test.py58
-rw-r--r--tensorflow/python/keras/applications/mobilenet_test.py71
-rw-r--r--tensorflow/python/keras/applications/mobilenet_v2.py12
-rw-r--r--tensorflow/python/keras/applications/nasnet_test.py76
-rw-r--r--tensorflow/python/keras/applications/resnet50_test.py51
-rw-r--r--tensorflow/python/keras/applications/vgg16_test.py50
-rw-r--r--tensorflow/python/keras/applications/vgg19_test.py50
-rw-r--r--tensorflow/python/keras/applications/xception_test.py57
-rw-r--r--tensorflow/python/keras/callbacks.py2
-rw-r--r--tensorflow/python/keras/integration_test.py26
-rw-r--r--tensorflow/python/keras/layers/local.py340
-rw-r--r--tensorflow/python/keras/layers/local_test.py461
-rw-r--r--tensorflow/python/keras/layers/normalization.py10
-rw-r--r--tensorflow/python/keras/layers/recurrent.py16
-rw-r--r--tensorflow/python/keras/metrics.py4
-rw-r--r--tensorflow/python/keras/optimizers.py4
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py166
-rw-r--r--tensorflow/python/keras/utils/conv_utils_test.py232
-rw-r--r--tensorflow/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/batch_gather_op_test.py116
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py16
-rw-r--r--tensorflow/python/kernel_tests/confusion_matrix_test.py4
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py80
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py20
-rw-r--r--tensorflow/python/ops/array_ops.py70
-rw-r--r--tensorflow/python/ops/clip_ops.py6
-rw-r--r--tensorflow/python/ops/image_ops_test.py35
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py22
-rw-r--r--tensorflow/python/ops/math_ops.py21
-rw-r--r--tensorflow/python/ops/metrics_impl.py41
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py9
-rw-r--r--tensorflow/python/ops/state_ops.py54
-rw-r--r--tensorflow/python/ops/summary_op_util.py4
-rw-r--r--tensorflow/python/ops/variables.py7
-rw-r--r--tensorflow/python/summary/summary.py6
-rw-r--r--tensorflow/python/summary/writer/writer.py2
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl1
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl1
-rw-r--r--tensorflow/python/tools/freeze_graph.py20
-rw-r--r--tensorflow/python/training/checkpoint_management.py293
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py201
-rw-r--r--tensorflow/python/training/checkpoint_state.proto8
-rw-r--r--tensorflow/python/training/checkpoint_utils.py6
-rw-r--r--tensorflow/python/training/checkpointable/BUILD5
-rw-r--r--tensorflow/python/training/checkpointable/base.py4
-rw-r--r--tensorflow/python/training/checkpointable/util.py10
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py9
-rw-r--r--tensorflow/python/training/distribute.py232
-rw-r--r--tensorflow/python/training/distribute_test.py53
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py203
-rw-r--r--tensorflow/python/training/optimizer.py16
-rw-r--r--tensorflow/python/training/slot_creator.py8
-rw-r--r--tensorflow/stream_executor/BUILD2
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.h2
-rw-r--r--tensorflow/stream_executor/host/host_stream.cc26
-rw-r--r--tensorflow/tensorflow.bzl2743
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt4
-rwxr-xr-xtensorflow/tools/ci_build/linux/mkl/build-dev-container.sh21
-rw-r--r--tensorflow/tools/docker/Dockerfile2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn72
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl27
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl-horovod168
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu2
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl2
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl-horovod111
-rwxr-xr-xtensorflow/tools/docker/parameterized_docker_build.sh40
-rw-r--r--tensorflow/tools/docs/BUILD19
-rw-r--r--tensorflow/tools/docs/doc_controls.py319
-rw-r--r--tensorflow/tools/docs/doc_controls_test.py183
-rw-r--r--tensorflow/tools/docs/generate_lib.py19
-rw-r--r--tensorflow/tools/docs/parser.py16
-rw-r--r--tensorflow/tools/docs/parser_test.py115
-rw-r--r--tensorflow/tools/pip_package/BUILD5
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh2
-rw-r--r--tensorflow/tools/proto_text/BUILD1
-rw-r--r--tensorflow/workspace.bzl16
-rw-r--r--third_party/curl.BUILD14
-rw-r--r--third_party/double_conversion.BUILD16
-rw-r--r--third_party/farmhash.BUILD8
-rw-r--r--third_party/fft2d/fft2d.BUILD10
-rw-r--r--third_party/flatbuffers/flatbuffers.BUILD15
-rw-r--r--third_party/gif.BUILD9
-rw-r--r--third_party/gpus/cuda_configure.bzl4
-rw-r--r--third_party/jpeg/jpeg.BUILD8
-rw-r--r--third_party/lmdb.BUILD6
-rw-r--r--third_party/mkl/BUILD17
-rw-r--r--third_party/mkl/build_defs.bzl83
-rw-r--r--third_party/mkl_dnn/BUILD5
-rw-r--r--third_party/nasm.BUILD9
-rw-r--r--third_party/png.BUILD18
-rw-r--r--third_party/snappy.BUILD12
-rw-r--r--third_party/sqlite.BUILD8
-rw-r--r--third_party/swig.BUILD6
-rw-r--r--third_party/zlib.BUILD1
437 files changed, 14228 insertions, 5278 deletions
diff --git a/RELEASE.md b/RELEASE.md
index ae41d56e14..763ef3b279 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -3,7 +3,7 @@
## Major Features And Improvements
* The `tf.lite` runtime now supports `complex64`.
-* Initial Bigtable integration for `tf.data`.
+* Initial [Google Cloud Bigtable integration](https://github.com/tensorflow/tensorflow/tree/r1.10/tensorflow/contrib/bigtable) for `tf.data`.
* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation.
* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`.
* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018.
diff --git a/configure.py b/configure.py
index 6d0c077406..7acc6932eb 100644
--- a/configure.py
+++ b/configure.py
@@ -839,15 +839,16 @@ def set_tf_cuda_version(environ_cp):
cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
- cuda_rt_lib_path = 'lib/x64/cudart.lib'
+ cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
- cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version
+ cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version)
+ for x in ['lib64', 'lib/x86_64-linux-gnu']]
elif is_macos():
- cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version
+ cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
- cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path)
- if os.path.exists(cuda_toolkit_path_full):
- break
+ cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths]
+ if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
+ break
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
@@ -1398,10 +1399,6 @@ def set_grpc_build_flags():
write_to_bazelrc('build --define grpc_no_ares=true')
-def set_build_strip_flag():
- write_to_bazelrc('build --strip=always')
-
-
def set_windows_build_flags(environ_cp):
"""Set Windows specific build options."""
# The non-monolithic build is not supported yet
@@ -1560,7 +1557,6 @@ def main():
set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
- set_build_strip_flag()
if is_windows():
set_windows_build_flags(environ_cp)
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index f3d8d558ac..e5654a5141 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -125,12 +125,6 @@ config_setting(
)
config_setting(
- name = "windows_msvc",
- values = {"cpu": "x64_windows_msvc"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "no_tensorflow_py_deps",
define_values = {"no_tensorflow_py_deps": "true"},
visibility = ["//visibility:public"],
@@ -439,14 +433,14 @@ package_group(
load(
"//third_party/mkl:build_defs.bzl",
- "if_mkl",
+ "if_mkl_ml",
)
filegroup(
name = "intel_binary_blob",
- data = if_mkl(
+ data = if_mkl_ml(
[
- "//third_party/mkl:intel_binary_blob",
+ "//third_party/intel_mkl_ml",
],
),
)
@@ -497,7 +491,6 @@ tf_cc_shared_object(
linkopts = select({
"//tensorflow:darwin": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"$(location //tensorflow:tf_framework_version_script.lds)",
@@ -539,7 +532,6 @@ tf_cc_shared_object(
"-Wl,-install_name,@rpath/libtensorflow.so",
],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
@@ -564,7 +556,6 @@ tf_cc_shared_object(
"$(location //tensorflow:tf_exported_symbols.lds)",
],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index f4be60a183..f56521dac0 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -628,7 +628,6 @@ tf_cc_binary(
copts = tf_copts(),
linkopts = select({
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//tensorflow:darwin": [
"-lm",
"-lpthread",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 55b98da472..e059f77563 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -314,12 +314,16 @@ cc_library(
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
+ "mark_for_compilation_pass_test_helper.cc",
+ "partially_decluster_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"mark_for_compilation_pass.h",
+ "mark_for_compilation_pass_test_helper.h",
+ "partially_decluster_pass.h",
],
deps = [
":common",
@@ -354,6 +358,7 @@ cc_library(
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
],
@@ -418,10 +423,12 @@ tf_cc_test(
srcs = [
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
+ "partially_decluster_pass_test.cc",
],
deps = [
":common",
":compilation_passes",
+ ":xla_cluster_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 4d49a14b24..c37b6112cc 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
@@ -23,15 +24,18 @@ namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
+ PartiallyDeclusterPass);
+
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
BuildXlaLaunchOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 37a2f3b5ac..7f4370b5b0 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -210,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
- launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie()));
VLOG(1) << "Done";
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 45d422943c..90d5d56998 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -65,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
+ VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
return false;
}
@@ -84,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def,
bool IsCompilableWhile(const Node& while_node,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
- VLOG(2) << "Loop marking: " << while_node.type_string();
-
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
- VLOG(2) << "Missing 'cond' attribute on While node.";
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
@@ -99,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
- VLOG(2) << "Can't compile loop condition: " << cond_func;
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
- VLOG(2) << "Missing 'body' attribute on While node.";
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
@@ -112,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
- VLOG(2) << "Can't compile loop body: " << body_func;
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": can't compile loop body: " << body_func;
return false;
}
- VLOG(2) << "Loop is compilable.";
return true;
}
@@ -125,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node,
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
- VLOG(2) << "Function marking: " << call_def.op();
-
if (depth > kMaxRecursionDepth) {
- VLOG(2) << "Function depth limit exceeded";
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": function depth limit exceeded.";
return false;
}
@@ -136,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def,
Status status =
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
if (!status.ok()) {
- VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": could not instantiate: " << status;
return false;
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
@@ -150,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def,
// tf2xla to translate the TF graph into XLA. So we avoid this for now.
//
// TODO(b/36139787): Create a mechanism to set inlining hints.
- VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": can't compile noinline function.";
return false;
}
@@ -164,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def,
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
lib_runtime)) {
- VLOG(2) << "Function marking failed: unsupported op " << node->name()
- << ": " << node->def().ShortDebugString();
+ VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
+ << node->name() << ": " << node->def().ShortDebugString();
return false;
}
}
- VLOG(2) << "Function is compilable: " << call_def.op();
return true;
}
-// Tests whether `node` has a DT_RESOURCE typed input or output.
-bool HasResourceInputOrOutput(const Node& node) {
- return std::find(node.input_types().begin(), node.input_types().end(),
- DT_RESOURCE) != node.input_types().end() ||
- std::find(node.output_types().begin(), node.output_types().end(),
- DT_RESOURCE) != node.output_types().end();
-}
-
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
@@ -357,24 +351,27 @@ Status FindCompilationCandidates(
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
+ if (fuel >= std::numeric_limits<int64>::max() / 2) {
+ // The assumption is that if fuel started out as INT64_MAX, it will forever
+ // stay greater than INT64_MAX / 2.
+ VLOG(2) << "Starting fuel: infinity";
+ } else {
+ VLOG(2) << "Starting fuel: " << fuel;
+ }
+
for (Node* node : sorted_nodes) {
- VLOG(2) << "Fuel: " << fuel;
if (fuel <= 0) {
- VLOG(2)
+ VLOG(1)
<< "Hit fuel limit; not marking any remaining ops as clusterable.";
break;
}
- VLOG(2) << "FindCompilationCandidates(): Processing "
- << node->DebugString();
-
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
- VLOG(2) << "Compilation rejected node: not compilable " << node->name()
- << ": " << node->type_string();
+ // is_compilable_fn has already logged the reason if it returned false.
continue;
}
@@ -384,14 +381,14 @@ Status FindCompilationCandidates(
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
- VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
- << ": " << node->type_string();
+ VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
+ << node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
HasResourceInputOrOutput(*node)) {
- VLOG(2) << "Compilation rejected node: resource input/output "
- << node->name() << ": " << node->type_string();
+ VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
+ << node->type_string();
continue;
}
if (node->type_string() == "While" &&
@@ -401,15 +398,11 @@ Status FindCompilationCandidates(
// _Arg nodes in a top-level function represent feeds.
// Do not compile them.
if (node->type_string() == "_Arg") {
- VLOG(2) << "Skipping jit compilation for '_Arg'-typed node "
- << node->DebugString();
continue;
}
// _Retval nodes in a top-level function represent fetches.
// Do not compile them.
if (node->type_string() == "_Retval") {
- VLOG(2) << "Compilation rejected node: return value " << node->name()
- << ": " << node->type_string();
continue;
}
candidates->insert(node);
@@ -475,6 +468,7 @@ Status MarkForCompilationPass::Run(
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
+ VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device.";
return false;
}
@@ -484,21 +478,36 @@ Status MarkForCompilationPass::Run(
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
- if (status.ok()) return compile;
+ if (status.ok()) {
+ if (!compile) {
+ VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
+ << kXlaCompileAttr << ") is false.";
+ }
+ return compile;
+ }
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
- if (status.ok()) return compile;
+ if (status.ok()) {
+ if (!compile) {
+ VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
+ << kXlaCompileAttr << ") on callee is false.";
+ }
+ return compile;
+ }
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness to propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness.";
return false;
}
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
+ VLOG(2) << "Rejecting " << node->name()
+ << ": not fusable op but fusion_only enabled.";
return false;
}
@@ -506,8 +515,17 @@ Status MarkForCompilationPass::Run(
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
- return (ignore_registration || registration->enable_jit_by_default) &&
- global_jit_level > 0;
+ bool should_compile =
+ (ignore_registration || registration->enable_jit_by_default) &&
+ global_jit_level > 0;
+ if (!should_compile) {
+ if (global_jit_level <= 0) {
+ VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
+ } else {
+ VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
+ }
+ }
+ return should_compile;
};
return RunImpl(options, is_compilable);
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h
index e9acbfb19e..f1137af3c1 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.h
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h
@@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
- // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass
- // unconditionally, call RunImpl() directly.
- // is_compilable_fn, if set, is a predicate that must be true for a node to
- // be compiled.
+ private:
Status RunImpl(const GraphOptimizationPassOptions& options,
const std::function<bool(const Node*, const DeviceType&)>&
is_compilable_fn = {});
+
+ friend class MarkForCompilationPassTestHelper;
};
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 2c5f4fb774..a780d4a936 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@@ -39,27 +39,6 @@ namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
-Status MarkForCompilation(std::unique_ptr<Graph>* graph,
- FunctionLibraryDefinition* flib_def) {
- // Assign all nodes to the CPU device.
- static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
- for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
- }
-
- GraphOptimizationPassOptions opt_options;
- opt_options.graph = graph;
- opt_options.flib_def = flib_def;
- MarkForCompilationPass pass;
- return pass.RunImpl(opt_options);
-}
-
-Status MarkForCompilation(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
- return MarkForCompilation(graph, &flib_def);
-}
-
std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
@@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
@@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size());
@@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
@@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_FALSE(clusters.empty());
}
@@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
}
@@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def));
+ TF_ASSERT_OK(
+ MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// Nothing should be compiled. In particular, 'd' and 'c' must not be
@@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A + relu(A)
@@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: D = relu(A) + (A @ relu(A))
@@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A @ relu(A)
@@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) {
ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
TF_EXPECT_OK(root.ToGraph(graph.get()));
- Status status = MarkForCompilation(&graph);
+ Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.ToString(),
"Edge from c to a would create a cycle.\n"
@@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
auto r = ops::_Retval(root.WithOpName("R"), c, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), 0.5f);
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_EQ(1, GetClusters(*graph).size());
}
@@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), string("string"));
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_TRUE(GetClusters(*graph).empty());
}
}
@@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
new file mode 100644
index 0000000000..a84b82e479
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+
+namespace tensorflow {
+/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : (*graph)->nodes()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = flib_def;
+ MarkForCompilationPass pass;
+ return pass.RunImpl(opt_options);
+}
+
+/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
+ std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+ return MarkForCompilation(graph, &flib_def);
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
new file mode 100644
index 0000000000..b9a0531cb0
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
+#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
+
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+
+namespace tensorflow {
+class MarkForCompilationPassTestHelper {
+ public:
+ // Runs the MarkForCompilation pass on `graph` after assigning all nodes in
+ // `graph` to the CPU device. To make testing easier, ignores device
+ // registration, _XlaCompile attributes, input deadness and global jit level.
+ static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def);
+
+ // Like `MarkForCompilation` but creates `flib_def` from the op registry.
+ static Status MarkForCompilation(std::unique_ptr<Graph>* graph);
+};
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
new file mode 100644
index 0000000000..68ead39424
--- /dev/null
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -0,0 +1,177 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/core/framework/memory_types.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace tensorflow {
+namespace {
+Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
+ gtl::ArraySlice<Node*> post_order) {
+ // Find nodes that have at least one user outside their cluster that expects
+ // hostmem output. These nodes should be cloned to outside the cluster to
+ // avoid the device-host copy we'd otherwise need.
+
+ MemoryTypeVector input_mtypes, output_mtypes;
+
+ for (Node* n : post_order) {
+ gtl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n);
+ if (!from_cluster) {
+ continue;
+ }
+
+ // We assume the only XLA-auto-clusterable operations with side effects are
+ // resource variable updates. We can't execute these twice.
+ if (HasResourceInputOrOutput(*n)) {
+ continue;
+ }
+
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
+ TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
+ n->def(), &input_mtypes,
+ &output_mtypes));
+ for (const Edge* e : n->out_edges()) {
+ Node* dst = e->dst();
+
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ bool edge_incurs_extra_device_to_host_copy;
+ if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
+ // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
+ // keep the node clustered -- XLA will also produce the output in device
+ // memory and we will get some benefit from clustering.
+ edge_incurs_extra_device_to_host_copy = false;
+ } else {
+ MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
+ DeviceType dst_device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type));
+ TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
+ dst->def(), &dst_input_mtypes,
+ &dst_output_mtypes));
+ edge_incurs_extra_device_to_host_copy =
+ dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
+ }
+
+ if (!edge_incurs_extra_device_to_host_copy) {
+ continue;
+ }
+
+ // Check if `dst` is in a different cluster, unclustered, or about to be
+ // partially declustered (here we rely on the post-order traversal order).
+ // If yes, decluster `n` to avoid the device-to-host memcpy.
+ gtl::optional<StringPiece> dst_cluster =
+ result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst);
+ if (from_cluster != dst_cluster) {
+ CHECK(result->insert(n).second);
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status PartiallyDeclusterNode(Graph* graph, Node* n) {
+ StringPiece cluster_name = *GetXlaClusterForNode(*n);
+ gtl::InlinedVector<const Edge*, 6> out_edges_to_clone;
+ for (const Edge* out_edge : n->out_edges()) {
+ if (out_edge->IsControlEdge()) {
+ continue;
+ }
+
+ Node* dst = out_edge->dst();
+ gtl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst);
+ if (dst_cluster_name != cluster_name) {
+ out_edges_to_clone.push_back(out_edge);
+ }
+ }
+
+ CHECK(!out_edges_to_clone.empty()) << n->DebugString();
+
+ NodeDef ndef = n->def();
+ ndef.set_name(strings::StrCat(n->name(), "/declustered"));
+ RemoveFromXlaCluster(&ndef);
+ Status s;
+ Node* cloned_node = graph->AddNode(ndef, &s);
+ cloned_node->set_assigned_device_name(n->assigned_device_name());
+ TF_RETURN_IF_ERROR(s);
+
+ for (const Edge* in_edge : n->in_edges()) {
+ graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
+ in_edge->dst_input());
+ }
+
+ for (const Edge* out_edge_to_clone : out_edges_to_clone) {
+ graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
+ out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
+ graph->RemoveEdge(out_edge_to_clone);
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status PartiallyDeclusterPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // NB! In this pass we assume the only XLA-auto-clusterable operations that
+ // may have side effects are resource variable operations so we don't cluster
+ // those. The pass will have to be updated if this assumption becomes
+ // invalid.
+
+ Graph* graph = options.graph->get();
+
+ // When deciding whether to decluster a particular node, we base our decision
+ // on if we've decided that some of its consumers have to be declustered too.
+ // Iterating the graph in post-order guarantees that consumers have been
+ // visited before producers.
+ std::vector<Node*> post_order;
+ GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+
+ gtl::FlatSet<Node*> nodes_to_partially_decluster;
+ TF_RETURN_IF_ERROR(FindNodesToDecluster(
+ **options.graph, &nodes_to_partially_decluster, post_order));
+
+ if (VLOG_IS_ON(3)) {
+ for (Node* n : post_order) {
+ if (nodes_to_partially_decluster.count(n)) {
+ VLOG(3) << n->DebugString();
+ }
+ }
+ }
+
+ for (Node* n : post_order) {
+ if (nodes_to_partially_decluster.count(n)) {
+ TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
+ }
+ }
+
+ nodes_to_partially_decluster.clear();
+ TF_RETURN_IF_ERROR(FindNodesToDecluster(
+ **options.graph, &nodes_to_partially_decluster, post_order));
+ CHECK(nodes_to_partially_decluster.empty());
+
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h
new file mode 100644
index 0000000000..6949b5028e
--- /dev/null
+++ b/tensorflow/compiler/jit/partially_decluster_pass.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+
+namespace tensorflow {
+
+// Clones nodes from within a cluster to outside the cluster if profitable.
+//
+// Today this only clones to avoid device-to-host copies, but in the future we
+// may consider other reasons to clone. For instance, we convert this:
+//
+// .....
+// |
+// v
+// A_Clustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// to:
+//
+// .....
+// | |
+// | +-------------+
+// | |
+// v v
+// A_Clustered A_Unclustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+class PartiallyDeclusterPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
new file mode 100644
index 0000000000..08a956e4c6
--- /dev/null
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -0,0 +1,284 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+REGISTER_OP("FakeNullary").Output("out: float");
+
+REGISTER_OP("FakeBinary")
+ .Input("host_in: float")
+ .Input("device_in: float")
+ .Output("host_out: float")
+ .Output("device_out: float");
+
+REGISTER_OP("FakeResourceVar").Output("out: resource");
+
+REGISTER_OP("FakeResourceUpdate")
+ .Input("in: resource")
+ .Output("out: resource")
+ .Output("something_else: float");
+
+class FakeBinaryOp : public OpKernel {
+ public:
+ explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override { CHECK(false); }
+};
+
+class FakeResourceVarUpdateOp : public OpKernel {
+ public:
+ explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override { CHECK(false); }
+};
+
+REGISTER_KERNEL_BUILDER(Name("FakeBinary")
+ .Device(DEVICE_CPU)
+ .HostMemory("host_in")
+ .HostMemory("host_out"),
+ FakeBinaryOp);
+
+REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
+ .Device(DEVICE_CPU)
+ .HostMemory("something_else"),
+ FakeResourceVarUpdateOp);
+
+Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
+ FixupSourceAndSinkEdges(graph->get());
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : (*graph)->nodes()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ PartiallyDeclusterPass pass;
+ return pass.Run(opt_options);
+}
+
+const Node* FindNodeByName(const Graph& graph, const string& name) {
+ for (const Node* node : graph.nodes()) {
+ if (node->name() == name) {
+ return node;
+ }
+ }
+ return nullptr;
+}
+
+bool GetInputsForNode(const Graph& graph, const string& node_name,
+ std::vector<Node*>* inputs) {
+ const Node* node = FindNodeByName(graph, node_name);
+ if (node == nullptr) {
+ return false;
+ }
+ for (const Edge* e : node->in_edges()) {
+ inputs->push_back(e->src());
+ }
+ std::sort(inputs->begin(), inputs->end(), NodeComparatorName());
+ return true;
+}
+
+TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer"));
+ ops::BinaryOp("FakeBinary", clustered_producer, input,
+ builder.opts().WithName("UnclusteredConsumer"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> unclustered_consumer_inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
+ &unclustered_consumer_inputs));
+ ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
+ EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
+ "ClusteredProducer/declustered");
+ EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
+
+ std::vector<Node*> clustered_consumer_inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer",
+ &clustered_consumer_inputs));
+ ASSERT_EQ(clustered_consumer_inputs.size(), 2);
+ EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer");
+ EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DifferentClusters) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer"));
+ Node* consumer_in_different_cluster =
+ ops::BinaryOp("FakeBinary", clustered_producer, input,
+ builder.opts().WithName("ConsumerInDifferentCluster"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
+ ASSERT_EQ(inputs.size(), 2);
+ EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered");
+ EXPECT_EQ(inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer"));
+ // The first input is hostmem and the second input is devicemem.
+ Node* consumer_in_different_cluster =
+ ops::BinaryOp("FakeBinary", input, clustered_producer,
+ builder.opts().WithName("ConsumerInDifferentCluster"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
+ ASSERT_EQ(inputs.size(), 2);
+ EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
+ EXPECT_EQ(inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* resource_var = ops::SourceOp("FakeResourceVar",
+ builder.opts().WithName("ResourceVar"));
+ Node* clustered_producer =
+ ops::UnaryOp("FakeResourceUpdate", resource_var,
+ builder.opts().WithName("ClusteredProducer"));
+ Node* consumer_in_different_cluster =
+ ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
+ builder.opts().WithName("ConsumerInDifferentCluster"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
+ ASSERT_EQ(inputs.size(), 2);
+ EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
+ EXPECT_EQ(inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer_0 =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer0"));
+ Node* clustered_producer_1 =
+ ops::BinaryOp("FakeBinary", clustered_producer_0, input,
+ builder.opts().WithName("ClusteredProducer1"));
+ ops::BinaryOp("FakeBinary", clustered_producer_1, input,
+ builder.opts().WithName("UnclusteredConsumer"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input,
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> unclustered_consumer_inputs, declustered_producer_1_inputs;
+
+ ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
+ &unclustered_consumer_inputs));
+ ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
+ EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
+ "ClusteredProducer1/declustered");
+ EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
+
+ ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered",
+ &declustered_producer_1_inputs));
+ ASSERT_EQ(declustered_producer_1_inputs.size(), 2);
+ EXPECT_EQ(declustered_producer_1_inputs[0]->name(),
+ "ClusteredProducer0/declustered");
+ EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index a5628b12a2..0a025a1fc0 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
return Status::OK();
}
+gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
+ const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
+ if (attr_value == nullptr) {
+ return gtl::nullopt;
+ }
+ Status s = AttrValueHasType(*attr_value, "string");
+ if (!s.ok()) {
+ return gtl::nullopt;
+ }
+ return attr_value->s();
+}
+
+bool HasResourceInputOrOutput(const Node& node) {
+ return std::find(node.input_types().begin(), node.input_types().end(),
+ DT_RESOURCE) != node.input_types().end() ||
+ std::find(node.output_types().begin(), node.output_types().end(),
+ DT_RESOURCE) != node.output_types().end();
+}
+
+void RemoveFromXlaCluster(NodeDef* node_def) {
+ node_def->mutable_attr()->erase(kXlaClusterAttr);
+}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index bcce082aaf..bff76da6f9 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/gtl/optional.h"
namespace tensorflow {
@@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node);
// the enclosing graph.
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
+// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
+// otherwise returns nullopt.
+gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node);
+
+// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
+void RemoveFromXlaCluster(NodeDef* node_def);
+
+// Returns true if `node` has a DT_RESOURCE typed input or output.
+bool HasResourceInputOrOutput(const Node& node);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index f65f89ebf5..dd84fb34c1 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -78,7 +78,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
- launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
+ TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
+ ctx, result, run_result.ConsumeValueOrDie()));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 4ddeaebd3e..2a2691a6a4 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -216,6 +217,8 @@ XlaDevice::XlaDevice(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
+ thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device",
+ /*num_threads=*/1));
}
XlaDevice::~XlaDevice() {
@@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() {
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
- TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
+ xla::StreamPool::Ptr ptr;
+ TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
+ *stream = std::shared_ptr<se::Stream>(std::move(ptr));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
@@ -281,8 +286,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
- se::Stream* host_to_device_stream = stream_.get();
- se::Stream* device_to_host_stream = stream_.get();
+ std::shared_ptr<se::Stream> host_to_device_stream = stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream = stream_;
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@@ -290,8 +295,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
- host_to_device_stream = host_to_device_stream_.get();
- device_to_host_stream = device_to_host_stream_.get();
+ host_to_device_stream = host_to_device_stream_;
+ device_to_host_stream = device_to_host_stream_;
}
if (!need_new_device_context) {
@@ -304,9 +309,13 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
if (device_context_) {
device_context_->Unref();
}
+ // The XlaDeviceContext keeps a reference count to the streams, and the
+ // XlaDeviceContext remains live for the duration of a Executor run. This
+ // ensures that the streams remain live for the duration of a run, even if
+ // an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
- stream_.get(), host_to_device_stream, device_to_host_stream, client(),
- transfer_as_literal_, shape_representation_fn_);
+ stream_, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
@@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
op_kernel->ComputeAsync(context, done);
}
+Status XlaDevice::Sync() {
+ VLOG(1) << "XlaDevice::Sync";
+ std::shared_ptr<se::Stream> stream;
+ {
+ mutex_lock lock(mu_);
+ stream = stream_;
+ }
+ if (!stream) return Status::OK();
+
+ if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
+ return errors::Internal("XlaDevice::Sync() failed.");
+ }
+ VLOG(1) << "XlaDevice::Sync completed";
+ return Status::OK();
+}
+
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d8906419b0..dbf35f349f 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
- Status Sync() override { return Status::OK(); }
+ Status Sync() override;
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice {
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
@@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
@@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice {
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
+
+ // Thread pool used for running closures
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0100bf51ed..0a0c089241 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include <memory>
+
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : stream_(compute_stream),
- host_to_device_stream_(host_to_device_stream),
- device_to_host_stream_(device_to_host_stream),
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : stream_(std::move(compute_stream)),
+ host_to_device_stream_(std::move(host_to_device_stream)),
+ device_to_host_stream_(std::move(device_to_host_stream)),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ thread_pool_(thread_pool) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
@@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice(
if (UseMultipleStreams()) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- host_to_device_stream_->ThenWaitFor(stream_);
+ host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_, *literal, shaped_buffer));
+ host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
- host_to_device_stream_->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(event.get());
+ xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event));
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
@@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
- device_to_host_stream_, shaped_buffer, literal,
+ device_to_host_stream_.get(), shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
@@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback(
- [done]() { done(Status::OK()); });
+ host_to_device_stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback
+ // to avoid a deadlock. If done() is the callback that ends an
+ // Executor's run, the Executor may call XlaDevice::Sync() inside the
+ // callback. This deadlocks, because XlaDevice::Sync() waits for all
+ // stream activity to complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
return;
}
} else {
@@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
- host_to_device_stream_, block_status.error_message().c_str());
+ host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
@@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
- xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) {
device_to_host_stream_->ThenWaitFor(event);
- xla_tensor->SetDefinedOn(device_to_host_stream_);
+ xla_tensor->SetDefinedOn(device_to_host_stream_.get());
}
Status status;
@@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
+ "Failed to complete data transfer on stream %p: %s", stream_.get(),
block_status.error_message().c_str());
}
}
@@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
if (stream_ != device_to_device_stream) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- device_to_device_stream->ThenWaitFor(stream_);
+ device_to_device_stream->ThenWaitFor(stream_.get());
}
}
if (se::Event* event =
- xla_src->GetDefinitionEvent(device_to_device_stream)) {
+ xla_src->GetDefinitionEvent(device_to_device_stream.get())) {
device_to_device_stream->ThenWaitFor(event);
- xla_src->SetDefinedOn(device_to_device_stream);
+ xla_src->SetDefinedOn(device_to_device_stream.get());
}
auto from_iter = xla_src->shaped_buffer().buffers().begin();
@@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- CHECK(event.Init());
- device_to_device_stream->ThenRecordEvent(&event);
- xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize";
+ device_to_device_stream->ThenRecordEvent(event.get());
+ xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event));
}
return Status::OK();
}();
if (!status.ok()) {
return done(status);
} else {
- stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
+ stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback to avoid
+ // a deadlock. If done() is the callback that ends an Executor's run, the
+ // Executor may call XlaDevice::Sync() inside the callback. This
+ // deadlocks, because XlaDevice::Sync() waits for all stream activity to
+ // complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
}
}
XlaDeviceContext::XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
- client, transfer_as_literal,
- std::move(shape_representation_fn)) {}
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : manager_(std::move(compute_stream), std::move(host_to_device_stream),
+ std::move(device_to_host_stream), client, transfer_as_literal,
+ std::move(shape_representation_fn), thread_pool) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 912f8d779e..2e7445340c 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@@ -61,7 +63,7 @@ class XlaTransferManager {
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
- se::Stream* stream() const { return stream_; }
+ se::Stream* stream() const { return stream_.get(); }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@@ -73,13 +75,13 @@ class XlaTransferManager {
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
- se::Stream* stream_;
+ std::shared_ptr<se::Stream> stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* host_to_device_stream_;
+ std::shared_ptr<se::Stream> host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* device_to_host_stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -87,6 +89,9 @@ class XlaTransferManager {
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+
+ // Thread pool used for running closures
+ thread::ThreadPool* thread_pool_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -95,10 +100,12 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 6134b8c694..4efbb2d5d7 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include <memory>
+
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
-void XlaComputationLaunchContext::PopulateOutputs(
+Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
se::Stream* stream =
@@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs(
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
}
+ std::shared_ptr<se::Event> definition_event;
+ if (use_multiple_streams_) {
+ definition_event = std::make_shared<se::Event>(stream->parent());
+ if (!definition_event->Init()) {
+ return errors::Internal("Failed to initialize tensor definition event.");
+ }
+ stream->ThenRecordEvent(definition_event.get());
+ }
+
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
- OP_REQUIRES(ctx, device != nullptr,
- errors::Internal("DeviceBase was not a Device."));
+ if (device == nullptr) {
+ return errors::Internal("DeviceBase was not a Device.");
+ }
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
@@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
@@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- OP_REQUIRES(ctx,
- write.input_index >= 0 && write.input_index < ctx->num_inputs(),
- errors::Internal("Invalid input index for variable write."));
+ if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ return errors::Internal("Invalid input index for variable write.");
+ }
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
- OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index),
- &variable, [this, ctx, &write](Var** ptr) {
- *ptr = new Var(write.type);
- return Status::OK();
- }));
+ TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, write.input_index), &variable,
+ [&write](Var** ptr) {
+ *ptr = new Var(write.type);
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
- OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
- errors::Internal("Mismatched type in variable write"));
+ if (variable->tensor()->dtype() != write.type) {
+ return errors::Internal("Mismatched type in variable write");
+ }
if (allocate_xla_tensors_) {
Tensor output_tensor;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
*variable->tensor() = output_tensor;
} else {
@@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
++output_num;
}
+ return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 1ea3fa4cf2..4232f514b3 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -93,9 +93,9 @@ class XlaComputationLaunchContext {
const std::map<int, OptionalTensor>& variables);
// Given the XLA output in `output`, populate all outputs of `ctx`.
- void PopulateOutputs(OpKernelContext* ctx,
- const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ Status PopulateOutputs(OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult* kernel,
+ xla::ScopedShapedBuffer output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index d777dfa5a3..92ba7de1b7 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
mutex_lock lock(mu_);
- if (!definition_event_.has_value()) {
+ if (!definition_event_) {
return nullptr;
}
@@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
return nullptr;
}
- return &*definition_event_;
+ return definition_event_.get();
}
-void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
+void XlaTensor::SetDefinedOn(se::Stream* stream,
+ std::shared_ptr<se::Event> event) {
mutex_lock lock(mu_);
definition_event_ = std::move(event);
streams_defined_on_ = {stream};
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index f7e401c731..8d36d0fa0a 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
+#include <memory>
+
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@@ -94,7 +96,7 @@ class XlaTensor {
// Assert that the tensor's content is defined on 'stream' by the time 'event'
// triggers.
- void SetDefinedOn(se::Stream* stream, se::Event event);
+ void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event);
// Assert that the tensor's content is defined on 'stream'. This version does
// not provide an event, and must be called *after* SetDefinedOn(Stream,
@@ -116,7 +118,7 @@ class XlaTensor {
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.
- gtl::optional<se::Event> definition_event_;
+ std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index f42fb92359..1bf8948ef6 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects;
std::once_flag flags_init;
void SetDebugOptionsDefaults(DebugOptions* flags) {
- flags->set_xla_enable_fast_math(true);
flags->set_xla_llvm_enable_alias_scope_metadata(true);
flags->set_xla_llvm_enable_noalias_metadata(true);
flags->set_xla_llvm_enable_invariant_load_metadata(true);
@@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// the heuristics needed to decide when to run on multiple streams. See
// b/77879207.
flags->set_xla_gpu_disable_multi_streaming(true);
+
+ // TODO(jlebar): Disable fastmath once doing so is not a performance
+ // regression.
+ flags->set_xla_cpu_enable_fast_math(true);
+ flags->set_xla_gpu_enable_fast_math(true);
}
// Allocates flag_values and flag_objects; this function must not be called more
@@ -150,10 +154,16 @@ void AllocateFlags() {
flag_values->mutable_xla_generate_hlo_text_to(),
"Dump all HLO modules as text into the provided directory path."),
tensorflow::Flag(
- "xla_enable_fast_math",
- bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
- flag_values->xla_enable_fast_math(),
- "Enable unsafe fast-math optimizations in the compiler; "
+ "xla_cpu_enable_fast_math",
+ bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
+ flag_values->xla_cpu_enable_fast_math(),
+ "Enable unsafe fast-math optimizations in the CPU compiler; "
+ "this may produce faster code at the expense of some accuracy."),
+ tensorflow::Flag(
+ "xla_gpu_enable_fast_math",
+ bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
+ flag_values->xla_cpu_enable_fast_math(),
+ "Enable unsafe fast-math optimizations in the GPU compiler; "
"this may produce faster code at the expense of some accuracy."),
tensorflow::Flag(
"xla_llvm_enable_alias_scope_metadata",
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7d315fa0d3..7331d2b54c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1234,6 +1234,20 @@ cc_library(
],
)
+cc_library(
+ name = "scatter_expander",
+ srcs = ["scatter_expander.cc"],
+ hdrs = ["scatter_expander.h"],
+ deps = [
+ ":hlo",
+ ":hlo_creation_utils",
+ ":hlo_pass",
+ ":while_util",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ ],
+)
+
tf_cc_test(
name = "batchnorm_expander_test",
size = "small",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 37834e1cc2..f7812d9661 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1705,6 +1705,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
reshape, HloInstruction::CreateReshape(reshape->shape(),
operand->mutable_operand(0)));
}
+ if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
+ *operand->mutable_shape() = reshape->shape();
+ return ReplaceInstruction(reshape, operand);
+ }
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
@@ -2144,6 +2148,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
transpose->dimensions())));
}
+ if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
+ *operand->mutable_shape() = transpose->shape();
+ return ReplaceInstruction(transpose, operand);
+ }
+
if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) {
ReplaceWithBitcast(transpose);
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 862cbeeba6..5837391d75 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1428,6 +1428,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
}
+// Test transforming reshapes and transposes of rng.
+TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* rng0 = builder.AddInstruction(
+ HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
+ RandomDistribution::RNG_UNIFORM, {zero, one}));
+
+ HloInstruction* transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
+ Shape reshape_shape = builder
+ .AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {4}), transpose))
+ ->shape();
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ bitcasting_callback());
+ EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ // Verify that that reshape(transpose(rng)) is replace by a single rng of the
+ // same shape as the reshape.
+ EXPECT_THAT(computation->root_instruction(), op::Rng());
+ EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
+ reshape_shape));
+}
+
// Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 118a11c8de..cfd26fc778 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType(
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kFusion:
// Map/reduce etc computations are always thread-local.
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index a23427f00c..985ff30e80 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kFusion:
return CallContext::kParallel;
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 36fb9b43aa..3e39c1bab1 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -312,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
-// We add copies for all the indices of the true and false computaiton roots,
+// We add copies for all the indices of the true and false computation roots,
// in order to resolve interference. We later rely on the CopyRemover to drop
// the unnecessary ones.
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
@@ -648,7 +648,12 @@ class CopyRemover {
// We can only perform copy elision if the resulting merged values have
// totally ordered live ranges; otherwise the merged buffer would have
// live range interference.
- if (IsHead(*dest)) {
+ if (src->next == dest) {
+ // In the process of eliding copies, its possible for a copy to have the
+ // same source and destination buffer. In this case, the copy can be
+ // safely removed.
+ VLOG(2) << copy->name() << " source and destination buffers are same.";
+ } else if (IsHead(*dest)) {
// The copy copies an arbitrary value in the source buffer (call it s_x)
// and defines d_0, the first value in the destination buffer. After
// merging, the values in the combined buffer must be strictly ordered
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index cd735256b8..892d0d7b54 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -2007,5 +2007,46 @@ ENTRY TestComputation {
InsertCopies(module.get());
}
+TEST_F(CopyInsertionTest, NestedWhiles) {
+ // Verify that only no unnecessary copies remain after copy insertion for
+ // trivial nested whiles (b/112472605).
+ const string& hlo_string = R"(
+HloModule TestModule
+
+cond.inner {
+ ROOT param.cond.inner = pred[] parameter(0)
+}
+
+body.inner {
+ param.body.inner = pred[] parameter(0)
+ ROOT neg = pred[] negate(param.body.inner)
+}
+
+cond.outer {
+ ROOT param.cond.outer = pred[] parameter(0)
+}
+
+body.outer {
+ param.cond.outer = pred[] parameter(0)
+ ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
+}
+
+ENTRY TestComputation {
+ entry_param = pred[] parameter(0)
+ ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
+ InsertCopies(module.get());
+
+ // There should only be a single copy inserted, and it's in the entry
+ // computation.
+ EXPECT_EQ(CountCopies(*module), 1);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::While(op::Copy(op::Parameter())));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 3efe3e2f93..84779c60b0 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
load(
"//third_party/mkl:build_defs.bzl",
- "if_mkl",
+ "mkl_deps",
)
# Filegroup used to collect source files for dependency checking.
@@ -86,6 +86,7 @@ cc_library(
":parallel_task_assignment",
":simple_orc_jit",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
+ "//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
@@ -497,10 +498,7 @@ cc_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
- ] + if_mkl([
- "@mkl_dnn",
- "//third_party/mkl:intel_binary_blob",
- ]),
+ ] + mkl_deps(),
)
cc_library(
@@ -554,10 +552,7 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
- ] + if_mkl([
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ]),
+ ] + mkl_deps(),
)
cc_library(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 2df959c4dc..35154af048 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -88,6 +88,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@@ -299,6 +300,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
+ pipeline.AddPass<ScatterExpander>();
+
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -356,7 +359,7 @@ llvm::TargetOptions CompilerTargetOptions(
llvm::TargetOptions target_options;
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/module_config.debug_options()
- .xla_enable_fast_math(),
+ .xla_cpu_enable_fast_math(),
&target_options);
return target_options;
}
@@ -523,7 +526,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
- module->config().debug_options().xla_enable_fast_math(),
+ module->config().debug_options().xla_cpu_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
pre_optimization_ir_hook, post_optimization_ir_hook);
llvm_module->setDataLayout(jit->data_layout());
@@ -653,9 +656,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// so we bail if the configs have conflicting flags. At the moment, the only
// flag that needs to be consistent is fast-math.
const bool fast_math_enabled =
- modules[0]->config().debug_options().xla_enable_fast_math();
+ modules[0]->config().debug_options().xla_cpu_enable_fast_math();
for (const auto& module : modules) {
- if (module->config().debug_options().xla_enable_fast_math() !=
+ if (module->config().debug_options().xla_cpu_enable_fast_math() !=
fast_math_enabled) {
return InvalidArgument(
"All HLO module configs must have the same value for "
@@ -832,7 +835,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
CompilerFunctor compiler_functor(
target_machine.get(), &disassembler, opt_level,
options::OptimizeForSizeRequested(module->config()),
- module->config().debug_options().xla_enable_fast_math(),
+ module->config().debug_options().xla_cpu_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook);
std::unique_ptr<llvm::MemoryBuffer> object_file =
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 946f5124b8..c376864c3e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -249,24 +249,11 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
- if (GetRootPointsToSet().IsAmbiguous()) {
- return Unimplemented("Points-to set of root instruction is ambiguous");
- }
-
- se::Stream* stream = run_options->stream();
- DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-
- std::vector<OwningDeviceMemory> owning_buffers;
- std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
- std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
-
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(
- &run_options->run_options(), unowning_buffers, hlo_execution_profile));
-
- return CreateResultShapedBuffer(run_options, &owning_buffers);
+ auto result,
+ ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile));
+ TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
+ return std::move(result);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -277,6 +264,16 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
"Asynchronous execution on stream with hlo profiling is not yet "
"supported on CPU.");
}
+ return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr);
+}
+
+StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ return Unimplemented("Points-to set of root instruction is ambiguous");
+ }
auto* host_stream = dynamic_cast<se::host::HostStream*>(
run_options->stream()->implementation());
@@ -310,19 +307,20 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
ServiceExecutableRunOptions run_options;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
+ HloExecutionProfile* hlo_execution_profile;
void operator()() {
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
- &run_options.run_options(), unowning_buffers,
- /*hlo_execution_profile=*/nullptr));
+ &run_options.run_options(), unowning_buffers, hlo_execution_profile));
}
};
host_stream->EnqueueTask(
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
std::make_shared<std::vector<OwningDeviceMemory>>(
- std::move(owning_buffers))});
+ std::move(owning_buffers)),
+ hlo_execution_profile});
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 8af8a5dfec..96e53de57e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -85,6 +85,16 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
+ // This is for sharing the code between ExecuteOnStream and
+ // ExecuteAsyncOnStream.
+ //
+ // Notice that it's tricky to use correctly, as the profile object (when it
+ // exists) must out-live the task.
+ StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile);
+
// Creates an array suitable for passing as the "temps" argument to the JIT
// compiled function pointer.
//
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 645888de78..f2ac742b6e 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -1066,7 +1066,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
<< config.GetCacheKey();
const bool enable_fast_math =
- hlo_module_config_.debug_options().xla_enable_fast_math();
+ hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
@@ -1149,7 +1149,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
const bool enable_fast_math =
- hlo_module_config_.debug_options().xla_enable_fast_math();
+ hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 09909b62ba..6f433b4f30 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -99,7 +99,7 @@ IrEmitter::IrEmitter(
target_machine_features_(*target_machine_features) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
- .xla_enable_fast_math()));
+ .xla_cpu_enable_fast_math()));
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
@@ -158,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::InternalLinkage;
// Create and initialize new IrFunction.
- compute_function_.reset(
- new IrFunction(function_name, linkage,
- options::OptimizeForSizeRequested(hlo_module_config_),
- hlo_module_config_.debug_options().xla_enable_fast_math(),
- module_, &b_, num_dynamic_loop_bounds_));
+ compute_function_.reset(new IrFunction(
+ function_name, linkage,
+ options::OptimizeForSizeRequested(hlo_module_config_),
+ hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_,
+ &b_, num_dynamic_loop_bounds_));
}
IrEmitter::~IrEmitter() {}
@@ -577,7 +577,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*reduce_window,
/*operands=*/{reduce_window->operand(0)},
- /*supported_types=*/{F32, BF16, S32}));
+ /*supported_types=*/{F32, BF16, S32, F16}));
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(reduce_window->window())) {
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index d938f3a2c4..48e4471499 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -21,8 +21,33 @@ limitations under the License.
namespace xla {
+namespace {
+
+// Pass which strips control dependencies from all instructions in the module.
+class ControlDepRemover : public HloPassInterface {
+ public:
+ ControlDepRemover() = default;
+ tensorflow::StringPiece name() const override {
+ return "control-dep-remover";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ changed = changed || !instruction->control_predecessors().empty();
+ TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+ }
+ }
+ return changed;
+ }
+};
+
+} // namespace
+
Despecializer::Despecializer() : pipeline_("despecializer") {
// TODO(b/70588125): Also deal with window reversal in a fast way.
+ pipeline_.AddPass<ControlDepRemover>();
pipeline_.AddPass<Defuser>();
pipeline_.AddPass<ImplicitBroadcastRemover>();
pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a3f6e8d989..19575c7905 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1,6 +1,7 @@
# Description:
# GPU-specific components in XLA service implementation.
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
licenses(["notice"]) # Apache 2.0
@@ -365,6 +366,7 @@ cc_library(
":gpu_executable",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -652,6 +654,7 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
+ "//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
@@ -852,3 +855,35 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
+
+cc_library(
+ name = "buffer_comparator",
+ srcs = ["buffer_comparator.cc"],
+ hdrs = ["buffer_comparator.h"],
+ deps = [
+ ":gpu_executable",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_runner",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+xla_test(
+ name = "buffer_comparator_test",
+ srcs = ["buffer_comparator_test.cc"],
+ backends = [
+ "cpu",
+ "gpu",
+ ],
+ deps = [
+ ":buffer_comparator",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
new file mode 100644
index 0000000000..6a285a6b98
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
@@ -0,0 +1,205 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
+
+#include <cmath>
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace gpu {
+
+static constexpr float kTolerance = 0.1f;
+
+static string GetCompHloText(size_t num_elements) {
+ // Implements the textual format of the comparison routine, as it's more
+ // readable.
+ static constexpr char kF16CompHloText[] = R"(
+HloModule CompareF16
+
+MaxF32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %max = f32[] maximum(%lhs, %rhs)
+}
+
+Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] {
+ %min_constant = f32[] constant(-65505)
+ %max_constant = f32[] constant(65505)
+ %large_constant = f32[] constant(1048576)
+ %min_values = f32[SIZE] broadcast(%min_constant), dimensions={}
+ %max_values = f32[SIZE] broadcast(%max_constant), dimensions={}
+ %large_values = f32[SIZE] broadcast(%large_constant), dimensions={}
+
+ %a = f16[SIZE] parameter(0)
+ %converted = f32[SIZE] convert(%a)
+ %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values)
+
+ // Since the clamp() above already took care of infs, only NaNs will cause
+ // is-finite() to return false.
+ %is_finite = pred[SIZE] is-finite(%clamped)
+ ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values)
+}
+
+ENTRY MaxDifference {
+ %one_constant = f32[] constant(1.0)
+ %zero_constant = f32[] constant(0.0)
+
+ %ones = f32[SIZE] broadcast(%one_constant), dimensions={}
+
+ %lhs = f16[SIZE] parameter(0)
+ %rhs = f16[SIZE] parameter(1)
+ %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize
+ %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize
+ %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical)
+ %sub_abs = f32[SIZE] abs(%sub)
+ %lhs_abs = f32[SIZE] abs(%lhs_canonical)
+ %rhs_abs = f32[SIZE] abs(%rhs_canonical)
+ %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs)
+ %denominator = f32[SIZE] add(%max, %ones)
+ %error = f32[SIZE] divide(%sub_abs, %denominator)
+ ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32
+})";
+ auto size_string = std::to_string(num_elements);
+ return tensorflow::str_util::StringReplace(
+ kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true);
+}
+
+StatusOr<F16BufferComparator> F16BufferComparator::Create(
+ se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler,
+ DeviceMemoryAllocator* allocator, se::Stream* stream) {
+ auto stream_exec = stream->parent();
+ int64 num_elements = ref_buffer.ElementCount();
+
+ // One may consider using hlo_runner to do all the compilation and execution.
+ // However, as of the time hlo_runner doesn't support injection for Compiler*,
+ // Stream*, or even the allocator. We may revisit this in the future if it
+ // proves to be a maintenance burden.
+ TF_ASSIGN_OR_RETURN(
+ auto exec, ([&]() -> StatusOr<std::unique_ptr<Executable>> {
+ HloModuleConfig config;
+ DebugOptions debug_options;
+ debug_options.set_xla_backend_optimization_level(2);
+ config.set_debug_options(debug_options);
+ TF_ASSIGN_OR_RETURN(
+ auto module, ParseHloString(GetCompHloText(num_elements), config));
+ TF_ASSIGN_OR_RETURN(
+ module,
+ compiler->RunHloPasses(std::move(module), stream_exec, nullptr));
+ return compiler->RunBackend(std::move(module), stream_exec, nullptr);
+ }()));
+
+ TF_ASSIGN_OR_RETURN(
+ auto shaped_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> {
+ auto device_ordinal = stream_exec->device_ordinal();
+ TF_ASSIGN_OR_RETURN(
+ auto owning_buffer,
+ allocator->Allocate(device_ordinal, ref_buffer.size()));
+ se::DeviceMemory<Eigen::half> buffer(
+ owning_buffer.AsDeviceMemoryBase());
+ stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size());
+ Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements});
+ ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal);
+ ret.set_buffer(std::move(owning_buffer), {});
+ return std::move(ret);
+ }()));
+
+ return F16BufferComparator(stream, allocator, std::move(exec),
+ std::move(shaped_buffer));
+}
+
+StatusOr<bool> F16BufferComparator::CompareEqualImpl(
+ se::DeviceMemory<Eigen::half> test_buffer) {
+ if (ref_buffer_.root_buffer().size() != test_buffer.size()) {
+ return InternalError("Mismatched buffer size: %lld vs %lld",
+ ref_buffer_.root_buffer().size(), test_buffer.size());
+ }
+
+ int64 num_elements = test_buffer.ElementCount();
+
+ TF_ASSIGN_OR_RETURN(
+ auto result_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> {
+ auto stream_exec = stream_->parent();
+ Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements});
+ auto device_ordinal = stream_exec->device_ordinal();
+ ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(),
+ device_ordinal);
+ shaped_test_buffer.set_buffer(test_buffer, {});
+ ExecutableRunOptions run_options;
+ run_options.set_device_ordinal(stream_exec->device_ordinal());
+ run_options.set_stream(stream_);
+ run_options.set_allocator(allocator_);
+ ServiceExecutableRunOptions service_run_options(run_options);
+ return exec_->ExecuteOnStream(
+ &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr);
+ }()));
+
+ float result;
+ CHECK(result_buffer.root_buffer().size() == sizeof(result));
+ stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result));
+ TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
+ return result < kTolerance;
+}
+
+StatusOr<bool> F16BufferComparator::CompareEqual(
+ se::DeviceMemory<Eigen::half> test_buffer) {
+ TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer));
+ if (result) {
+ return true;
+ }
+ // Host side code that does the same thing, but report some of the
+ // differences as well.
+ int64 n = test_buffer.ElementCount();
+ std::vector<half> host_ref_buffer(n), host_test_buffer(n);
+ stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(),
+ ref_buffer_.root_buffer().size());
+ stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size());
+ TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
+
+ const auto canonicalize = [](float a) -> float {
+ constexpr float kBigNumer = 1048576.;
+ constexpr float kMaxFp16Value = 65504.;
+ if (std::isnan(a)) {
+ return kBigNumer;
+ }
+ if (std::isinf(a)) {
+ if (a < 0) {
+ return -(kMaxFp16Value + 1);
+ }
+ return kMaxFp16Value + 1;
+ }
+ return a;
+ };
+ int differences_seen = 0;
+ for (int64 i = 0; i < n && differences_seen < 10; i++) {
+ float original_ref = static_cast<float>(host_ref_buffer[i]);
+ float original_test = static_cast<float>(host_test_buffer[i]);
+ float ref = canonicalize(original_ref);
+ float test = canonicalize(original_test);
+ if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) <
+ kTolerance)) {
+ differences_seen++;
+ LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs "
+ << original_test;
+ }
+ }
+
+ return false;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h
new file mode 100644
index 0000000000..bf2ba78cea
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A fp16 comparator that internally keeps a reference buffer, and compares it
+// against other test buffers.
+class F16BufferComparator {
+ public:
+ F16BufferComparator(const F16BufferComparator&) = delete;
+ F16BufferComparator(F16BufferComparator&&) = default;
+
+ // Creates a new comparator. It internally allocates a buffer initialized by
+ // ref_buffer.
+ static StatusOr<F16BufferComparator> Create(
+ se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler,
+ DeviceMemoryAllocator* allocator, se::Stream* stream);
+
+ // Returns true if the internally allocated buffer "compares equal" to
+ // test_buffer. The definition of "equal" is:
+ // * All NaNs equal.
+ // * All infs are treated as 65505 or -65505, so that this checker is tolerant
+ // to fp16 overflows.
+ // * With NaNs and infs taken care of, a and b compare equal iff:
+ // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance
+ //
+ // See the implementation for the tolerance value.
+ StatusOr<bool> CompareEqual(se::DeviceMemory<Eigen::half> test_buffer);
+
+ private:
+ F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator,
+ std::unique_ptr<Executable> exec,
+ ScopedShapedBuffer ref_buffer)
+ : stream_(stream),
+ allocator_(allocator),
+ exec_(std::move(exec)),
+ ref_buffer_(std::move(ref_buffer)) {}
+
+ StatusOr<bool> CompareEqualImpl(se::DeviceMemory<Eigen::half> test_buffer);
+
+ se::Stream* stream_;
+ DeviceMemoryAllocator* allocator_;
+ std::unique_ptr<Executable> exec_;
+ ScopedShapedBuffer ref_buffer_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc
new file mode 100644
index 0000000000..33761d1bd8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
+
+#include <limits>
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class BufferComparatorTest : public testing::Test {
+ protected:
+ BufferComparatorTest()
+ : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()),
+ stream_exec_(backend_->default_stream_executor()),
+ allocator_(stream_exec_->platform(), {stream_exec_}),
+ compiler_(Compiler::GetForPlatform(stream_exec_->platform())
+ .ConsumeValueOrDie()) {}
+
+ // Take floats only for convenience. Still uses half internally.
+ bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
+ const std::vector<float>& rhs_float) {
+ std::vector<half> lhs(lhs_float.begin(), lhs_float.end());
+ std::vector<half> rhs(rhs_float.begin(), rhs_float.end());
+ se::Stream stream(stream_exec_);
+ stream.Init();
+
+ auto owning_lhs_buffer =
+ allocator_
+ .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half))
+ .ConsumeValueOrDie();
+
+ auto owning_rhs_buffer =
+ allocator_
+ .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half))
+ .ConsumeValueOrDie();
+
+ auto lhs_buffer =
+ se::DeviceMemory<Eigen::half>(owning_lhs_buffer.AsDeviceMemoryBase());
+ auto rhs_buffer =
+ se::DeviceMemory<Eigen::half>(owning_rhs_buffer.AsDeviceMemoryBase());
+
+ stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size());
+ stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size());
+
+ TF_CHECK_OK(stream.BlockHostUntilDone());
+
+ return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_,
+ &stream)
+ .ConsumeValueOrDie()
+ .CompareEqual(rhs_buffer)
+ .ConsumeValueOrDie();
+ }
+
+ std::unique_ptr<Backend> backend_;
+ se::StreamExecutor* stream_exec_;
+ StreamExecutorMemoryAllocator allocator_;
+ Compiler* compiler_;
+};
+
+TEST_F(BufferComparatorTest, TestNaNs) {
+ EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")}));
+ // NaN values with different bit patterns should compare equal.
+ EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.}));
+}
+
+TEST_F(BufferComparatorTest, TestInfs) {
+ const auto inf = std::numeric_limits<float>::infinity();
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504}));
+
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20}));
+}
+
+TEST_F(BufferComparatorTest, TestNumbers) {
+ EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9}));
+}
+
+TEST_F(BufferComparatorTest, TestMultiple) {
+ EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60},
+ {20.1, 30.1, 40.1, 50.1, 60.1}));
+ std::vector<float> lhs(200);
+ std::vector<float> rhs(200);
+ for (int i = 0; i < 200; i++) {
+ EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs))
+ << "should be the same at index " << i;
+ lhs[i] = 3;
+ rhs[i] = 5;
+ EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs))
+ << "should be the different at index " << i;
+ lhs[i] = 0;
+ rhs[i] = 0;
+ }
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 7348307ec8..7d93bdfc8b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -30,7 +30,6 @@ namespace {
using se::DeviceMemoryBase;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
-using tensorflow::gtl::nullopt;
using tensorflow::gtl::optional;
class ScratchAllocator : public se::ScratchAllocator {
@@ -173,7 +172,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
-optional<std::tuple<int64, bool, int64>>
+StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
@@ -206,45 +205,25 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// Allocate space for the input, filter, and output of the convolution. We
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
- //
- // We don't put any data in these buffers, because (in theory, anyway) the
- // speed of a conv isn't affected by the data being convolved.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- StatusOr<DeviceMemoryBase> maybe_input_buf =
- input_output_allocator.AllocateBytes(&stream,
- ShapeUtil::ByteSizeOf(input_shape));
- StatusOr<DeviceMemoryBase> maybe_filter_buf =
- input_output_allocator.AllocateBytes(&stream,
- ShapeUtil::ByteSizeOf(filter_shape));
- StatusOr<DeviceMemoryBase> maybe_output_buf =
- input_output_allocator.AllocateBytes(&stream,
- ShapeUtil::ByteSizeOf(output_shape));
- if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() ||
- !maybe_output_buf.ok()) {
- LOG(WARNING)
- << "Couldn't allocate space for input/filter/output of convolution "
- << instr->ToString() << ". Falling back to default algorithm.";
- return nullopt;
- }
-
- DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie();
- DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie();
- DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie();
+ TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(input_shape)));
+ TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+ TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(output_shape)));
// Although we don't have evidence this matters, zero out the buffers before
// autotuning. It's conceivable that using uninitialized memory as the inputs
// might affect performance if e.g. the inputs contain denormals, and this is
// easy enough.
- if (!stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size())
- .BlockHostUntilDone()
- .ok()) {
- LOG(WARNING)
- << "Couldn't zero out input/filter/output buffer for convolution "
- << instr->ToString() << ". Falling back to default algorithm.";
- return nullopt;
- }
+ TF_RETURN_IF_ERROR(stream.ThenMemZero(&input_buf, input_buf.size())
+ .ThenMemZero(&filter_buf, filter_buf.size())
+ .ThenMemZero(&output_buf, output_buf.size())
+ .BlockHostUntilDone());
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
input_shape, output_shape, dnums, stream_exec_);
@@ -292,9 +271,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
best_result_bytes_used);
}
- LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
- << " failed. Falling back to default algorithm.";
- return nullopt;
+ return InternalError(
+ "All algorithms tried for convolution %s failed. Falling back to "
+ "default algorithm.",
+ instr->ToString().c_str());
}
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
@@ -305,12 +285,13 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
const auto& lhs_shape = instr->operand(0)->shape();
const auto& rhs_shape = instr->operand(1)->shape();
const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
+ StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
if (call_target == kCudnnConvForwardCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
- instr->window(), instr->convolution_dimension_numbers(), instr);
+ alg_scratch_and_tc =
+ PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
+ /*filter_shape=*/rhs_shape,
+ /*output_shape=*/conv_result_shape, instr->window(),
+ instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
@@ -326,7 +307,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
<< instr->ToString();
}
- if (!alg_scratch_and_tc.has_value()) {
+ if (!alg_scratch_and_tc.ok()) {
+ LOG(ERROR) << alg_scratch_and_tc.status();
return false;
}
@@ -334,7 +316,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
bool tensor_ops_enabled;
int64 scratch_bytes;
- std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
+ std::tie(algorithm, tensor_ops_enabled, scratch_bytes) =
+ alg_scratch_and_tc.ConsumeValueOrDie();
VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
<< NumBytesToString(scratch_bytes)
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index bc5d1ce94a..8b7749628a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
// memory while timing the various convolution algorithms. If it's null,
// we'll use the default allocator on the StreamExecutor.
CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec,
- DeviceMemoryAllocator* allocator)
- : stream_exec_(stream_exec), allocator_(allocator) {}
+ DeviceMemoryAllocator* allocator,
+ Compiler* compiler)
+ : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
tensorflow::StringPiece name() const override {
return "cudnn-convolution-algorithm-picker";
@@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
+ StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
+ Compiler* compiler_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 69ba91793d..9b6de115ad 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -210,11 +210,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
return make_sqrt();
}
- if (hlo_module_config_.debug_options().xla_enable_fast_math() &&
- IsFPLiteralWithValue(rhs, -.5)) {
+ if (IsFPLiteralWithValue(rhs, -.5)) {
VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString();
// LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX
// rsqrt.approx instruction.
+ //
+ // TODO(jlebar): Does this happen with fastmath disabled? If not, should
+ // we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
@@ -274,16 +276,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
- // If we don't care much about precision, emit a fast approximation of
- // tanh.
- if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
- // Upcast F16 to F32 if necessary.
- llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
- llvm::Value* input = b_->CreateFPCast(value, type);
- llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
- return b_->CreateFPCast(fast_tanh, value->getType());
- }
- return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
+ // Emit a fast approximation of tanh instead of calling __nv_tanh.
+ // __nv_tanh is particularly bad because it contains branches, thus
+ // preventing LLVM's load-store vectorizer from working its magic across a
+ // function which contains tanh calls.
+ //
+ // This routine isn't numerically precise, but it's good enough for ML.
+
+ // Upcast F16 to F32 if necessary.
+ llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
+ llvm::Value* input = b_->CreateFPCast(value, type);
+ llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+ return b_->CreateFPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 66aeb4efef..6675dbd3f9 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -64,7 +64,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
hlo_module_config_(hlo_module_config) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config.debug_options()
- .xla_enable_fast_math()));
+ .xla_gpu_enable_fast_math()));
}
Status IrEmitter::DefaultAction(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index cf44458a2e..ff4ae1f9ef 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -180,7 +180,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
TargetOptions target_options = InitTargetOptionsFromCodeGenFlags();
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/hlo_module_config.debug_options()
- .xla_enable_fast_math(),
+ .xla_gpu_enable_fast_math(),
&target_options);
// Enable FMA synthesis.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 76c9b6ab33..d937123357 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -72,6 +72,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@@ -130,8 +131,12 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
}
// Runs optimization passes on the given HLO module.
+//
+// It takes a compiler pointer, as passes may compile and execute HLOs on the
+// fly for cuDNN verification or other purposes.
Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
- DeviceMemoryAllocator* device_allocator) {
+ DeviceMemoryAllocator* device_allocator,
+ Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
@@ -167,6 +172,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();
+ pipeline.AddPass<ScatterExpander>();
+
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
@@ -245,8 +252,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// the gte(customcall, 0) would probably already be into a fusion node. We
// can't simplify across HloComputation boundaries, so in this case we
// wouldn't be able to simplify away the new_tuple bits.
- pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
- device_allocator);
+ pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(
+ stream_exec, device_allocator, compiler);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
@@ -492,11 +499,15 @@ NVPTXCompiler::NVPTXCompiler()
StatusOr<std::unique_ptr<HloModule>> NVPTXCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) {
+ // We dump the post-optimization HLO in RunBackend so no need to dump it here.
+ VLOG(2) << "*** HLO Before Optimization";
+ XLA_VLOG_LINES(2, module->ToString());
+
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses");
tracing::ScopedActivity activity("HLO Transforms", module->name(),
/*is_expensive=*/true);
TF_RETURN_IF_ERROR(
- OptimizeHloModule(module.get(), stream_exec, device_allocator));
+ OptimizeHloModule(module.get(), stream_exec, device_allocator, this));
return std::move(module);
}
@@ -548,6 +559,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// include headers, so no need for us to print them ourselves.
XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString());
XLA_VLOG_LINES(2, buffer_assignment->ToString());
+ VLOG(2) << "*** HLO After Optimization";
XLA_VLOG_LINES(2, module->ToString());
const string xla_dump_optimized_hlo_proto_to =
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 90d2be118d..858992a326 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -174,6 +174,29 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
}
+StatusOr<HloInstruction*> MakeMapHlo(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation) {
+ CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
+ HloComputation* computation = operands.front()->parent();
+ std::vector<const Shape*> operand_shapes;
+ int64 max_operand_rank = 0;
+ for (const HloInstruction* operand : operands) {
+ CHECK_EQ(computation, operand->parent());
+ operand_shapes.push_back(&operand->shape());
+ max_operand_rank =
+ std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
+ }
+ std::vector<int64> map_dims(max_operand_rank);
+ std::iota(map_dims.begin(), map_dims.end(), 0);
+ TF_ASSIGN_OR_RETURN(
+ Shape map_shape,
+ ShapeInference::InferMapShape(
+ operand_shapes, map_computation->ComputeProgramShape(), map_dims));
+ return computation->AddInstruction(
+ HloInstruction::CreateMap(map_shape, operands, map_computation));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
@@ -251,6 +274,38 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
return MakeReshapeHlo(output_shape, operand);
}
+StatusOr<HloInstruction*> InsertDegenerateDims(
+ HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
+ CHECK(c_is_sorted(dims_to_insert));
+
+ const Shape& operand_shape = operand->shape();
+ int64 output_shape_rank =
+ operand_shape.dimensions_size() + dims_to_insert.size();
+ for (auto dim_to_insert : dims_to_insert) {
+ CHECK_LT(dim_to_insert, output_shape_rank);
+ }
+
+ std::vector<int64> output_shape_dim_bounds;
+ output_shape_dim_bounds.reserve(output_shape_rank);
+ int64 operand_dims_idx = 0;
+ int64 dims_to_insert_idx = 0;
+ for (int64 i = 0; i < output_shape_rank; ++i) {
+ if (dims_to_insert_idx < dims_to_insert.size() &&
+ i == dims_to_insert[dims_to_insert_idx]) {
+ output_shape_dim_bounds.push_back(1);
+ ++dims_to_insert_idx;
+ } else {
+ output_shape_dim_bounds.push_back(
+ operand_shape.dimensions(operand_dims_idx));
+ ++operand_dims_idx;
+ }
+ }
+
+ Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
+ output_shape_dim_bounds);
+ return MakeReshapeHlo(output_shape, operand);
+}
+
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
int64 zeros_to_prepend,
int64 zeros_to_append) {
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 49b1402d68..5ff8946fb0 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -102,6 +102,12 @@ StatusOr<HloInstruction*> MakeConcatHlo(
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers);
+// Creates a Map HLO instruction and adds it to the computation containing the
+// operands. All operands must be in the same computation.
+StatusOr<HloInstruction*> MakeMapHlo(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation);
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
@@ -144,6 +150,16 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
+// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
+// exactly one element), `dims_to_insert` into `operand`. The dimensions in
+// `dims_to_insert` refer to the dimensions in the result, and hence should be
+// less than the rank of the result. Also, `dims_to_insert` must be sorted.
+//
+// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
+// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
+StatusOr<HloInstruction*> InsertDegenerateDims(
+ HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
+
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 71b44507cc..8e0d38b6a6 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() {
return TokKind::kLparen;
case ')':
return TokKind::kRparen;
- case '/':
- return LexComment();
+ case '/': {
+ if (PeekCurrentChar() == '*') {
+ // This is the start of a /*...*/ delimited comment. Save the current
+ // location in case the comment is unterminated so the error message
+ // will point to the beginning of the comment.
+ const char* comment_start = current_ptr_;
+ current_ptr_++;
+ // Advance until '*/' is found.
+ while (true) {
+ int current = GetNextChar();
+ if (current == '*' && PeekCurrentChar() == '/') {
+ // End of comment.
+ current_ptr_++;
+ break;
+ }
+ if (current == kEOF) {
+ // Unterminated comment.
+ current_ptr_ = comment_start;
+ return TokKind::kError;
+ }
+ }
+ // Return no token for the comment. Keep lexing.
+ continue;
+ } else if (PeekCurrentChar() == '/') {
+ // This is the start of a '//' delimited comment. Throw away
+ // everything until end of line or file. The end-of-line character(s)
+ // are left unlexed in the buffer which is harmless because these are
+ // skipped later by the lexer. This approach enables support for
+ // different end-of-line encodings.
+ while (true) {
+ int current = PeekCurrentChar();
+ if (current == kEOF || current == '\n' || current == '\r') {
+ break;
+ }
+ current_ptr_++;
+ }
+ continue;
+ }
+ // A lone '/' is an error.
+ return TokKind::kError;
+ }
case '"':
return LexString();
}
@@ -357,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
return StringPieceFromPointers(start, end);
}
-TokKind HloLexer::LexComment() {
- auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
- static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"};
- if (RE2::Consume(&consumable, *comment_pattern)) {
- current_ptr_ = consumable.begin();
- return TokKind::kComment;
- }
- return TokKind::kError;
-}
-
// Lexes quoted string with escaping characters. If matched, the quoted string
// will be unescaped and stored to str_val_.
TokKind HloLexer::LexString() {
@@ -412,8 +441,6 @@ string TokKindToString(TokKind kind) {
return "kRparen";
case TokKind::kArrow:
return "kArrow";
- case TokKind::kComment:
- return "kComment";
case TokKind::kw_HloModule:
return "kw_HloModule";
case TokKind::kw_ENTRY:
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index ceb674f25e..003ac34ace 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -105,7 +105,6 @@ class HloLexer {
TokKind LexShape();
TokKind LexConstant();
TokKind LexNumberOrPattern();
- TokKind LexComment();
TokKind LexString();
const tensorflow::StringPiece buf_;
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index b57c940238..c577b4359a 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -231,6 +231,7 @@ HLO_MATCHER(Tanh);
HLO_MATCHER(Trace);
HLO_MATCHER(Transpose);
HLO_MATCHER(Tuple);
+HLO_MATCHER(TupleSelect);
HLO_MATCHER(While);
// The special cases below let you check additional information about the
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 84f2d3f5fb..1b256cd00e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -166,7 +166,7 @@ class HloModuleGroupMetadata {
//
// Precondition: IsCompanionWhile(instruction) is true.
const std::unordered_set<HloInstruction*>& Companions(
- HloInstruction* instruction) const {
+ const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
@@ -243,7 +243,7 @@ class HloModuleGroupMetadata {
companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
- tensorflow::gtl::FlatMap<HloInstruction*, int64> companion_set_index_;
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
// Map from computation to the instruction using it (a kWhile, kConditional).
tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 9fd0ade153..0dc5676148 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -37,24 +38,38 @@ namespace xla {
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
- std::vector<HloInstruction*> predecessors;
-
- // Adds to the unique predecessors list and also add companion instructions
- // if the given predecessor has those.
+ std::vector<HloInstruction*>
+ predecessors; // Use a vector to avoid non-determinism.
+ tensorflow::gtl::FlatSet<HloInstruction*> unique;
+
+ // Adds to the unique predecessors list; if the predecessors is a companion
+ // instruction, also add companion instructions; if the predecessors is a
+ // cross-module all-reduce, also add the all-reduce instructions in the same
+ // group.
auto add_unique_predecessor = [&](HloInstruction* predecessor) {
- if (std::find(predecessors.begin(), predecessors.end(), predecessor) !=
- predecessors.end()) {
+ if (unique.find(predecessor) != unique.end()) {
return;
}
- if (!metadata_.IsCompanionInstruction(predecessor)) {
- predecessors.push_back(predecessor);
+ if (metadata_.IsCompanionInstruction(predecessor)) {
+ for (HloInstruction* instr : metadata_.Companions(predecessor)) {
+ if (unique.insert(instr).second) {
+ predecessors.push_back(instr);
+ }
+ }
return;
}
- for (HloInstruction* companion : metadata_.Companions(predecessor)) {
- predecessors.push_back(companion);
+ if (predecessor->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr :
+ metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
+ if (unique.insert(instr).second) {
+ predecessors.push_back(instr);
+ }
+ }
+ return;
}
+ unique.insert(predecessor);
+ predecessors.push_back(predecessor);
};
-
// If the given instruction is a companion instruction, we need to find the
// predecessors of all of its companion instructions. If the instruction is an
// all-reduce, we need to find the predecessors of all the peer all-reduce
@@ -98,22 +113,37 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
- std::vector<HloInstruction*> successors;
-
- // Adds to the unique successors list and also add companion instructions
- // if the given successor has those.
+ std::vector<HloInstruction*>
+ successors; // Use a vector to avoid non-determinism.
+ tensorflow::gtl::FlatSet<HloInstruction*> unique;
+
+ // Adds to the unique successors list; if the successor is a companion
+ // instruction, also add companion instructions; if the successor is a
+ // cross-module all-reduce, also add the all-reduce instructions in the same
+ // group.
auto add_unique_successor = [&](HloInstruction* successor) {
- if (std::find(successors.begin(), successors.end(), successor) !=
- successors.end()) {
+ if (unique.find(successor) != unique.end()) {
return;
}
- if (!metadata_.IsCompanionInstruction(successor)) {
- successors.push_back(successor);
+ if (metadata_.IsCompanionInstruction(successor)) {
+ for (HloInstruction* instr : metadata_.Companions(successor)) {
+ if (unique.insert(instr).second) {
+ successors.push_back(instr);
+ }
+ }
return;
}
- for (HloInstruction* companion : metadata_.Companions(successor)) {
- successors.push_back(companion);
+ if (successor->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr :
+ metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
+ if (unique.insert(instr).second) {
+ successors.push_back(instr);
+ }
+ }
+ return;
}
+ unique.insert(successor);
+ successors.push_back(successor);
};
// If the given instruction is a companion instruction, we need to find the
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 2a8c6ecd92..4b3cd99dc0 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1824,7 +1824,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
break;
}
case TokKind::kComma:
- case TokKind::kComment:
// Skip.
lexer_.Lex();
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 4cd21841f4..5990a3d478 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1560,6 +1560,81 @@ ENTRY consts {
"last");
}
+TEST_F(HloParserTest, Comments) {
+ const string original = R"(/* module description. */
+HloModule comments:
+
+ENTRY /*comment*/ c1 {
+ /* blah */
+ ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
+ /* comment */
+}
+
+/* something else */
+
+)";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, MultilineComments) {
+ const string original = R"(HloModule multiline_comment:
+ENTRY c1 {
+ /*
+ ROOT foo = f32[1]{0} constant({12345})
+ */
+ ROOT const1 = f32[1]{0} constant({12345})
+/*
+a
+b
+c
+d
+
+*/
+})";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, UnterminatedComment) {
+ const string original = R"(HloModule unterminated_comment:
+ENTRY c1 {
+/* unterminated
+ ROOT const1 = f32[1]{0} constant({12345})
+})";
+ // Verify that the error message points to the beginning of the unterminated
+ // comment.
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "/* unterminated\n^");
+}
+
+TEST_F(HloParserTest, SlashSlashComments) {
+ const string original = R"(HloModule slash_slash_comment:
+// Garbage
+ENTRY c1 {
+ // Foo bar
+ ROOT const1 = f32[1]{0} constant({12345}) // Something else
+})";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
+ const string original =
+ "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
+ "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
+ const string original =
+ "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
+ "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
TEST_F(HloParserTest, MultipleEntries) {
const string original = R"(HloModule multiple_entries:
ENTRY c1 {
diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h
index 533429608b..4458c251de 100644
--- a/tensorflow/compiler/xla/service/hlo_token.h
+++ b/tensorflow/compiler/xla/service/hlo_token.h
@@ -44,7 +44,6 @@ enum class TokKind {
kRparen, // ( )
kArrow, // ->
- kComment, // /*xxx*/
// Keywords
kw_HloModule,
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index 9b109022fb..db6b910b32 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
}
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return false; }
+ bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
return false;
}
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
new file mode 100644
index 0000000000..45ca731153
--- /dev/null
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -0,0 +1,350 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/scatter_expander.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/while_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+using tensorflow::gtl::ArraySlice;
+
+// Transposes the given scatter_indices such that the index_vector_dim becomes
+// the most-minor dimension.
+static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
+ HloInstruction* scatter_indices, int64 index_vector_dim) {
+ const Shape& scatter_indices_shape = scatter_indices->shape();
+
+ if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
+ return scatter_indices;
+ }
+
+ if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
+ return scatter_indices;
+ }
+
+ std::vector<int64> permutation;
+ permutation.reserve(scatter_indices_shape.dimensions_size());
+ for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
+ if (i != index_vector_dim) {
+ permutation.push_back(i);
+ }
+ }
+ permutation.push_back(index_vector_dim);
+ return MakeTransposeHlo(scatter_indices, permutation);
+}
+
+// Canonicalizes the scatter_indices tensor in order to keep them uniform while
+// performing the scatter operation.
+static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
+ HloInstruction* scatter_indices, int64 index_vector_dim) {
+ // Transpose the non-index-vector dimensions to the front.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * transposed_scatter_indices,
+ TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
+ bool indices_are_scalar =
+ index_vector_dim == scatter_indices->shape().dimensions_size();
+
+ // The number of dimensions in scatter_indices that are index dimensions.
+ const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
+
+ // If there is only one index (i.e. scatter_indices has rank 1 and this
+ // scatter is really just a dynamic update slice) add a leading degenerate
+ // dimension for uniformity. Otherwise create a "collapsed" leading dimension
+ // that subsumes all of the non-index-vector dimensions.
+ const Shape& shape = transposed_scatter_indices->shape();
+ if (shape.dimensions_size() == index_dims_in_scatter_indices) {
+ return PrependDegenerateDims(transposed_scatter_indices, 1);
+ } else {
+ // Collapse all but the dimensions (0 or 1) in scatter_indices containing
+ // the index vectors.
+ return CollapseFirstNDims(
+ transposed_scatter_indices,
+ shape.dimensions_size() - index_dims_in_scatter_indices);
+ }
+}
+
+// Permutes the `updates` tensor such that all the scatter dims appear in the
+// major dimensions and all the window dimensions appear in the minor
+// dimensions.
+static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
+ HloInstruction* updates, ArraySlice<int64> update_window_dims) {
+ std::vector<int64> permutation;
+ const int64 updates_rank = ShapeUtil::Rank(updates->shape());
+ permutation.reserve(updates_rank);
+
+ for (int64 i = 0; i < updates_rank; ++i) {
+ bool is_scatter_dim = !c_binary_search(update_window_dims, i);
+ if (is_scatter_dim) {
+ permutation.push_back(i);
+ }
+ }
+ for (auto window_dim : update_window_dims) {
+ permutation.push_back(window_dim);
+ }
+
+ return MakeTransposeHlo(updates, permutation);
+}
+
+// Expands or contracts the scatter indices in the updates tensor.
+static StatusOr<HloInstruction*> AdjustScatterDims(
+ const Shape& scatter_indices_shape, HloInstruction* updates,
+ int64 index_vector_dim) {
+ int64 num_scatter_dims = scatter_indices_shape.dimensions_size();
+ if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
+ --num_scatter_dims;
+ }
+ if (num_scatter_dims == 0) {
+ // If there are no scatter dims, this must be a dynamic-update-slice kind of
+ // scatter. In this case, we prepend a degenerate dimension to work
+ // uniformly in the while loop.
+ return PrependDegenerateDims(updates, 1);
+ }
+ return CollapseFirstNDims(updates, num_scatter_dims);
+}
+
+// Expands an index vector from the scatter_indices tensor into a vector that
+// can be used to dynamic-update-slice to perform the scatter update.
+static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
+ HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
+ int64 operand_rank) {
+ HloComputation* computation = index_vector->parent();
+ const Shape& index_shape = index_vector->shape();
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
+
+ // We extract out individual components from the smaller index and concatenate
+ // them (interspersing zeros as needed) into the larger index.
+ std::vector<HloInstruction*> expanded_index_components;
+
+ for (int i = 0; i < operand_rank; i++) {
+ int64 index_vector_dim_index =
+ FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
+ if (index_vector_dim_index !=
+ dim_numbers.scatter_dims_to_operand_dims_size()) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * component_to_concat,
+ MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
+ /*limit_indices=*/{index_vector_dim_index + 1},
+ /*strides=*/{1}));
+ expanded_index_components.push_back(component_to_concat);
+ } else {
+ expanded_index_components.push_back(zero);
+ }
+ }
+
+ return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
+}
+
+// Body of the while loop that performs the scatter operation using other HLOs.
+static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
+ HloInstruction* scatter, HloInstruction* induction_var,
+ const std::vector<HloInstruction*>& loop_state) {
+ const ScatterDimensionNumbers& dim_numbers =
+ scatter->scatter_dimension_numbers();
+ CHECK_EQ(loop_state.size(), 3);
+ HloInstruction* operand = loop_state[0];
+ HloInstruction* scatter_indices = loop_state[1];
+ HloInstruction* updates = loop_state[2];
+
+ bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
+ CHECK_EQ(has_scalar_indices,
+ dim_numbers.index_vector_dim() ==
+ scatter->operand(1)->shape().dimensions_size());
+
+ // Build a vector form of the induction variable of the while loop.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * induction_var_as_vector,
+ MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
+ /*result_shape_bounds=*/{1}));
+
+ // Pick the index to scatter from scatter_indices based on the induction_var
+ // and transform that to an index into the `operand` space.
+ HloInstruction* index_vector;
+ if (has_scalar_indices) {
+ TF_ASSIGN_OR_RETURN(
+ index_vector,
+ MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_into_scatter_indices,
+ PadVectorWithZeros(induction_var_as_vector,
+ /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
+ int index_vector_size = scatter_indices->shape().dimensions(1);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_vector_2d,
+ MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
+ {1, index_vector_size}));
+ TF_ASSIGN_OR_RETURN(index_vector,
+ ElideDegenerateDims(index_vector_2d, {0}));
+ }
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * scatter_slice_start,
+ ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
+ operand->shape().dimensions_size()));
+
+ // Extract the slice to be used to update from `updates` tensor for the
+ // induction_var corresponding to this iteration of the while loop.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_into_updates,
+ PadVectorWithZeros(
+ induction_var_as_vector, /*zeros_to_prepend=*/0,
+ /*zeros_to_append=*/updates->shape().dimensions_size() - 1));
+ std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(),
+ updates->shape().dimensions().end());
+ update_slice_bounds[0] = 1;
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * update_slice,
+ MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds));
+ TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
+ ElideDegenerateDims(update_slice, {0}));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * update_slice_with_dims_inserted,
+ InsertDegenerateDims(update_slice_for_scatter,
+ AsInt64Slice(dim_numbers.inserted_window_dims())));
+
+ // Extact the slice to update from `operand` tensor.
+ const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * operand_slice_to_update,
+ MakeDynamicSliceHlo(operand, scatter_slice_start,
+ AsInt64Slice(update_slice_shape.dimensions())));
+
+ // Compute the new value for the slice to be updated in `operand` tensor by
+ // combining the existing value and the update value using the update
+ // computation.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * updated_operand_slice,
+ MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
+ scatter->to_apply()));
+
+ // Write the updated value of the slice into `operand` tensor.
+ TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
+ MakeDynamicUpdateSliceHlo(operand, updated_operand_slice,
+ scatter_slice_start));
+
+ return StatusOr<std::vector<HloInstruction*>>{
+ {updated_operand, scatter_indices, updates}};
+}
+
+// High Level Algorithm.
+//
+// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
+// each row is an index into the operand.
+// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
+// and the scatter dim is the most-major dimension.
+// 3. Iterate over the set of indices in the canonicalized scatter_indices
+// tensor using a while loop, updating the operand for each such index. Each
+// iteration of this while loop performs the following:
+// a. Pick the index from scatter_indices for this iteration.
+// b. Transfrom this index into an index into the operand space.
+// c. Extract the slice to be used to update from the updates tensor.
+// d. Extract the slice to update from the operand tensor.
+// e. Compute the new value for the slice to update by combining the slices
+// from c. and d. using the update_computation of scatter.
+// f. Write the updated value of the slice into the operand tensor.
+
+StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
+ HloInstruction* scatter) {
+ HloInstruction* operand = scatter->mutable_operand(0);
+ HloInstruction* scatter_indices = scatter->mutable_operand(1);
+ HloInstruction* updates = scatter->mutable_operand(2);
+ const ScatterDimensionNumbers& dim_numbers =
+ scatter->scatter_dimension_numbers();
+
+ // If the updates tensor is empty, there is no need to update the operand. We
+ // can return the operand as is.
+ if (ShapeUtil::IsZeroElementArray(updates->shape())) {
+ return operand;
+ }
+
+ // Compute the trip count for the while loop to be used for scatter. This
+ // should be the number of indices we should scatter into the operand.
+ const Shape& scatter_indices_shape = scatter_indices->shape();
+ int64 scatter_loop_trip_count = 1;
+ for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
+ if (i != dim_numbers.index_vector_dim()) {
+ scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
+ }
+ }
+ if (!IsInt32(scatter_loop_trip_count)) {
+ return Unimplemented(
+ "Scatter operations with more than 2147483647 scatter indices are not "
+ "supported. This error occurred for %s.",
+ scatter->ToString().c_str());
+ }
+
+ // Canonicalize the scatter_indices, after which the size of its most-major
+ // dimension must be same as the while loop trip count.
+ TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
+ CanonicalizeScatterIndices(
+ scatter_indices, dim_numbers.index_vector_dim()));
+ CHECK_EQ(scatter_loop_trip_count,
+ canonical_scatter_indices->shape().dimensions(0));
+
+ // Canonicalize the updates, after which the size of its most-major dimension
+ // must be same as the while loop trip count.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * canonical_updates,
+ PermuteScatterAndWindowDims(
+ updates, AsInt64Slice(dim_numbers.update_window_dims())));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * adjusted_canonical_updates,
+ AdjustScatterDims(scatter_indices->shape(), canonical_updates,
+ dim_numbers.index_vector_dim()));
+ CHECK_EQ(scatter_loop_trip_count,
+ adjusted_canonical_updates->shape().dimensions(0));
+
+ // The while loop that implements the scatter operation.
+ StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
+ WhileUtil::MakeCountedLoop(
+ scatter->parent(), scatter_loop_trip_count,
+ {operand, canonical_scatter_indices, adjusted_canonical_updates},
+ [&](HloInstruction* induction_var,
+ const std::vector<HloInstruction*>& loop_state) {
+ return ScatterLoopBody(scatter, induction_var, loop_state);
+ });
+ TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
+ scatter_loop_result_status);
+ return scatter_loop_result.front();
+}
+
+StatusOr<bool> ScatterExpander::Run(HloModule* module) {
+ std::vector<HloInstruction*> scatter_instrs;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ for (HloInstruction* instr : computation->instructions()) {
+ if (instr->opcode() == HloOpcode::kScatter) {
+ scatter_instrs.push_back(instr);
+ }
+ }
+ }
+
+ for (auto instr : scatter_instrs) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
+ TF_RETURN_IF_ERROR(
+ instr->parent()->ReplaceInstruction(instr, expanded_root));
+ }
+
+ return !scatter_instrs.empty();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
new file mode 100644
index 0000000000..8f735e877d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+class ScatterExpander : public HloPassInterface {
+ public:
+ tensorflow::StringPiece name() const override { return "scatter_expander"; }
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ StatusOr<HloInstruction*> ExpandScatter(HloInstruction* scatter);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 34869cc507..b69c346f1e 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
}
/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
+ if (!IsTuple(shape)) {
+ return 1;
+ }
int64 count = 0;
- ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) {
- if (IsLeafIndex(shape, index)) {
- ++count;
- }
- });
+ for (const Shape& subshape : shape.tuple_shapes()) {
+ count += GetLeafCount(subshape);
+ }
return count;
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 42d52aee78..0f8cffd466 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -709,6 +709,21 @@ xla_test(
],
)
+xla_test(
+ name = "scatter_test",
+ srcs = ["scatter_test.cc"],
+ deps = [
+ ":client_library_test_base",
+ ":hlo_test_base",
+ "//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
# Repeat dot_operation_runtime_test with single-threaded eigen.
xla_test(
name = "dot_operation_single_threaded_runtime_test",
@@ -2061,6 +2076,8 @@ tf_cc_test(
xla_test(
name = "test_utils_test",
srcs = ["test_utils_test.cc"],
+ # There is nothing backend specific in this test, so just pick an arbitrary backend.
+ backends = ["cpu"],
deps = [
":local_client_test_base",
":test_utils",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 4a6e8a3124..b04a3b105c 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test {
string TestName() const;
void SetFastMathDisabled(bool disabled) {
- execution_options_.mutable_debug_options()->set_xla_enable_fast_math(
- !disabled);
+ auto* opts = execution_options_.mutable_debug_options();
+ opts->set_xla_cpu_enable_fast_math(!disabled);
+ opts->set_xla_gpu_enable_fast_math(!disabled);
}
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 73edad89dc..92c93f08b2 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1464,5 +1464,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
}
+XLA_TEST_F(HloTestBase, ReduceWindowF16) {
+ const string hlo_string = R"(
+HloModule reduce-window
+
+%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
+ %param0 = f16[] parameter(0)
+ ROOT %param1 = f16[] parameter(1)
+}
+
+ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
+ %parameter.0 = f16[81,8]{1,0} parameter(0)
+ %parameter.1 = f16[] parameter(1)
+ ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
+}
+
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
new file mode 100644
index 0000000000..922d70b752
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -0,0 +1,615 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::gtl::nullopt;
+
+class ScatterTest : public HloTestBase {
+ protected:
+ void RunTest(const string& hlo_text, Literal* operand,
+ Literal* scatter_indices, Literal* updates) {
+ RunTest(hlo_text, {operand, scatter_indices, updates});
+ }
+
+ void RunTest(const string& hlo_text,
+ tensorflow::gtl::ArraySlice<Literal*> args) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
+ }
+};
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterV2
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[3,2] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={0},
+ inserted_window_dims={1},
+ scatter_dims_to_operand_dims={1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_Add
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_Mul
+
+mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=mul_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_F32
+
+add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(f32[] lhs, f32[] rhs)
+}
+
+ENTRY main {
+ operand = f32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = f32[2,3] parameter(2)
+ ROOT scatter = f32[3,3] scatter(operand, indices, updates),
+ to_apply=add_f32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({2, 1});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterMultipleBatchDims
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,3,2] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={1},
+ scatter_dims_to_operand_dims={1},
+ index_vector_dim=2
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0,1},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0,1},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
+ const char* hlo_text = R"(
+HloModule DynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[1,1] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={0,1},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
+ const char* hlo_text = R"(
+HloModule BatchDynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,1,1] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, ZeroDimBounds) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter_ZeroDimBounds
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,0] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,0] parameter(2)
+ ROOT scatter = s32[3,0] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
+ const string hlo_text = R"(
+HloModule Scatter_NoUpdateWindowDims
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[2,2,1] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=2
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = u32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, NegativeIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OneScalarIndex) {
+ const char* hlo_text = R"(
+HloModule OneScalarIndex
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[2,3,2]{2,1,0} parameter(0)
+ index = s32[] parameter(1)
+ updates = s32[1,3,2]{2,1,0} parameter(2)
+ ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates),
+ to_apply=update_s32,
+ update_window_dims={0,1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, ScalarUpdate) {
+ const char* hlo_text = R"(
+HloModule ScalarUpdate
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[4]{0} parameter(0)
+ index = s32[] parameter(1)
+ updates = s32[] parameter(2)
+ ROOT scatter = s32[4]{0} scatter(operand, index, updates),
+ to_apply=update_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, EmptyIndices) {
+ const string hlo_text = R"(
+HloModule EmptyIndices
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[0] parameter(1)
+ updates = s32[0] parameter(2)
+ ROOT scatter = s32[3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2647937013..faeec657b6 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
- const Shape& input_shape, const Shape& slice_shape,
- std::minstd_rand0* engine) {
- const int64 rank = ShapeUtil::Rank(input_shape);
- std::vector<int32> start_indices(rank);
+std::unique_ptr<Literal> MakeRandomIndex(
+ tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
+ std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
- for (int i = 0; i < rank; ++i) {
- const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
- ShapeUtil::GetDimension(slice_shape, i);
- std::uniform_int_distribution<int32> generator(0, upper_bound);
+ for (int i = 0; i < index_space.size(); ++i) {
+ std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
@@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
- HloInstruction* needs_index = nullptr;
- HloInstruction* needs_constant = nullptr;
+ std::vector<int64> index_space;
+ bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr) {
- auto needs_index_shape = needs_index->shape();
- auto use_shape = use->shape();
- if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
- needs_index_shape = needs_index->operand(0)->shape();
- }
- if (use->opcode() == HloOpcode::kDynamicSlice) {
- use_shape = use->operand(0)->shape();
+ case HloOpcode::kDynamicUpdateSlice: {
+ const Shape& indexed_shape = use->operand(0)->shape();
+ const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
+ ? use->shape()
+ : use->operand(1)->shape();
+ const int64 rank = ShapeUtil::Rank(indexed_shape);
+ if (!index_space.empty()) {
+ TF_RET_CHECK(rank == index_space.size());
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = std::min(
+ index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i));
}
- if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ } else {
+ index_space.resize(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i);
}
}
- needs_index = use;
break;
+ }
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->to_apply());
break;
case HloOpcode::kSelectAndScatter:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->scatter());
break;
@@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_constant != nullptr) {
+ if (!index_space.empty() && needs_constant) {
return Unimplemented(
- "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "constant: %s\n",
- needs_index->ToString().c_str(), needs_constant->ToString().c_str());
+ "Conflicting operand generation constraints. Dynamically indexes a "
+ "shape and is the init value of a reduction.");
}
- if (needs_index != nullptr) {
- return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
- needs_index->shape(), engine);
- } else if (needs_constant != nullptr) {
+ if (!index_space.empty()) {
+ return MakeRandomIndex(index_space, engine);
+ } else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
@@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
- *dataflow, *params[i], engine.get()));
+ arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
+ .ValueOrDie();
}
return std::move(arguments);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index a2f0338e25..64d9e2031e 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) {
TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
}
+XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
+ auto module = ParseHloString(
+ R"(HloModule index_space_module
+
+ ENTRY IndexSpace {
+ index_param = s32[3]{0} parameter(0)
+ array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
+ array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
+ dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
+ ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 3);
+ const Literal& index_arg = *args[0];
+
+ EXPECT_EQ(index_arg.Get<int32>({0}), 0);
+
+ EXPECT_GE(index_arg.Get<int32>({1}), 0);
+ EXPECT_LE(index_arg.Get<int32>({1}), 2);
+
+ EXPECT_GE(index_arg.Get<int32>({2}), 0);
+ EXPECT_LE(index_arg.Get<int32>({2}), 3);
+}
+
+XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
+ auto module = ParseHloString(
+ R"(HloModule index_space_module
+
+ ENTRY IndexSpace {
+ index_param = s32[3]{0} parameter(0)
+ array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
+ array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
+ update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
+ update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
+
+ dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
+ ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 5);
+ const Literal& index_arg = *args[0];
+
+ EXPECT_EQ(index_arg.Get<int32>({0}), 0);
+
+ EXPECT_GE(index_arg.Get<int32>({1}), 0);
+ EXPECT_LE(index_arg.Get<int32>({1}), 2);
+
+ EXPECT_GE(index_arg.Get<int32>({2}), 0);
+ EXPECT_LE(index_arg.Get<int32>({2}), 3);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 10c0adc670..3b72eb17c6 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -104,15 +104,6 @@ message DebugOptions {
// interpretation of this value is left to the backends.
int32 xla_backend_optimization_level = 31;
- // When true, "unsafe" mathematical optimizations are enabled. These
- // transformations include but are not limited to:
- //
- // - Reducing the precision of operations (e.g. using an approximate sin
- // function, or transforming x/y into x * (1/y)).
- // - Assuming that operations never produce or consume NaN or +/- Inf.
- // - Assuming that +0 and -0 are indistinguishable.
- bool xla_enable_fast_math = 32;
-
// Embed the compiler IR as a string in the executable.
bool xla_embed_ir_in_executable = 33;
@@ -194,6 +185,16 @@ message DebugOptions {
// Maximum kernel unroll factor for the GPU backend.
int32 xla_gpu_max_kernel_unroll_factor = 98;
+ // When true, "unsafe" mathematical optimizations are enabled. These
+ // transformations include but are not limited to:
+ //
+ // - Reducing the precision of operations (e.g. using an approximate sin
+ // function, or transforming x/y into x * (1/y)).
+ // - Assuming that operations never produce or consume NaN or +/- Inf.
+ // - Assuming that +0 and -0 are indistinguishable.
+ bool xla_cpu_enable_fast_math = 99;
+ bool xla_gpu_enable_fast_math = 100;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index 1790b4bc11..a25a641cdb 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -218,11 +218,11 @@ class ToBigtableOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
- dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator),
+ dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator",
+ &iterator),
done);
int64 timestamp_int;
@@ -245,9 +245,10 @@ class ToBigtableOp : public AsyncOpKernel {
::google::cloud::bigtable::BulkMutation mutation;
// TODO(saeta): Make # of mutations configurable.
for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
+ OP_REQUIRES_OK_ASYNC(ctx,
+ iterator->GetNext(IteratorContext(ctx),
+ &components, &end_of_sequence),
+ done);
if (!end_of_sequence) {
OP_REQUIRES_OK_ASYNC(
ctx,
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index 9e49fa35db..bd32672aa9 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -53,7 +53,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
BigtableTableResource* table,
@@ -61,7 +61,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
std::vector<string> columns,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
table_(table),
column_families_(std::move(column_families)),
@@ -80,8 +80,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::BigtableLookupDataset")}));
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -96,6 +96,14 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
return "BigtableLookupDatasetOp::Dataset";
}
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
static ::google::cloud::bigtable::Filter MakeFilter(
const std::vector<string>& column_families,
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index e960719614..a803fdcb49 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -35,11 +35,13 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix)
- : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) {
+ : DatasetBase(DatasetContext(ctx)),
+ table_(table),
+ prefix_(std::move(prefix)) {
table_->Ref();
}
@@ -47,8 +49,8 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")}));
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -68,6 +70,14 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 96d3565d9b..5cd0371c79 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -39,11 +39,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string start_key, string end_key)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
table_(table),
start_key_(std::move(start_key)),
end_key_(std::move(end_key)) {
@@ -54,8 +54,8 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")}));
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -75,6 +75,14 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index a1a63a975a..6928d9423c 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -52,11 +52,11 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
table_(table),
key_range_(MakeMultiModeKeyRange(
std::move(prefix), std::move(start_key), std::move(end_key))) {
@@ -68,7 +68,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")}));
+ {this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -87,6 +87,14 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
return "BigtableSampleKeyPairsDatasetOp::Dataset";
}
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
string start_key,
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index a5a47cfe2d..a759fb5063 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -31,10 +31,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
- : GraphDatasetBase(ctx), table_(table) {
+ : DatasetBase(DatasetContext(ctx)), table_(table) {
table_->Ref();
}
@@ -43,7 +43,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")}));
+ {this, strings::StrCat(prefix, "::BigtableSampleKeys")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -63,6 +63,14 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index 13cb868167..78a920b077 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -84,7 +84,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key,
@@ -92,7 +92,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
std::vector<string> columns, float probability,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
table_(table),
prefix_(std::move(prefix)),
start_key_(std::move(start_key)),
@@ -111,8 +111,8 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::BigtableScanDataset")}));
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::BigtableScan")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -129,6 +129,14 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 401bec84a2..d9e7a0f466 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -34,7 +34,9 @@
namespace tensorflow {
+using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearnerConfig_MultiClassStrategy;
+using boosted_trees::learner::ObliviousSplitInfo;
using boosted_trees::learner::SplitInfo;
using boosted_trees::learner::stochastic::GradientStats;
using boosted_trees::learner::stochastic::NodeStats;
@@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
- Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer
+ int32 size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+ Tensor* gains_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ &state, normalizer_ratio, num_elements, partition_boundaries,
+ bucket_boundaries, partition_ids, bucket_ids, gradients_t,
+ hessians_t, &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ &state, normalizer_ratio, num_elements, partition_boundaries,
+ bucket_boundaries, partition_ids, bucket_ids, gradients_t,
+ hessians_t, &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ SplitBuilderState* state, const float normalizer_ratio,
+ const int num_elements, const std::vector<int32>& partition_boundaries,
+ const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
@@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -237,20 +283,124 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
- dense_split->set_feature_column(state.feature_column_group_id());
+ dense_split->set_feature_column(state->feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ }
+ }
+ void ComputeObliviousDecisionTree(
+ SplitBuilderState* state, const float normalizer_ratio,
+ const int num_elements, const std::vector<int32>& partition_boundaries,
+ const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ GradientStats root_gradient_stats;
+ for (int64 bucket_idx = start_index; bucket_idx < end_index;
+ ++bucket_idx) {
+ root_gradient_stats +=
+ GradientStats(*gradients_t, *hessians_t, bucket_idx);
+ }
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
+ }
+
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_bucket_idx = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_bucket_id = 0;
+ int64 last_bucket_id = -1;
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current bucket we consider.
+ std::vector<int> current_layer_offsets(num_elements, 0);
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ // The idea is to try every bucket id in increasing order. In each iteration
+ // we calculate the gain of the layer using the current bucket id as split
+ // value, and we also obtain the following bucket id to try.
+ while (current_bucket_id > last_bucket_id) {
+ last_bucket_id = current_bucket_id;
+ int64 next_bucket_id = -1;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] += g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) {
+ next_bucket_id = bucket_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ }
+ current_bucket_id = next_bucket_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
+ ->mutable_dense_float_binary_split();
+ oblivious_dense_split->set_feature_column(state->feature_column_group_id());
+ oblivious_dense_split->set_threshold(
+ bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_children = oblivious_split_info.add_children_leaves();
+ auto* right_children = oblivious_split_info.add_children_leaves();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_children);
+ state->FillLeaf(best_right_node_stats[root_idx], right_children);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
}
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
};
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 2559fe9913..f45010ec26 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -64,6 +64,7 @@ from __future__ import print_function
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
@@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(DenseSplitHandler, self).__init__(
@@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy=multiclass_strategy,
loss_uses_sum_reduction=loss_uses_sum_reduction)
self._dense_float_column = dense_float_column
+ self._weak_learner_type = weak_learner_type
# Register dense_make_stats_update function as an Op to the graph.
g = ops.get_default_graph()
dense_make_stats_update.add_to_graph(g)
@@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
- self._min_node_weight, self._loss_uses_sum_reduction))
-
+ self._min_node_weight, self._loss_uses_sum_reduction,
+ self._weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
-def _make_dense_split(
- quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
- next_stamp_token, multiclass_strategy, class_id, feature_column_id,
- l1_regularization, l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional,
+ loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a dense feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@@ -327,7 +332,8 @@ def _make_dense_split(
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
- multiclass_strategy=multiclass_strategy))
+ multiclass_strategy=multiclass_strategy,
+ weak_learner_type=weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
@@ -507,7 +513,40 @@ def _make_sparse_split(
return are_splits_ready, partition_ids, gains, split_infos
-def _specialize_make_split(func, is_multi_dimentional):
+def _specialize_make_split_dense(func, is_multi_dimentional):
+ """Builds a specialized version of the function."""
+
+ @function.Defun(
+ dtypes.resource,
+ dtypes.resource,
+ dtypes.int64,
+ dtypes.int64,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.bool,
+ dtypes.int32,
+ noinline=True)
+ def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, loss_uses_sum_reduction, weak_learner_type):
+ """Function that builds splits for a sparse feature column."""
+ return func(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy, class_id,
+ feature_column_id, l1_regularization, l2_regularization,
+ tree_complexity_regularization, min_node_weight,
+ is_multi_dimentional, loss_uses_sum_reduction,
+ weak_learner_type)
+
+ return f
+
+
+def _specialize_make_split_sparse(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
@@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional):
return f
-make_dense_split_scalar = _specialize_make_split(_make_dense_split,
- is_multi_dimentional=False)
-make_dense_split_tensor = _specialize_make_split(_make_dense_split,
- is_multi_dimentional=True)
-make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
- is_multi_dimentional=False)
-make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
- is_multi_dimentional=True)
+make_dense_split_scalar = _specialize_make_split_dense(
+ _make_dense_split, is_multi_dimentional=False)
+
+make_dense_split_tensor = _specialize_make_split_dense(
+ _make_dense_split, is_multi_dimentional=True)
+
+make_sparse_split_scalar = _specialize_make_split_sparse(
+ _make_sparse_split, is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split_sparse(
+ _make_sparse_split, is_multi_dimentional=True)
@function.Defun(
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 5d82c4cae5..6572f2f414 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
+ def testObliviousFeatureSplitGeneration(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Dense Quantile |
+ # i0 | (0.2, 0.12) | 0 | 2 |
+ # i1 | (-0.5, 0.07) | 0 | 2 |
+ # i2 | (1.2, 0.2) | 0 | 0 |
+ # i3 | (4.0, 0.13) | 1 | 1 |
+ dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ class_id = -1
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ split_handler = ordinal_split_handler.DenseSplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
+ epsilon=0.001,
+ num_quantiles=10,
+ feature_column_group_id=0,
+ dense_float_column=dense_column,
+ init_stamp_token=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
+ with ops.control_dependencies([are_splits_ready]):
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_2]):
+ are_splits_ready2, partitions, gains, splits = (
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
+ are_splits_ready, are_splits_ready2, partitions, gains, splits = (
+ sess.run([
+ are_splits_ready, are_splits_ready2, partitions, gains, splits
+ ]))
+
+ # During the first iteration, inequality split handlers are not going to
+ # have any splits. Make sure that we return not_ready in that case.
+ self.assertFalse(are_splits_ready)
+ self.assertTrue(are_splits_ready2)
+
+ self.assertAllEqual([0, 1], partitions)
+
+ oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
+ oblivious_split_info.ParseFromString(splits[0])
+ split_node = oblivious_split_info.split_node.dense_float_binary_split
+
+ self.assertAllClose(0.3, split_node.threshold, 0.00001)
+ self.assertEqual(0, split_node.feature_column)
+
+ # Check the split on partition 0.
+ # -(1.2 - 0.1) / (0.2 + 1)
+ expected_left_weight_0 = -0.9166666666666666
+
+ # expected_left_weight_0 * -(1.2 - 0.1)
+ expected_left_gain_0 = 1.008333333333333
+
+ # (-0.5 + 0.2 + 0.1) / (0.19 + 1)
+ expected_right_weight_0 = 0.1680672
+
+ # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
+ expected_right_gain_0 = 0.033613445378151252
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain_0 = 0.46043165467625896
+
+ left_child = oblivious_split_info.children_leaves[0].vector
+ right_child = oblivious_split_info.children_leaves[1].vector
+
+ self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
+
+ # Check the split on partition 1.
+ expected_left_weight_1 = 0
+ expected_left_gain_1 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight_1 = -3.4513274336283186
+ # expected_right_weight_1 * -(4 - 0.1)
+ expected_right_gain_1 = 13.460176991150442
+ # (-4 + 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain_1 = 13.460176991150442
+
+ left_child = oblivious_split_info.children_leaves[2].vector
+ right_child = oblivious_split_info.children_leaves[3].vector
+
+ self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
+
+ # The layer gain is the sum of the gains of each partition
+ layer_gain = (
+ expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
+ expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
+ self.assertAllClose(layer_gain, gains[0], 0.00001)
+
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following:
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index ca5c7f3d8c..9b68a9de96 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
+ .Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
+weak_learner_type: A scalar, specifying the weak learner type to use.
+ See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto
index d84ba7438e..c49cb48cde 100644
--- a/tensorflow/contrib/boosted_trees/proto/learner.proto
+++ b/tensorflow/contrib/boosted_trees/proto/learner.proto
@@ -108,6 +108,11 @@ message LearnerConfig {
DIAGONAL_HESSIAN = 3;
}
+ enum WeakLearnerType {
+ NORMAL_DECISION_TREE = 0;
+ OBLIVIOUS_DECISION_TREE = 1;
+ }
+
// Number of classes.
uint32 num_classes = 1;
@@ -141,4 +146,7 @@ message LearnerConfig {
// If you want to average the ensembles (for regularization), provide the
// config below.
AveragingConfig averaging_config = 11;
+
+ // By default we use NORMAL_DECISION_TREE as weak learner.
+ WeakLearnerType weak_learner_type = 12;
}
diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto
index a300c24c8e..850340f5c2 100644
--- a/tensorflow/contrib/boosted_trees/proto/split_info.proto
+++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto
@@ -17,3 +17,10 @@ message SplitInfo {
// Right Leaf node.
tensorflow.boosted_trees.trees.Leaf right_child = 3;
}
+
+message ObliviousSplitInfo {
+ // The split node with the feature_column and threshold defined.
+ tensorflow.boosted_trees.trees.TreeNode split_node = 1;
+ // The new leaves of the tree.
+ repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
+}
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 5cd37ec67e..2589504762 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
# .assertEmpty doesn't exist on ubuntu-contrib
self.assertEqual(0, len(partitions))
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index d0d1249bd6..20ff48c360 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.constraints.min_node_weight, dtypes.float32)
loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
+ weak_learner_type = constant_op.constant(
+ self._learner_config.weak_learner_type)
epsilon = 0.01
num_quantiles = 100
strategy_tensor = constant_op.constant(strategy)
@@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object):
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type,
))
fc_name_idx += 1
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 2fbaa31d5e..e92f0bb841 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -31,6 +31,9 @@ Checkpointable data structures:
@@List
@@Mapping
@@UniqueNameTracker
+
+Checkpoint management:
+@@CheckpointManager
"""
from __future__ import absolute_import
@@ -41,6 +44,7 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
+from tensorflow.python.training.checkpoint_management import CheckpointManager
from tensorflow.python.training.checkpointable.base import CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py
index ac86a6741b..66d7ebed74 100644
--- a/tensorflow/contrib/constrained_optimization/python/candidates.py
+++ b/tensorflow/contrib/constrained_optimization/python/candidates.py
@@ -204,7 +204,7 @@ def find_best_candidate_distribution(objective_vector,
assert best_pp is not None
# Throughout this loop, a maximum_violation of "lower" is not achievable,
- # but a maximum_violation of "upper" is achiveable.
+ # but a maximum_violation of "upper" is achievable.
while True:
middle = 0.5 * (lower + upper)
if (middle - lower <= epsilon) or (upper - middle <= epsilon):
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
index 70813fb217..41258edd90 100644
--- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
@@ -72,7 +72,8 @@ class ConstrainedMinimizationProblem(object):
else:
proxy_constraints_shape = self.proxy_constraints.get_shape()
- if (constraints_shape is None or proxy_constraints_shape is None or
+ if (constraints_shape.ndims is None or
+ proxy_constraints_shape.ndims is None or
any([ii is None for ii in constraints_shape.as_list()]) or
any([ii is None for ii in proxy_constraints_shape.as_list()])):
raise ValueError(
@@ -121,3 +122,19 @@ class ConstrainedMinimizationProblem(object):
A tensor of proxy constraint functions.
"""
return None
+
+ # This is a property, instead of an abstract property, since it doesn't need
+ # to be overridden: if pre_train_ops returns None, then there are no ops to
+ # run before train_op.
+ @property
+ def pre_train_ops(self):
+ """Returns a list of `Operation`s to run before the train_op.
+
+ When a `ConstrainedOptimizer` creates a train_op (in `minimize`
+ `minimize_unconstrained`, or `minimize_constrained`), it will include these
+ ops before the main training step.
+
+ Returns:
+ A list of `Operation`s.
+ """
+ return None
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
index 8055545366..0b79bdf7c0 100644
--- a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
@@ -55,20 +55,21 @@ class ConstrainedOptimizer(object):
"""Returns the `tf.train.Optimizer` used for optimization."""
return self._optimizer
- def minimize_unconstrained(self,
- minimization_problem,
- global_step=None,
- var_list=None,
- gate_gradients=train_optimizer.Optimizer.GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- name=None,
- grad_loss=None):
- """Returns an `Op` for minimizing the unconstrained problem.
+ @abc.abstractmethod
+ def _minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Version of `minimize_constrained` to be overridden by subclasses.
- Unlike `minimize_constrained`, this function ignores the `constraints` (and
- `proxy_constraints`) portion of the minimization problem entirely, and only
- minimizes `objective`.
+ Implementations of this method should ignore the `pre_train_ops` property of
+ the `minimization_problem`. The public `minimize_constrained` method will
+ take care of executing these before the returned train_op.
Args:
minimization_problem: ConstrainedMinimizationProblem, the problem to
@@ -83,19 +84,10 @@ class ConstrainedOptimizer(object):
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
- TensorFlow Op.
+ `Operation`, the train_op.
"""
- return self.optimizer.minimize(
- minimization_problem.objective,
- global_step=global_step,
- var_list=var_list,
- gate_gradients=gate_gradients,
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- name=name,
- grad_loss=grad_loss)
+ pass
- @abc.abstractmethod
def minimize_constrained(self,
minimization_problem,
global_step=None,
@@ -105,7 +97,7 @@ class ConstrainedOptimizer(object):
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
- """Returns an `Op` for minimizing the constrained problem.
+ """Returns an `Operation` for minimizing the constrained problem.
Unlike `minimize_unconstrained`, this function attempts to find a solution
that minimizes the `objective` portion of the minimization problem while
@@ -124,9 +116,83 @@ class ConstrainedOptimizer(object):
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
- TensorFlow Op.
+ `Operation`, the train_op.
"""
- pass
+
+ def train_op_callback():
+ return self._minimize_constrained(
+ minimization_problem,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ # If we have pre_train_ops, use tf.control_dependencies() to ensure that
+ # they execute before the train_op.
+ pre_train_ops = minimization_problem.pre_train_ops
+ if pre_train_ops:
+ with ops.control_dependencies(pre_train_ops):
+ train_op = train_op_callback()
+ else:
+ train_op = train_op_callback()
+
+ return train_op
+
+ def minimize_unconstrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Operation` for minimizing the unconstrained problem.
+
+ Unlike `minimize_constrained`, this function ignores the `constraints` (and
+ `proxy_constraints`) portion of the minimization problem entirely, and only
+ minimizes `objective`.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ `Operation`, the train_op.
+ """
+
+ def train_op_callback():
+ return self.optimizer.minimize(
+ minimization_problem.objective,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ # If we have pre_train_ops, use tf.control_dependencies() to ensure that
+ # they execute before the train_op.
+ pre_train_ops = minimization_problem.pre_train_ops
+ if pre_train_ops:
+ with ops.control_dependencies(pre_train_ops):
+ train_op = train_op_callback()
+ else:
+ train_op = train_op_callback()
+
+ return train_op
def minimize(self,
minimization_problem,
@@ -138,7 +204,7 @@ class ConstrainedOptimizer(object):
colocate_gradients_with_ops=False,
name=None,
grad_loss=None):
- """Returns an `Op` for minimizing the constrained problem.
+ """Returns an `Operation` for minimizing the constrained problem.
This method combines the functionality of `minimize_unconstrained` and
`minimize_constrained`. If global_step < unconstrained_steps, it will
@@ -164,14 +230,14 @@ class ConstrainedOptimizer(object):
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
Returns:
- TensorFlow Op.
+ `Operation`, the train_op.
Raises:
ValueError: If unconstrained_steps is provided, but global_step is not.
"""
def unconstrained_fn():
- """Returns an `Op` for minimizing the unconstrained problem."""
+ """Returns an `Operation` for minimizing the unconstrained problem."""
return self.minimize_unconstrained(
minimization_problem=minimization_problem,
global_step=global_step,
@@ -183,7 +249,7 @@ class ConstrainedOptimizer(object):
grad_loss=grad_loss)
def constrained_fn():
- """Returns an `Op` for minimizing the constrained problem."""
+ """Returns an `Operation` for minimizing the constrained problem."""
return self.minimize_constrained(
minimization_problem=minimization_problem,
global_step=global_step,
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
index 01c6e4f08a..d1af15f7e4 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
@@ -70,11 +70,13 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
region w.r.t. the Euclidean norm.
Raises:
- ValueError: if the `multipliers` tensor does not have a fully-known shape,
- or is not one-dimensional.
+ ValueError: if the `multipliers` tensor is not floating-point, does not have
+ a fully-known shape, or is not one-dimensional.
"""
+ if not multipliers.dtype.is_floating:
+ raise ValueError("multipliers must have a floating-point dtype")
multipliers_shape = multipliers.get_shape()
- if multipliers_shape is None:
+ if multipliers_shape.ndims is None:
raise ValueError("multipliers must have known shape")
if multipliers_shape.ndims != 1:
raise ValueError(
@@ -101,12 +103,12 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
(radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive)))
multipliers += scale * inactive
- new_inactive = standard_ops.to_float(multipliers > 0)
+ new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype)
multipliers *= new_inactive
return (iteration, multipliers, new_inactive, inactive)
iteration = standard_ops.constant(0)
- inactive = standard_ops.ones_like(multipliers)
+ inactive = standard_ops.ones_like(multipliers, dtype=multipliers.dtype)
# We actually want a do-while loop, so we explicitly call while_loop_body()
# once before tf.while_loop().
@@ -189,16 +191,16 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
def _projection_op(self, state, name=None):
pass
- def minimize_constrained(self,
- minimization_problem,
- global_step=None,
- var_list=None,
- gate_gradients=train_optimizer.Optimizer.GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- name=None,
- grad_loss=None):
- """Returns an `Op` for minimizing the constrained problem.
+ def _minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Operation` for minimizing the constrained problem.
The `optimizer` constructor parameter will be used to update the model
parameters, while the Lagrange multipliers will be updated using
@@ -216,8 +218,11 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
name: as in `tf.train.Optimizer`'s `minimize` method.
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+ Raises:
+ ValueError: If the minimization_problem tensors have different dtypes.
+
Returns:
- TensorFlow Op.
+ `Operation`, the train_op.
"""
objective = minimization_problem.objective
@@ -225,6 +230,14 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
proxy_constraints = minimization_problem.proxy_constraints
if proxy_constraints is None:
proxy_constraints = constraints
+
+ # Make sure that the objective, constraints and proxy constraints all have
+ # the same dtype.
+ if (objective.dtype.base_dtype != constraints.dtype.base_dtype or
+ objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype):
+ raise ValueError("objective, constraints and proxy_constraints must "
+ "have the same dtype")
+
# Flatten both constraints tensors to 1d.
num_constraints = minimization_problem.num_constraints
constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
@@ -241,8 +254,10 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
multipliers = self._lagrange_multipliers(state)
loss = (
- objective + standard_ops.tensordot(multipliers, proxy_constraints, 1))
- multipliers_gradient = constraints
+ objective + standard_ops.tensordot(
+ standard_ops.cast(multipliers, proxy_constraints.dtype),
+ proxy_constraints, 1))
+ multipliers_gradient = standard_ops.cast(constraints, multipliers.dtype)
update_ops = []
if self.constraint_optimizer is None:
@@ -356,6 +371,8 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer):
# For an AdditiveExternalRegretOptimizer, the internal state is simply a
# tensor of Lagrange multipliers with shape (m,), where m is the number of
# constraints.
+ #
+ # FUTURE WORK: make the dtype a parameter.
return standard_ops.zeros((num_constraints,), dtype=dtypes.float32)
def _lagrange_multipliers(self, state):
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
index ff846b191a..2c673d9347 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
@@ -79,9 +79,11 @@ def _maximal_eigenvector_power_method(matrix,
The maximal right-eigenvector of `matrix`.
Raises:
- ValueError: If the epsilon or maximum_iterations parameters violate their
- bounds.
+ ValueError: If the `matrix` tensor is not floating-point, or if the
+ `epsilon` or `maximum_iterations` parameters violate their bounds.
"""
+ if not matrix.dtype.is_floating:
+ raise ValueError("multipliers must have a floating-point dtype")
if epsilon <= 0.0:
raise ValueError("epsilon must be strictly positive")
if maximum_iterations <= 0:
@@ -139,11 +141,13 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
(i.e. the Frobenius norm).
Raises:
- ValueError: if the `matrix` tensor does not have a fully-known shape, or is
- not two-dimensional and square.
+ ValueError: if the `matrix` tensor is not floating-point, does not have a
+ fully-known shape, or is not two-dimensional and square.
"""
+ if not matrix.dtype.is_floating:
+ raise ValueError("multipliers must have a floating-point dtype")
matrix_shape = matrix.get_shape()
- if matrix_shape is None:
+ if matrix_shape.ndims is None:
raise ValueError("matrix must have known shape")
if matrix_shape.ndims != 2:
raise ValueError(
@@ -172,12 +176,12 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
matrix, axis=0, keepdims=True)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
matrix += scale * inactive
- new_inactive = standard_ops.to_float(matrix > 0)
+ new_inactive = standard_ops.cast(matrix > 0, matrix.dtype)
matrix *= new_inactive
return (iteration, matrix, new_inactive, inactive)
iteration = standard_ops.constant(0)
- inactive = standard_ops.ones_like(matrix)
+ inactive = standard_ops.ones_like(matrix, dtype=matrix.dtype)
# We actually want a do-while loop, so we explicitly call while_loop_body()
# once before tf.while_loop().
@@ -218,7 +222,7 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
"""Base class representing a `_SwapRegretOptimizer`.
This class contains most of the logic for performing constrained optimization,
- minimizing external regret for the constraints player. What it *doesn't* do is
+ minimizing swap regret for the constraints player. What it *doesn't* do is
keep track of the internal state (the stochastic matrix). Instead, the state
is accessed via the _initial_state(), _stochastic_matrix(),
_constraint_grad_and_var() and _projection_op() methods.
@@ -291,16 +295,16 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
def _projection_op(self, state, name=None):
pass
- def minimize_constrained(self,
- minimization_problem,
- global_step=None,
- var_list=None,
- gate_gradients=train_optimizer.Optimizer.GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- name=None,
- grad_loss=None):
- """Returns an `Op` for minimizing the constrained problem.
+ def _minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Operation` for minimizing the constrained problem.
The `optimizer` constructor parameter will be used to update the model
parameters, while the constraint/objective weight matrix (the analogue of
@@ -320,8 +324,11 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
name: as in `tf.train.Optimizer`'s `minimize` method.
grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+ Raises:
+ ValueError: If the minimization_problem tensors have different dtypes.
+
Returns:
- TensorFlow Op.
+ `Operation`, the train_op.
"""
objective = minimization_problem.objective
@@ -329,6 +336,14 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
proxy_constraints = minimization_problem.proxy_constraints
if proxy_constraints is None:
proxy_constraints = constraints
+
+ # Make sure that the objective, constraints and proxy constraints all have
+ # the same dtype.
+ if (objective.dtype.base_dtype != constraints.dtype.base_dtype or
+ objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype):
+ raise ValueError("objective, constraints and proxy_constraints must "
+ "have the same dtype")
+
# Flatten both constraints tensors to 1d.
num_constraints = minimization_problem.num_constraints
constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
@@ -344,15 +359,18 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
name="swap_regret_optimizer_state")
zero_and_constraints = standard_ops.concat(
- (standard_ops.zeros((1,)), constraints), axis=0)
+ (standard_ops.zeros((1,), dtype=constraints.dtype), constraints),
+ axis=0)
objective_and_proxy_constraints = standard_ops.concat(
(standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0)
distribution = self._distribution(state)
- loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints,
- 1)
+ loss = standard_ops.tensordot(
+ standard_ops.cast(distribution, objective_and_proxy_constraints.dtype),
+ objective_and_proxy_constraints, 1)
matrix_gradient = standard_ops.matmul(
- standard_ops.expand_dims(zero_and_constraints, 1),
+ standard_ops.expand_dims(
+ standard_ops.cast(zero_and_constraints, distribution.dtype), 1),
standard_ops.expand_dims(distribution, 0))
update_ops = []
@@ -555,6 +573,7 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer):
log_initial_one = math.log(1.0 - (self._initial_multiplier_radius *
(dimension - 1) / (dimension)))
log_initial_zero = math.log(self._initial_multiplier_radius / dimension)
+ # FUTURE WORK: make the dtype a parameter.
return standard_ops.concat(
(standard_ops.constant(
log_initial_one, dtype=dtypes.float32, shape=(1, dimension)),
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
index bff6301250..e36c9c0634 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
@@ -42,13 +42,13 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& transformations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
transformations_(transformations),
output_types_(output_types),
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 51e1b9aa65..d242cfdf49 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -131,7 +131,7 @@ class CSVDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
string compression_type, io::ZlibCompressionOptions options,
@@ -139,7 +139,7 @@ class CSVDatasetOp : public DatasetOpKernel {
const std::vector<PartialTensorShape>& output_shapes,
std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
bool use_quote_delim, char delim, string na_value)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
header_(header),
out_type_(output_types),
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
index b9306f611b..ccf7ec1f84 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
@@ -63,11 +63,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
index d77beb8e10..db24e60846 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
@@ -35,10 +35,10 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
}
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 13bcd77b4a..74df1e42a8 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -929,10 +929,9 @@ class MultiDeviceIteratorInitOp : public OpKernel {
LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
core::ScopedUnref unref(resource);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx,
- dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator",
+ &iterator));
int64 incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
&incarnation_id));
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index 4dc69dc2ef..ab584504a0 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -130,11 +130,13 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
ThreadPoolResource* threadpool)
- : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ threadpool_(threadpool) {
input_->Ref();
threadpool_->Ref();
}
@@ -165,9 +167,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented(
- "Cannot currently serialize the thread pool for a "
- "ThreadPoolDataset.");
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
}
private:
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
index f6bfc982e9..6fbf5d2ebb 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
@@ -47,10 +47,10 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
}
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 9123ca749b..2c93ce92ce 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -22,13 +22,14 @@ from __future__ import print_function
from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
-from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
+from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.training.distribute import *
+from tensorflow.python.training.distribution_strategy_context import *
from tensorflow.python.util.all_util import remove_undocumented
@@ -55,6 +56,7 @@ _allowed_symbols = [
'get_tower_context',
'has_distribution_strategy',
'require_tower_context',
+ 'UpdateContext',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index a1efbcaf9a..aeec9c44d7 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -56,7 +56,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.training import adam
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import tf_inspect
@@ -320,7 +320,7 @@ class NamedDistribution(object):
# pylint: disable=g-long-lambda
default_strategy = NamedDistribution(
"Default",
- lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access
required_gpus=None)
one_device_strategy = NamedDistribution(
"OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py
index 3e00cf4332..cc626c33bf 100644
--- a/tensorflow/contrib/distribute/python/estimator_integration_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adagrad
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import test
from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator import training
from tensorflow.python.estimator.canned import dnn_linear_combined
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export
@@ -63,8 +64,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
combinations.one_device_strategy,
combinations.mirrored_strategy_with_gpu_and_cpu,
combinations.mirrored_strategy_with_two_gpus
- ]))
- def test_complete_flow_with_mode(self, distribution):
+ ],
+ use_train_and_evaluate=[True, False]))
+ def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate):
label_dimension = 2
input_dimension = label_dimension
batch_size = 10
@@ -103,9 +105,15 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
train_distribute=distribution, eval_distribute=distribution))
num_steps = 10
- estimator.train(train_input_fn, steps=num_steps)
+ if use_train_and_evaluate:
+ scores, _ = training.train_and_evaluate(
+ estimator,
+ training.TrainSpec(train_input_fn, max_steps=num_steps),
+ training.EvalSpec(eval_input_fn))
+ else:
+ estimator.train(train_input_fn, steps=num_steps)
+ scores = estimator.evaluate(eval_input_fn)
- scores = estimator.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index e064cfe37d..9a4cc0a897 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -40,7 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -164,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# This variable should be created only once across the threads because of
# special variable_creator functions used by `dist.call_for_each_tower`.
v = variable_scope.variable(1.0, name="foo")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -181,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0)
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -201,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@@ -223,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@@ -245,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(device_id):
v = variable_scope.variable(1.0, name="foo_" + str(device_id))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -268,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
layer2 = core.Dense(1)
layer2(features)
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
layer3 = core.Dense(1)
layer3(features)
return [(layer1.kernel, layer1.bias),
@@ -300,7 +301,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
@@ -343,7 +345,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
@@ -453,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0, name="foo")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -470,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -570,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope("foo"):
a = constant_op.constant(1.0, name="a")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
b = constant_op.constant(1.0, name="b")
return a, b
@@ -591,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
@@ -619,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.variable(1.0, name="b")
with ops.name_scope("foo"):
- c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ c = distribution_strategy_context.get_tower_context().merge_call(
+ in_cross_tower)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@@ -651,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.get_variable("b", [1])
with ops.name_scope("foo"):
- c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ c = distribution_strategy_context.get_tower_context().merge_call(
+ in_cross_tower)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@@ -833,8 +840,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
@@ -898,8 +906,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
@@ -963,8 +972,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index a066adf124..5db2fff239 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -24,7 +24,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
@@ -68,7 +68,8 @@ class VariableCreatorStackTest(test.TestCase):
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
return v
def main_thread_creator(next_creator, *args, **kwargs):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index cf29c0ed91..02eb68227d 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
@@ -101,7 +101,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
last_part_device = 'device:CPU:0'
else:
last_part_device = (
- 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ 'device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
@@ -192,14 +193,16 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
tower_compute_device = '/device:CPU:0'
else:
tower_compute_device = (
- '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ '/device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
tower_compute_device = device_util.canonicalize(tower_compute_device)
if 'CPU' in variable_device:
tower_variable_device = '/device:CPU:0'
else:
tower_variable_device = (
- '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ '/device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
tower_variable_device = device_util.canonicalize(tower_variable_device)
a = constant_op.constant(1.0)
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index baed0ebaae..371b97ba96 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -28,7 +28,7 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer
@@ -45,7 +45,8 @@ def _raise_exception_fn(_=None):
# Must be the argument to a distribution.call_for_each_tower() call, calls a
# get_tower_context().merge_call() that raises an exception.
def _merge_raises_fn():
- distribute_lib.get_tower_context().merge_call(_raise_exception_fn)
+ distribution_strategy_context.get_tower_context().merge_call(
+ _raise_exception_fn)
# Must be the argument to a get_tower_context().merge_call() call, calls
@@ -58,7 +59,7 @@ def _call_raises_fn(dist):
# calls a get_tower_context().merge_call() that calls a
# call_for_each_tower() that raises an exception.
def _merge_call_raises_fn():
- distribute_lib.get_tower_context().merge_call(_call_raises_fn)
+ distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn)
# Must be the argument to a get_tower_context().merge_call() call, calls
@@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist):
# get_tower_context().merge_call() that calls a call_for_each_tower() that
# calls a get_tower_context().merge_call() that raises an exception.
def _merge_call_merge_raises_fn():
- distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn)
+ distribution_strategy_context.get_tower_context().merge_call(
+ _call_merge_raises_fn)
class DistributionTestBase(test.TestCase):
@@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase):
expected_devices = [False] * len(d.worker_devices)
def mark_devices_fn():
- tower_id = distribute_lib.get_tower_context().tower_id
+ tower_id = distribution_strategy_context.get_tower_context().tower_id
self.assertLess(tower_id, len(d.worker_devices))
self.assertFalse(expected_devices[tower_id])
expected_devices[tower_id] = True
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 5fd4c9de69..8548a86421 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -56,7 +57,7 @@ class DistributedValues(object):
def get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context:
device = tower_context.device
else:
@@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate):
# We want cross-tower code that does some var.op.X calls
# to work (even if the current device isn't in self.devices), but
# other uses of var.op in a cross-tower context to fail.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return DistributedVarOp(self._primary_var.op.name,
self._primary_var.op.graph,
self._primary_var.op.type)
return self.get().op
def read_value(self):
- return distribute_lib.get_distribution_strategy().read_var(self)
+ return distribution_strategy_context.get_distribution_strategy().read_var(
+ self)
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
@@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
@@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
v = self.get(device=update_device)
return f(v, *args, **kwargs)
- return distribute_lib.get_distribution_strategy().update(
+ return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
_assert_tower_context()
@@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored,
aggregation=self._aggregation, value=value, destinations=self),
*other_args, **other_kwargs)
- return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
- **kwargs)
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
def assign_sub(self, *args, **kwargs):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
@@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._primary_var._as_graph_element()
return self.get()._as_graph_element()
@@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().read_var(
+ return distribution_strategy_context.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,
@@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
def _assert_tower_context():
- if not distribute_lib.get_tower_context():
+ if not distribution_strategy_context.get_tower_context():
raise RuntimeError(
"Tower-local variables may only be assigned in a tower context.")
@@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
@@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._get_cross_tower()
return self.get()._as_graph_element()
@@ -994,12 +996,12 @@ class MultiStepContext(object):
outputs as already reduced or not.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._last_step_outputs_aggregations[name] = aggregation
if aggregation is variables_lib.VariableAggregation.NONE:
self._last_step_outputs[name] = output
else:
- distribution = distribute_lib.get_distribution_strategy()
+ distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0")
else:
@@ -1011,7 +1013,9 @@ class MultiStepContext(object):
# context object, so it's more robust to set it only once (even if all
# the towers are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
@property
def non_tensor_outputs(self):
@@ -1020,14 +1024,15 @@ class MultiStepContext(object):
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
def value_container(val):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
index 90910f3839..200310bc41 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
@@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase):
self.assertAllClose(
np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_)
+ def testEntropy(self):
+ loc = np.array([-0.1, -3.2, 7.])
+ deterministic = deterministic_lib.Deterministic(loc=loc)
+ with self.test_session() as sess:
+ entropy_ = sess.run(deterministic.entropy())
+ self.assertAllEqual(np.zeros(3), entropy_)
+
class VectorDeterministicTest(test.TestCase):
@@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase):
self.assertAllClose(
np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_)
+ def testEntropy(self):
+ loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]])
+ deterministic = deterministic_lib.VectorDeterministic(loc=loc)
+ with self.test_session() as sess:
+ entropy_ = sess.run(deterministic.entropy())
+ self.assertAllEqual(np.zeros(2), entropy_)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index ad853ee293..affc64a14f 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -152,6 +152,9 @@ class _BaseDeterministic(distribution.Distribution):
"""Relative tolerance for comparing points to `self.loc`."""
return self._rtol
+ def _entropy(self):
+ return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype)
+
def _mean(self):
return array_ops.identity(self.loc)
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
index 975105a179..5621d6a358 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
@@ -495,7 +495,7 @@
" random_vector_for_generation)\n",
" \n",
" # saving (checkpoint) the model every 15 epochs\n",
- " if epoch % 15 == 0:\n",
+ " if (epoch + 1) % 15 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 78a711548d..027097908f 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -132,6 +132,7 @@
"tf.enable_eager_execution()\n",
"\n",
"import numpy as np\n",
+ "import os\n",
"import re\n",
"import random\n",
"import unidecode\n",
@@ -313,7 +314,7 @@
"outputs": [],
"source": [
"dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n",
- "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
@@ -493,7 +494,7 @@
"source": [
"# Training step\n",
"\n",
- "EPOCHS = 30\n",
+ "EPOCHS = 20\n",
"\n",
"for epoch in range(EPOCHS):\n",
" start = time.time()\n",
@@ -520,7 +521,7 @@
" batch,\n",
" loss))\n",
" # saving (checkpoint) the model every 5 epochs\n",
- " if epoch % 5 == 0:\n",
+ " if (epoch + 1) % 5 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 1d07721e3b..08d8364978 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -319,7 +319,7 @@
"vocab_tar_size = len(targ_lang.word2idx)\n",
"\n",
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
- "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
]
},
{
@@ -619,7 +619,7 @@
" batch,\n",
" batch_loss.numpy()))\n",
" # saving (checkpoint) the model every 2 epochs\n",
- " if epoch % 2 == 0:\n",
+ " if (epoch + 1) % 2 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
" \n",
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
index acc0f5b653..ee25d25b52 100644
--- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
+++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
@@ -701,7 +701,7 @@
" generate_images(generator, inp, tar)\n",
" \n",
" # saving (checkpoint) the model every 20 epochs\n",
- " if epoch % 20 == 0:\n",
+ " if (epoch + 1) % 20 == 0:\n",
" checkpoint.save(file_prefix = checkpoint_prefix)\n",
"\n",
" print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 82272bf120..77f62df99d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -20,6 +20,7 @@ py_library(
":dnn_linear_combined",
":early_stopping",
":export",
+ ":exporter",
":extenders",
":head",
":hooks",
@@ -220,6 +221,33 @@ py_test(
)
py_library(
+ name = "exporter",
+ srcs = [
+ "python/estimator/exporter.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python/estimator:exporter",
+ ],
+)
+
+py_test(
+ name = "exporter_test",
+ size = "medium",
+ srcs = ["python/estimator/exporter_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":exporter",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:exporter",
+ ],
+)
+
+py_library(
name = "head",
srcs = [
"python/estimator/head.py",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index e1453ae1d0..6ad3a4a604 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -45,6 +45,7 @@ _allowed_symbols = [
'clip_gradients_by_norm',
'forward_features',
'InMemoryEvaluatorHook',
+ 'StopAtCheckpointStepHook',
'logistic_regression_head',
'multi_class_head',
'multi_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/exporter.py b/tensorflow/contrib/estimator/python/estimator/exporter.py
new file mode 100644
index 0000000000..09d7440605
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/exporter.py
@@ -0,0 +1,280 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implements StepsExporter to export the model in user specified steps."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.estimator import exporter
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.summary import summary_iterator
+
+DEFAULT_GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP
+
+
+class StepsExporter(exporter.Exporter):
+ """This class exports the model in user specified steps.
+
+ This class exports the model at the steps given by the `steps_to_keep`
+ argument. Each number in the list is treated as a lower bound for model
+ exports, to handle the case when evaluation is performed at different steps.
+
+ Consider this example:
+
+ ```
+ steps_to_keep = [1, 2, 3, 6, 7, 10, 12, 25]
+ ```
+
+ The model is evaluated at step increments of 5: `[5, 10, 15, 20, 25, 30]`.
+ The `StepsExporter` will export the model when it has reached steps
+ `[5, 10, 15, 25]`.
+
+ This example illustrates the two cases when the model is exported:
+
+ 1. Model is evaluated on a step defined in the list `steps_to_keep`.
+
+ In the example, the model is exported on step `10` and `25`.
+
+ 2. Model is evaluated on a step not defined in the list `steps_to_keep`, but
+ is still exported because a step in `steps_to_keep` was missed.
+
+ In the example, when the model reaches step `5`, the model is exported even
+ though `steps_to_keep` does not contain `5`. Step `5` is exported to make
+ up for step `3`, which was missed. Steps `1` and `2` in `steps_to_keep` are
+ skipped completely (e.g. say the model is evaluated at step `6`. It will
+ **not** be exported to make up for step `2`).
+
+ Using the `steps_to_keep` list as a lower bound allows users to define
+ approximate step boundaries for exporting their models, and avoid frustrating
+ off-by-one calculation errors.
+
+ Sample Use Cases:
+ There are specific points during the training when having a saved version of
+ the model would be useful. One example is at the end of each training phase
+ when the set of freezed weights is changed.
+ Another good use case is saving the model at the end of each epoch for
+ visualization or retraining.
+ """
+
+ def __init__(self,
+ steps_to_keep,
+ name='steps_exporter',
+ serving_input_receiver_fn=None,
+ event_file_pattern='eval/*.tfevents.*',
+ assets_extra=None,
+ as_text=False):
+ """Create an `StepsExporter` to use with `tf.estimator.EvalSpec`.
+
+ Example of creating a StepsExporter for training and evaluation:
+
+ ```python
+ categorical_feature_a = categorical_column_with_hash_bucket(...)
+ categorical_feature_b = categorical_column_with_hash_bucket(...)
+
+ categorical_feature_a_emb = embedding_column(
+ categorical_column=categorical_feature_a, ...)
+ categorical_feature_b_emb = embedding_column(
+ categorical_column=categorical_feature_b, ...)
+
+ estimator = tf.estimator.DNNClassifier(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256])
+
+ # Input pipeline for train and evaluate.
+ def train_input_fn: # returns x, y
+ # please shuffle the data.
+ pass
+ def eval_input_fn_eval: # returns x, y
+ pass
+
+ exporter = tf.contrib.estimator.exporter.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=serving_input_receiver_fn,
+ event_file_pattern='eval/*.tfevents.*'
+ steps_to_keep=[...])
+
+ train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
+
+ eval_spec = [tf.estimator.EvalSpec(
+ input_fn=eval_input_fn,
+ steps=1,
+ exporters=exporter,
+ start_delay_secs=0,
+ throttle_secs=5)]
+
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+
+ # Models will be exported to estimator.model_dir in timestamped directories,
+ # which can be used for serving, analysis with TFMA, or directly loaded in.
+ # For example:
+ export_dir = os.path.join(estimator.model_dir,
+ <timestamped directory name>)
+
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ tf.saved_model.loader.load(
+ sess, [tf.saved_model.tag_constants.SERVING], export_dir)
+
+ ```
+
+ Args:
+ steps_to_keep: Non-empty list of positive integers containing
+ the step numbers at which the model should be exported. All the exports
+ will be kept, so there is no garbage collection.
+ name: Unique name of this `Exporter` that is going to be used in the
+ export path.
+ serving_input_receiver_fn: A function that takes no arguments and returns
+ a `ServingInputReceiver`.
+ event_file_pattern: Event file name pattern relative to model_dir. If
+ None, however, the exporter would not be preemption-safe. To be
+ preemption-safe, event_file_pattern should be specified.
+ assets_extra: An optional dict specifying how to populate the assets.extra
+ directory within the exported SavedModel. Each key should give the
+ destination path (including the filename) relative to the assets.extra
+ directory. The corresponding value gives the full path of the source
+ file to be copied. For example, the simple case of copying a single
+ file without renaming it is specified as `{'my_asset_file.txt':
+ '/path/to/my_asset_file.txt'}`.
+ as_text: Whether to write the SavedModel proto in text format. Defaults to
+ `False`.
+
+ Raises:
+ ValueError: If any arguments is invalid.
+ """
+ # pylint: disable=protected-access
+ self._saved_model_exporter = exporter._SavedModelExporter(
+ name, serving_input_receiver_fn, assets_extra, as_text)
+ # pylint: enable=protected-access
+
+ self._event_file_pattern = event_file_pattern
+ self._model_dir = None
+
+ self._input_steps_to_keep = steps_to_keep
+ steps_to_keep = [step for step in steps_to_keep if isinstance(step, int)]
+ steps_to_keep = [step for step in steps_to_keep if step > 0]
+ if not steps_to_keep:
+ raise ValueError(
+ '`steps_to_keep` list must have at least one positive integer')
+ elif self._input_steps_to_keep != steps_to_keep:
+ tf_logging.warn('Changed `steps_to_keep`, by omitting non-integer or'
+ ' less than 1 elements, to [%s]',
+ ', '.join(str(step) for step in steps_to_keep))
+ self._steps_to_keep = sorted(steps_to_keep)
+ self._steps_kept = []
+
+ @property
+ def name(self):
+ return self._saved_model_exporter.name
+
+ def export(self, estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ """Exports the given Estimator to a specific format.
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance to export.
+ export_path: A string containing a directory where to write the export.
+ checkpoint_path: The checkpoint path to export.
+ eval_result: The output of Estimator.evaluate on this checkpoint.
+ is_the_final_export: This boolean is True when this is an export in the
+ end of training. It is False for the intermediate exports during the
+ training. When passing Exporter to tf.estimator.train_and_evaluate
+ is_the_final_export is always False if TrainSpec.max_steps is None.
+
+ Returns:
+ The string path to the exported directory or None if export is skipped.
+
+ Raises:
+ ValueError: If `eval_result` is None or doesn't have
+ `ops.GraphKeys.GLOBAL_STEP` as a key.
+ """
+ export_result = None
+
+ if not eval_result or DEFAULT_GLOBAL_STEP_KEY not in eval_result:
+ raise ValueError(
+ '`eval_result` is empty, or does not have global step. This'
+ ' should never happen as Estimator always sets the global step in '
+ '`eval_result`. Please file a bug report. Got eval_result: %s'
+ % str(eval_result))
+
+ if self._model_dir != estimator.model_dir and self._event_file_pattern:
+ tf_logging.info('Loads the steps that the model was already evaluated at,'
+ 'from event files')
+ self._model_dir = estimator.model_dir
+ full_event_file_pattern = os.path.join(self._model_dir,
+ self._event_file_pattern)
+ self._steps_kept = self._get_kept_steps(full_event_file_pattern)
+
+ if self._steps_kept:
+ self._steps_kept = sorted(self._steps_kept)
+ self._steps_to_keep = [step for step in self._steps_to_keep if
+ step > self._steps_kept[-1]]
+ # It is assumed that the model is exported at any evaluated step 'n' if
+ # there is any `steps_missed` lower than 'n'. As a result, all the steps in
+ # `_steps_to_keep` lower than the last evaluated step will be removed.
+ steps_missed = [step for step in self._steps_to_keep
+ if step <= eval_result[DEFAULT_GLOBAL_STEP_KEY]]
+
+ if steps_missed:
+ # update the `_steps_to_keep` list by omitting all steps smaller than the
+ # current global step which are missed to be exported
+ export_result = self._saved_model_exporter.export(estimator, export_path,
+ checkpoint_path,
+ eval_result,
+ is_the_final_export)
+ self._steps_to_keep = [step for step in self._steps_to_keep if step
+ not in steps_missed]
+ # contains all the steps in which export has happened.
+ self._steps_kept.append(eval_result[DEFAULT_GLOBAL_STEP_KEY])
+ # Show warning for all the missed steps except the last one
+ if steps_missed[:-1]:
+ tf_logging.warn('Missed steps [%s] for exporting, as no evaluation'
+ ' took place at them.', ', '.join(str(step) for step in
+ steps_missed[:-1]))
+ # Log model export if the last missed step is the same as the current step
+ if steps_missed[-1] == eval_result[DEFAULT_GLOBAL_STEP_KEY]:
+ tf_logging.info('Performing model export at step %d.',
+ eval_result[DEFAULT_GLOBAL_STEP_KEY])
+ # Show warning for exporting model at another step instead of the user
+ # specified one
+ else:
+ tf_logging.warn('Performing model export at step %d instead of %d, as'
+ ' no evaluation took place at step %d.',
+ eval_result[DEFAULT_GLOBAL_STEP_KEY], steps_missed[-1],
+ steps_missed[-1])
+ return export_result
+
+ def _get_kept_steps(self, event_files):
+ """Get the steps that the model was evaluated at, from event files.
+
+ Args:
+ event_files: Absolute pattern of event files.
+
+ Returns:
+ steps_kept: A list of steps in which the model was evaluated.
+ """
+ if not event_files:
+ return None
+
+ steps_kept = []
+ for event_file in gfile.Glob(os.path.join(event_files)):
+ for event in summary_iterator.summary_iterator(event_file):
+ if event.step not in steps_kept:
+ steps_kept.append(event.step)
+ return steps_kept
diff --git a/tensorflow/contrib/estimator/python/estimator/exporter_test.py b/tensorflow/contrib/estimator/python/estimator/exporter_test.py
new file mode 100644
index 0000000000..0d009b945e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/exporter_test.py
@@ -0,0 +1,206 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `StepsExporter`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+
+from tensorflow.contrib.estimator.python.estimator import exporter as exporter_lib
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+
+class StepsExporterTest(test.TestCase):
+
+ def test_error_out_if_steps_to_keep_has_no_positive_integers(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ with self.assertRaisesRegexp(ValueError, "positive integer"):
+ exporter = exporter_lib.StepsExporter(
+ name="specified_steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ steps_to_keep=[-1, 0, 1.1])
+ self.assertEqual("specified_steps_exporter", exporter.name)
+
+ def test_steps_exporter(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1])
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+ estimator.model_dir = export_dir_base
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 1},
+ False)
+
+ self.assertEqual("export_result_path", export_result)
+ estimator.export_savedmodel.assert_called_with(
+ export_dir_base,
+ _serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ checkpoint_path="checkpoint_path",
+ strip_default_attrs=True)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+ def test_steps_exporter_with_preemption(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ eval_dir_base = os.path.join(export_dir_base, "eval_continuous")
+ estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1)
+ estimator_lib._write_dict_to_summary(eval_dir_base, {}, 2)
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ event_file_pattern="eval_continuous/*.tfevents.*",
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1, 2, 6, 8])
+
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.model_dir = export_dir_base
+ estimator.export_savedmodel.return_value = "export_result_path"
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 3},
+ False)
+ self.assertEqual(None, export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 6},
+ False)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 7},
+ False)
+ self.assertEqual(None, export_result)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+ def test_specified_step_is_saved(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1, 5, 8, 10, 11])
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+ estimator.model_dir = export_dir_base
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 1},
+ False)
+
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 2},
+ False)
+ self.assertEqual(None, export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 5},
+ False)
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 10},
+ False)
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 15},
+ False)
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 20},
+ False)
+ self.assertEqual(None, export_result)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+ def test_steps_exporter_with_no_global_step_key(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1])
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+ estimator.model_dir = export_dir_base
+
+ with self.assertRaisesRegexp(ValueError, "does not have global step"):
+ exporter.export(estimator, export_dir_base, "checkpoint_path", {}, False)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index caadafdfa6..faefda7c48 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+import time
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.framework import ops
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training
+from tensorflow.python.training import training_util
# pylint: disable=protected-access
@@ -210,4 +212,55 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
+class StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint."""
+
+ def __init__(self, model_dir, last_step,
+ wait_after_file_check_secs=30):
+ """Initializes a `StopAtCheckpointStepHook`.
+
+ This hook requests stop after a last step has been reached. It checks latest
+ checkpoint to verify last step is written on disk or not.
+
+ Args:
+ model_dir: Directory to read global step from latest checkpoint.
+ last_step: Step after which to stop.
+ wait_after_file_check_secs: Reading same file by many workers may create
+ I/O issues. To throttle that we will wait given secs after each read of
+ the file.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+ """
+ if last_step is None:
+ raise ValueError('last_step must be specified.')
+ if model_dir is None:
+ raise ValueError('model_dir must be specified.')
+
+ self._model_dir = model_dir
+ self._last_step = last_step
+ self._wait_after_file_check_secs = wait_after_file_check_secs
+
+ def begin(self):
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ 'Global step should be created to use StopAtCheckpointStepHook.')
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return training.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results + 1
+ if global_step >= self._last_step:
+ # Check latest global step in the checkpoint to ensure that the targeted
+ # last step is written on disk.
+
+ step = estimator_lib._load_global_step_from_checkpoint_dir(
+ self._model_dir)
+ if step >= self._last_step:
+ run_context.request_stop()
+ else:
+ time.sleep(self._wait_after_file_check_secs)
+
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index ee88d5ecf5..42352aa3ff 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import glob
import json
import os
+import tempfile
+import time
from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib
+from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -316,5 +319,59 @@ class InMemoryEvaluatorHookTest(test.TestCase):
estimator.train(input_fn, hooks=[evaluator])
+class StopAtCheckpointStepHookTest(test.TestCase):
+
+ def test_do_not_stop_if_checkpoint_is_not_there(self):
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=tempfile.mkdtemp(), last_step=10)
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_do_not_stop_if_checkpoint_step_is_smaller(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_nine = step.assign(9)
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_nine)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_stop_if_checkpoint_step_is_laststep(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_ten)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertFalse(mock_sleep.called)
+ self.assertTrue(mon_sess.should_stop())
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 4d8d5004fe..f384d761a8 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -188,7 +188,6 @@ class _ModelFn(object):
# center.
# is_initialized: scalar indicating whether the initial cluster centers
# have been chosen; see init_op.
- # cluster_centers_var: a Variable containing the cluster centers.
# init_op: an op to choose the initial cluster centers. A single worker
# repeatedly executes init_op until is_initialized becomes True.
# training_op: an op that runs an iteration of training, either an entire
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 82e3bbe3c0..9866fccfba 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -424,9 +424,11 @@ py_library(
":namedtuples",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:summary",
"//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/ops/losses",
],
)
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 508f487722..f9995bb19d 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -22,7 +22,9 @@ from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import util as loss_util
from tensorflow.python.summary import summary
@@ -32,6 +34,7 @@ __all__ = [
'add_gan_model_summaries',
'add_regularization_loss_summaries',
'add_cyclegan_image_summaries',
+ 'add_stargan_image_summaries'
]
@@ -179,6 +182,94 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2,
max_outputs=1)
+def add_stargan_image_summaries(stargan_model,
+ num_images=2,
+ display_diffs=False):
+ """Adds image summaries to see StarGAN image results.
+
+ If display_diffs is True, each image result has `2` rows and `num_domains + 1`
+ columns.
+ The first row looks like:
+ [original_image, transformed_to_domain_0, transformed_to_domain_1, ...]
+ The second row looks like:
+ [no_modification_baseline, transformed_to_domain_0-original_image, ...]
+ If display_diffs is False, only the first row is shown.
+
+ IMPORTANT:
+ Since the model originally does not transformed the image to every domains,
+ we will transform them on-the-fly within this function in parallel.
+
+ Args:
+ stargan_model: A StarGANModel tuple.
+ num_images: The number of examples/images to be transformed and shown.
+ display_diffs: Also display the difference between generated and target.
+
+ Raises:
+ ValueError: If input_data is not images.
+ ValueError: If input_data_domain_label is not rank 2.
+ ValueError: If dimension 2 of input_data_domain_label is not fully defined.
+ """
+
+ _assert_is_image(stargan_model.input_data)
+ stargan_model.input_data_domain_label.shape.assert_has_rank(2)
+ stargan_model.input_data_domain_label.shape[1:].assert_is_fully_defined()
+
+ num_domains = stargan_model.input_data_domain_label.get_shape().as_list()[-1]
+
+ def _build_image(image):
+ """Helper function to create a result for each image on the fly."""
+
+ # Expand the first dimension as batch_size = 1.
+ images = array_ops.expand_dims(image, axis=0)
+
+ # Tile the image num_domains times, so we can get all transformed together.
+ images = array_ops.tile(images, [num_domains, 1, 1, 1])
+
+ # Create the targets to 0, 1, 2, ..., num_domains-1.
+ targets = array_ops.one_hot(list(range(num_domains)), num_domains)
+
+ with variable_scope.variable_scope(
+ stargan_model.generator_scope, reuse=True):
+
+ # Add the original image.
+ output_images_list = [image]
+
+ # Generate the image and add to the list.
+ gen_images = stargan_model.generator_fn(images, targets)
+ gen_images_list = array_ops.split(gen_images, num_domains)
+ gen_images_list = [
+ array_ops.squeeze(img, axis=0) for img in gen_images_list
+ ]
+ output_images_list.extend(gen_images_list)
+
+ # Display diffs.
+ if display_diffs:
+ diff_images = gen_images - images
+ diff_images_list = array_ops.split(diff_images, num_domains)
+ diff_images_list = [
+ array_ops.squeeze(img, axis=0) for img in diff_images_list
+ ]
+ output_images_list.append(array_ops.zeros_like(image))
+ output_images_list.extend(diff_images_list)
+
+ # Create the final image.
+ final_image = eval_utils.image_reshaper(
+ output_images_list, num_cols=num_domains + 1)
+
+ # Reduce the first rank.
+ return array_ops.squeeze(final_image, axis=0)
+
+ summary.image(
+ 'stargan_image_generation',
+ functional_ops.map_fn(
+ _build_image,
+ stargan_model.input_data[:num_images],
+ parallel_iterations=num_images,
+ back_prop=False,
+ swap_memory=True),
+ max_outputs=num_images)
+
+
def add_gan_model_summaries(gan_model):
"""Adds typical GANModel summaries.
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index 33d51bfc21..54a6f8d4d9 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries
from tensorflow.python.framework import ops
@@ -37,6 +36,10 @@ def discriminator_model(inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
+def stargan_generator_model(inputs, _):
+ return generator_model(inputs)
+
+
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
@@ -57,6 +60,31 @@ def get_gan_model():
discriminator_fn=discriminator_model)
+def get_stargan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+ with variable_scope.variable_scope('generator') as gen_scope:
+ return namedtuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=stargan_generator_model(
+ array_ops.ones([1, 2, 2, 3]), None),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]),
+ discriminator_generated_data_source_predication=array_ops.ones([1]),
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]),
+ generator_variables=None,
+ generator_scope=gen_scope,
+ generator_fn=stargan_generator_model,
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=discriminator_model)
+
+
def get_cyclegan_model():
with variable_scope.variable_scope('x2y'):
model_x2y = get_gan_model()
@@ -143,6 +171,16 @@ class SummariesTest(test.TestCase):
with self.test_session(use_gpu=True):
summary.merge_all().eval()
+ def test_add_image_comparison_summaries_for_stargan(self):
+
+ summaries.add_stargan_image_summaries(get_stargan_model())
+
+ self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ summary.merge_all().eval()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 03f52d214b..9e5aea1498 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -52,7 +52,6 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-
__all__ = [
'gan_model',
'infogan_model',
@@ -61,6 +60,7 @@ __all__ = [
'stargan_model',
'gan_loss',
'cyclegan_loss',
+ 'stargan_loss',
'gan_train_ops',
'gan_train',
'get_sequential_train_hooks',
@@ -646,8 +646,9 @@ def gan_loss(
type(model))
# Optionally create pooled model.
- pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if
- tensor_pool_fn else model)
+ pooled_model = (
+ _tensor_pool_adjusted_model(model, tensor_pool_fn)
+ if tensor_pool_fn else model)
# Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
@@ -665,9 +666,10 @@ def gan_loss(
if _use_aux_loss(mutual_information_penalty_weight):
gen_info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries)
- dis_info_loss = (gen_info_loss if tensor_pool_fn is None else
- tfgan_losses.mutual_information_penalty(
- pooled_model, add_summaries=add_summaries))
+ dis_info_loss = (
+ gen_info_loss
+ if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty(
+ pooled_model, add_summaries=add_summaries))
gen_loss += mutual_information_penalty_weight * gen_info_loss
dis_loss += mutual_information_penalty_weight * dis_info_loss
if _use_aux_loss(aux_cond_generator_weight):
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
index b510994152..80b2d3e08b 100644
--- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -204,11 +204,11 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const std::vector<string>& filenames,
const DataTypeVector& output_types)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(filenames),
output_types_(output_types) {}
@@ -233,7 +233,8 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
}
protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
Node** output) const override {
Node* filenames = nullptr;
TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py
index 61f78febfc..7b7ac4f347 100644
--- a/tensorflow/contrib/integrate/python/ops/odes.py
+++ b/tensorflow/contrib/integrate/python/ops/odes.py
@@ -73,7 +73,7 @@ def _scaled_dot_product(scale, xs, ys, name=None):
# _possibly_nonzero lets us avoid wasted computation.
return math_ops.add_n(
[(scale * x) * y for x, y in zip(xs, ys)
- if _possibly_nonzero(x) or _possibly_nonzero(y)],
+ if _possibly_nonzero(x) and _possibly_nonzero(y)],
name=scope)
@@ -122,7 +122,7 @@ def _runge_kutta_step(func,
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
k.append(func(yi, ti))
- if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
+ if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
# This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
index 92ae79d3c7..d0ea961473 100644
--- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
+++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
@@ -52,12 +52,12 @@ class KafkaDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> topics,
const string& servers, const string& group, const bool eof,
const int64 timeout)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
topics_(std::move(topics)),
servers_(servers),
group_(group),
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
index d6b1a61b71..44e01e1aeb 100644
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ b/tensorflow/contrib/kfac/examples/convnet.py
@@ -202,7 +202,7 @@ def minimize_loss_single_machine(loss,
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
- device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse
+ device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
update ops are run on this device.
session_config: None or tf.ConfigProto. Configuration for tf.Session().
@@ -470,7 +470,7 @@ def train_mnist_single_machine(data_dir,
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
- device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse
+ device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
update ops are run on this device.
Returns:
@@ -509,7 +509,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of CPUs to split inference across.
use_fake_data: bool. If True, generate a synthetic dataset.
- devices: string, Either list of CPU or GPU. The covaraince and inverse
+ devices: string, Either list of CPU or GPU. The covariance and inverse
update ops are run on this device.
Returns:
@@ -621,7 +621,7 @@ def train_mnist_distributed_sync_replicas(task_id,
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
op_strategy: `string`, Strategy to run the covariance and inverse
- ops. If op_strategy == `chief_worker` then covaraiance and inverse
+ ops. If op_strategy == `chief_worker` then covariance and inverse
update ops are run on chief worker otherwise they are run on dedicated
workers.
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index 854f885c26..323234c403 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -97,8 +97,8 @@ class FisherEstimator(object):
and to regularize the update direction by making it closer to the
gradient. (Higher damping means the update looks more like a standard
gradient update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the fisher
- blocks, kronecker factors, and losses associated with the
+ layer_collection: The layer collection object, which holds the Fisher
+ blocks, Kronecker factors, and losses associated with the
graph.
exps: List of floats or ints. These represent the different matrix
powers of the approximate Fisher that the FisherEstimator will be able
@@ -464,7 +464,7 @@ class FisherEstimator(object):
def _get_grads_lists_empirical(self, tensors):
# Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnessesary ops on the default device
+ # the latter creates unnecessary ops on the default device
grads_flat = gradients_impl.gradients(
self._layers.eval_losses(),
nest.flatten(tensors),
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 3a5c8eb5f9..9fa6eb7dcd 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -870,7 +870,7 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
Estimates the Fisher Information matrix's blog for a convolutional
layer.
- Consider a convoluational layer in this model with (unshared) filter matrix
+ Consider a convolutional layer in this model with (unshared) filter matrix
'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
this FisherBlock estimates,
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index b43232dfaf..afa2fd1ca7 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -71,15 +71,15 @@ _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
-# If True, then subsamples the tensor passed to compute the covaraince matrix.
+# If True, then subsamples the tensor passed to compute the covariance matrix.
_SUB_SAMPLE_OUTER_PRODUCTS = False
-# If True, then subsamples the tensor passed to compute the covaraince matrix.
+# If True, then subsamples the tensor passed to compute the covariance matrix.
_SUB_SAMPLE_INPUTS = False
# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
# passed to the factors from the blocks will be concatenated across towers
-# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over
+# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over
# towers will be passed in, and the factors will iterate over this and do the
# cov computations separately for each one, averaging the results together.
TOWER_STRATEGY = "concat"
@@ -309,7 +309,7 @@ def _subsample_for_cov_computation(array, name=None):
def _random_tensor_gather(array, max_size):
- """Generates a random set of indices and gathers the value at the indcices.
+ """Generates a random set of indices and gathers the value at the indices.
Args:
array: Tensor, of shape `[batch_size, dim_2]`.
@@ -1762,8 +1762,8 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
# Might need to enforce symmetry lost due to numerical issues.
invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
- # The following line imposses the symmetry assumed by "Option 1" on C1.
- # Stangely the code can work okay with this line commented out,
+ # The following line imposes the symmetry assumed by "Option 1" on C1.
+ # Strangely the code can work okay with this line commented out,
# depending on how psd_eig is defined. I'm not sure why.
C1 = (C1 + array_ops.transpose(C1)) / 2.0
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index cbbfe7212c..43aa713edc 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -609,7 +609,7 @@ class LayerCollection(object):
outputs,
approx=None,
reuse=VARIABLE_SCOPE):
- """Registers a fully connnected layer.
+ """Registers a fully connected layer.
Args:
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
@@ -975,7 +975,7 @@ class LayerCollection(object):
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
word `use` here has a completely different meaning to "use in the graph"
- as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.)
+ as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
(Default: "VARIABLE_SCOPE")
Raises:
@@ -1045,7 +1045,7 @@ class LayerCollection(object):
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
word `use` here has a completely different meaning to "use in the graph"
- as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.)
+ as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
(Default: "VARIABLE_SCOPE")
Raises:
@@ -1116,7 +1116,7 @@ class LayerCollection(object):
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
word `use` here has a completely different meaning to "use in the graph"
- as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.)
+ as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
(Default: "VARIABLE_SCOPE")
Raises:
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index 42d525c2c2..c8cebc42cb 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -214,7 +214,7 @@ class NegativeLogProbLoss(LossFunction):
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
- probability distribtion (whose log-prob defines the loss). Typically this
+ probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
@@ -238,7 +238,7 @@ class NegativeLogProbLoss(LossFunction):
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
- probability distribtion (whose log-prob defines the loss). Typically this
+ probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
@@ -262,7 +262,7 @@ class NegativeLogProbLoss(LossFunction):
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
- probability distribtion (whose log-prob defines the loss). Typically this
+ probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 03b9da7933..38605259b5 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -72,7 +72,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
(Higher damping means the update looks more like a standard gradient
update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher
- blocks, kronecker factors, and losses associated with the
+ blocks, Kronecker factors, and losses associated with the
graph. The layer_collection cannot be modified after KfacOptimizer's
initialization.
var_list: Optional list or tuple of variables to train. Defaults to the
@@ -99,7 +99,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
placement_strategy: string, Device placement strategy used when creating
covariance variables, covariance ops, and inverse ops.
(Default: `None`)
- **kwargs: Arguments to be passesd to specific placement
+ **kwargs: Arguments to be passed to specific placement
strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
Raises:
@@ -120,7 +120,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._estimation_mode = estimation_mode
self._colocate_gradients_with_ops = colocate_gradients_with_ops
- # The below parameters are required only if damping needs to be adapated.
+ # The below parameters are required only if damping needs to be adapted.
# These parameters can be set by calling
# set_damping_adaptation_params() explicitly.
self._damping_adaptation_decay = 0.95
@@ -574,7 +574,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
"""Wrapper function for `self._compute_qmodel_hyperparams`.
Constructs a list of preconditioned gradients and variables. Also creates a
- op to asssign the computed q model change to `self._q_model_change`.
+ op to assign the computed q model change to `self._q_model_change`.
Args:
grads_and_vars: List of (gradient, variable) pairs.
diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
index 7b28bb5e4d..95c7001371 100644
--- a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
+++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
@@ -164,11 +164,11 @@ class KinesisDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
const bool read_indefinitely, const int64 interval)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
stream_(stream),
shard_(shard),
read_indefinitely_(read_indefinitely),
diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py
index 1192198ec2..655f038b18 100644
--- a/tensorflow/contrib/layers/python/layers/initializers.py
+++ b/tensorflow/contrib/layers/python/layers/initializers.py
@@ -111,7 +111,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
if not dtype.is_floating:
raise TypeError('Cannot create initializer for non-floating point type.')
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
- raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
+ raise TypeError('Unknown mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
# pylint: disable=unused-argument
def _initializer(shape, dtype=dtype, partition_info=None):
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index d3aa3fa92c..418b0cf392 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -108,7 +108,6 @@ py_test(
size = "small",
srcs = ["python/learn/learn_io/data_feeder_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/python:client_testlib",
@@ -164,7 +163,6 @@ tf_py_test(
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
],
- tags = ["no_windows"], # TODO: needs investigation on Windows
)
py_test(
@@ -591,7 +589,6 @@ py_test(
size = "small",
srcs = ["python/learn/learn_io/io_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/contrib/learn/python/learn/datasets",
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 9872c6f97c..8ebe45d851 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -158,7 +158,7 @@ class SDCAOptimizer(object):
# exactly 2 (i.e., its shape should be [batch_size, column.dim]).
check_rank_op = control_flow_ops.Assert(
math_ops.less_equal(array_ops.rank(transformed_tensor), 2),
- ['transformed_tensor shouls have rank at most 2.'])
+ ['transformed_tensor should have rank at most 2.'])
# Reshape to [batch_size, dense_column_dimension].
with ops.control_dependencies([check_rank_op]):
transformed_tensor = array_ops.reshape(transformed_tensor, [
@@ -172,7 +172,7 @@ class SDCAOptimizer(object):
elif isinstance(column, layers.feature_column._BucketizedColumn): # pylint: disable=protected-access
# A bucketized column corresponds to a sparse feature in SDCA. The
# bucketized feature is "sparsified" for SDCA by converting it to a
- # SparseFeatureColumn respresenting the one-hot encoding of the
+ # SparseFeatureColumn representing the one-hot encoding of the
# bucketized feature.
#
# TODO(sibyl-vie3Poto): Explore whether it is more efficient to translate a
@@ -220,7 +220,7 @@ class SDCAOptimizer(object):
# occur multiple times for a single example.
projected_ids = projection_length * example_ids + flat_ids
- # Remove any redudant ids.
+ # Remove any redundant ids.
ids, idx = array_ops.unique(projected_ids)
# Keep only one example id per duplicated ids.
example_ids_filtered = math_ops.unsorted_segment_min(
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 81844756bc..ab694d768f 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -227,6 +227,7 @@ def generated_test_models():
"constant",
"control_dep",
"conv",
+ "conv_with_shared_weights",
"depthwiseconv",
"div",
"equal",
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 5bc20106d3..c920f6a508 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -452,13 +452,15 @@ typedef struct _TfLiteDelegate {
// Copy the data from delegate buffer handle to raw memory.
// This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate,
+ TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
void* data, size_t size);
// Copy the data from raw memory to delegate buffer handle.
// This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate,
+ TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
void* data, size_t size);
@@ -466,7 +468,7 @@ typedef struct _TfLiteDelegate {
// this doesn't release the underlying resource (e.g. textures). The
// resources are either owned by application layer or the delegate.
// This can be null if the delegate doesn't use its own buffer.
- void (*FreeBufferHandle)(TfLiteDelegate* delegate,
+ void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteBufferHandle* handle);
} TfLiteDelegate;
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index bb518becc5..5a7eb370f6 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -18,18 +18,21 @@ cc_library(
"//tensorflow/c:c_api_internal",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_all_cc",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
tf_cc_test(
name = "buffer_map_test",
size = "small",
srcs = ["buffer_map_test.cc"],
- tags = [
- "tflite_not_portable",
- ],
deps = [
":buffer_map",
"//tensorflow/contrib/lite:framework",
@@ -55,17 +58,20 @@ cc_library(
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
- "//tensorflow/core:lib",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ ],
+ }),
)
tf_cc_test(
name = "delegate_test",
size = "small",
srcs = ["delegate_test.cc"],
- tags = [
- "tflite_not_portable",
- ],
deps = [
":delegate",
":test_util",
@@ -80,19 +86,22 @@ cc_library(
hdrs = ["delegate_data.h"],
deps = [
":buffer_map",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+ }),
)
tf_cc_test(
name = "delegate_data_test",
size = "small",
srcs = ["delegate_data_test.cc"],
- tags = [
- "tflite_not_portable",
- ],
deps = [
":delegate_data",
"//tensorflow/contrib/lite:framework",
@@ -109,25 +118,28 @@ cc_library(
deps = [
":delegate_data",
":util",
+ "@flatbuffers",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
- "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
- "@flatbuffers",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
tf_cc_test(
name = "kernel_test",
size = "small",
srcs = ["kernel_test.cc"],
- tags = [
- "tflite_not_portable",
- ],
deps = [
":delegate_data",
":kernel",
@@ -159,18 +171,21 @@ cc_library(
"//tensorflow/c:c_api_internal",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:framework",
+ ],
+ }),
)
tf_cc_test(
name = "util_test",
size = "small",
srcs = ["util_test.cc"],
- tags = [
- "tflite_not_portable",
- ],
deps = [
":util",
"//tensorflow/contrib/lite:string",
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc
index 7d22b45419..8ab768575e 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc
@@ -55,17 +55,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
return kTfLiteOk;
}
-TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
+TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
+ TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle, void* data,
size_t size) {
- // TODO(nupurgarg): Make BufferMap unique to each interpreter in order to
- // support multiple interpreters using a single delegate.
BufferMap* buffer_map =
- reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap();
+ reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(context);
- // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf.
if (!buffer_map->HasTensor(buffer_handle)) {
- fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle);
+ context->ReportError(context, "Invalid tensor index %d.", buffer_handle);
return kTfLiteError;
}
@@ -73,7 +71,8 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
tensorflow::StringPiece t_data = t.tensor_data();
if (size != t_data.size()) {
- fprintf(stderr, "Not enough space to store TensorFlow's aligned buffer.\n");
+ context->ReportError(
+ context, "Not enough space to store TensorFlow's aligned buffer.");
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 0defca7c32..a07002f487 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -26,8 +26,8 @@ namespace tflite {
// executed by TensorFlow's runtime via Eager.
//
// The interpreter must be constructed after the EagerDelegate and destructed
-// before the EagerDelegate. This delegate can only be used with one
-// interpreter.
+// before the EagerDelegate. This delegate may be used with multiple
+// interpreters, but it is *not* thread-safe.
//
// Usage:
// EagerDelegate delegate;
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h
index 8a0e8ba8bf..772d26f44e 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h
@@ -32,14 +32,18 @@ class DelegateData {
// The EagerContext that is required for execution of Eager Ops.
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
- // Map from TF Lite tensor index to TensorFlow tensor.
- BufferMap* GetBufferMap() { return &buffer_map_; }
+ // Map from TF Lite tensor index to TensorFlow tensor for a given context.
+ BufferMap* GetBufferMap(const TfLiteContext* context) {
+ return &buffer_map_[context];
+ }
private:
explicit DelegateData(tensorflow::EagerContext* eager_context);
std::unique_ptr<tensorflow::EagerContext> eager_context_;
- BufferMap buffer_map_;
+ // TODO(b/112439500): Clean up stale BufferMap instances after adding the
+ // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
+ std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_;
};
} // namespace eager
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
index 30251b8f82..b3a0ffcec1 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
@@ -29,8 +30,12 @@ TEST(DelegateDataTest, Basic) {
// binary.
EXPECT_TRUE(DelegateData::Create(&data).ok());
+ TfLiteContext dummy_context1 = {};
+ TfLiteContext dummy_context2 = {};
EXPECT_NE(data->GetEagerContext(), nullptr);
- EXPECT_NE(data->GetBufferMap(), nullptr);
+ EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr);
+ EXPECT_NE(data->GetBufferMap(&dummy_context1),
+ data->GetBufferMap(&dummy_context2));
}
} // namespace
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index 88fb34044e..511a239363 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -25,8 +25,6 @@ namespace {
using ::testing::ContainsRegex;
using ::testing::ElementsAre;
-// TODO(nupurgarg): Add a test with multiple interpreters for one delegate.
-
class DelegateTest : public testing::EagerModelTest {
public:
DelegateTest() {
@@ -139,6 +137,56 @@ TEST_F(DelegateTest, OnlyTFLite) {
ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
}
+TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
+ // Build a graph, configure the delegate and set inputs.
+ {
+ AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kAdd, {1, 4}, {6});
+ AddTfOp(testing::kAdd, {2, 5}, {7});
+ AddTfOp(testing::kMul, {6, 7}, {8});
+ ConfigureDelegate();
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+ }
+
+ // Create a new interpreter, inject into the test framework and build
+ // a different graph using the *same* delegate.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
+ interpreter_.swap(interpreter);
+ {
+ AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kAdd, {1, 2}, {3});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfLiteMulOp({4, 5}, {6});
+ AddTfOp(testing::kUnpack, {6}, {7, 8});
+ AddTfOp(testing::kAdd, {7, 8}, {9});
+ ConfigureDelegate();
+ SetShape(0, {2, 2, 2, 1});
+ SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
+ }
+
+ // Swap back in the first interpreter and validate inference.
+ interpreter_.swap(interpreter);
+ {
+ ASSERT_TRUE(Invoke());
+ EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
+ EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+ }
+
+ // Swap in the second interpreter and validate inference.
+ interpreter_.swap(interpreter);
+ {
+ ASSERT_TRUE(Invoke());
+ EXPECT_THAT(GetShape(9), ElementsAre(1));
+ EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
+ }
+}
+
} // namespace
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index 1bd17a3bca..1082b78725 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -150,8 +150,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
op_data->eager_context =
reinterpret_cast<DelegateData*>(params->delegate->data_)
->GetEagerContext();
- op_data->buffer_map =
- reinterpret_cast<DelegateData*>(params->delegate->data_)->GetBufferMap();
+ op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
+ ->GetBufferMap(context);
CHECK(params->output_tensors);
for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
index b7bfbb34e4..66f2226626 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
@@ -55,12 +55,14 @@ class KernelTest : public testing::EagerModelTest {
delegate_.data_ = delegate_data_.get();
delegate_.FreeBufferHandle = nullptr;
delegate_.Prepare = prepare_function;
- delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate,
+ delegate_.CopyFromBufferHandle = [](TfLiteContext* context,
+ TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
void* data, size_t size) {
auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_);
- tensorflow::StringPiece values =
- delegate_data->GetBufferMap()->GetTensor(buffer_handle).tensor_data();
+ tensorflow::StringPiece values = delegate_data->GetBufferMap(context)
+ ->GetTensor(buffer_handle)
+ .tensor_data();
memcpy(data, values.data(), values.size());
return kTfLiteOk;
};
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc
index 03fcd5409c..646913c026 100644
--- a/tensorflow/contrib/lite/error_reporter.cc
+++ b/tensorflow/contrib/lite/error_reporter.cc
@@ -16,6 +16,10 @@ limitations under the License.
#include <cstdarg>
#include <cstdio>
+#ifdef __ANDROID__
+#include <android/log.h>
+#endif
+
namespace tflite {
ErrorReporter::~ErrorReporter() {}
@@ -39,6 +43,15 @@ int ErrorReporter::ReportError(void*, const char* format, ...) {
}
int StderrReporter::Report(const char* format, va_list args) {
+#ifdef __ANDROID__
+ // On Android stderr is not captured for applications, only for code run from
+ // the shell. Rather than assume all users will set up a custom error
+ // reporter, let's output to logcat here
+ va_list args_for_log;
+ va_copy(args_for_log, args);
+ __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log);
+ va_end(args_for_log);
+#endif
const int result = vfprintf(stderr, format, args);
fputc('\n', stderr);
return result;
diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
index 30fee64a6f..734b15e0a1 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
@@ -26,7 +26,7 @@
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
#define LOG(x) std::cerr
diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile
index cd8c39043f..8084307ac7 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_camera_example'
- pod 'TensorFlowLite', '0.1.7'
+ pod 'TensorFlowLite', '1.10.0'
diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile
index c885398f44..eea7ecb759 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_simple_example'
- pod 'TensorFlowLite', '0.1.7'
+ pod 'TensorFlowLite', '1.10.0'
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
index 0ab7aa25d0..650c73f732 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
@@ -25,7 +25,7 @@
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
#include "ios_image_load.h"
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 7a680f5c64..362e588725 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -157,7 +157,7 @@ Interpreter::~Interpreter() {
TfLiteTensor* tensor = &context_.tensors[i];
if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
tensor->delegate->FreeBufferHandle != nullptr) {
- tensor->delegate->FreeBufferHandle(tensor->delegate,
+ tensor->delegate->FreeBufferHandle(&context_, tensor->delegate,
&tensor->buffer_handle);
}
TfLiteTensorFree(tensor);
@@ -988,7 +988,7 @@ TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
tensor->delegate = delegate;
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr);
- tensor->delegate->FreeBufferHandle(tensor->delegate,
+ tensor->delegate->FreeBufferHandle(&context_, tensor->delegate,
&tensor->buffer_handle);
}
tensor->buffer_handle = buffer_handle;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 159ff7bc20..a27df4b964 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -350,7 +350,7 @@ class Interpreter {
// This can be null if the delegate doesn't use its own buffer.
TF_LITE_ENSURE(&context_,
tensor->delegate->CopyFromBufferHandle != nullptr);
- tensor->delegate->CopyFromBufferHandle(tensor->delegate,
+ tensor->delegate->CopyFromBufferHandle(&context_, tensor->delegate,
tensor->buffer_handle,
tensor->data.raw, tensor->bytes);
tensor->data_is_stale = false;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 2bf598bad7..f00697826c 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -1080,21 +1080,22 @@ class TestDelegate : public ::testing::Test {
return kTfLiteOk;
};
delegate_.CopyToBufferHandle =
- [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle,
- void* data, size_t size) -> TfLiteStatus {
+ [](TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle, void* data,
+ size_t size) -> TfLiteStatus {
// TODO(ycling): Implement tests to test buffer copying logic.
return kTfLiteOk;
};
delegate_.CopyFromBufferHandle =
- [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle,
- void* data, size_t size) -> TfLiteStatus {
+ [](TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle, void* data,
+ size_t size) -> TfLiteStatus {
// TODO(ycling): Implement tests to test buffer copying logic.
return kTfLiteOk;
};
- delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate,
- TfLiteBufferHandle* handle) {
- *handle = kTfLiteNullBufferHandle;
- };
+ delegate_.FreeBufferHandle =
+ [](TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; };
// Store type-punned data SimpleDelegate structure.
delegate_.data_ = reinterpret_cast<void*>(this);
}
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 817266a471..d6d62580e2 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -40,6 +40,11 @@ struct OpData {
int diff_min = 0;
};
+struct LogSoftmaxOpData : public OpData {
+ int32_t reverse_scaling_divisor = 0;
+ int32_t reverse_scaling_right_shift = 0;
+};
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
@@ -47,10 +52,19 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new OpData;
}
+void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
+ size_t length) {
+ return new LogSoftmaxOpData;
+}
+
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
+void LogSoftmaxFree(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<LogSoftmaxOpData*>(buffer);
+}
+
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -205,6 +219,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
+TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
+ TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
+
+ static const double kBeta = 1.0;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessLogSoftmaxScalingExp(
+ kBeta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift,
+ &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift);
+ data->reverse_scaling_right_shift *= -1;
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -509,6 +551,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ const LogSoftmaxOpData* data =
+ reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
@@ -517,6 +561,14 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<float>(input), GetTensorShape(input),
GetTensorData<float>(output), GetTensorShape(output));
return kTfLiteOk;
+ case kTfLiteUInt8:
+ optimized_ops::LogSoftmax(
+ GetTensorData<uint8_t>(input), GetTensorShape(input),
+ data->input_multiplier, data->input_left_shift,
+ data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
+ data->diff_min, GetTensorData<uint8_t>(output),
+ GetTensorShape(output));
+ return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
input->type);
@@ -590,9 +642,9 @@ TfLiteRegistration* Register_SOFTMAX() {
}
TfLiteRegistration* Register_LOG_SOFTMAX() {
- static TfLiteRegistration r = {activations::Init, activations::Free,
- activations::GenericPrepare,
- activations::LogSoftmaxEval};
+ static TfLiteRegistration r = {
+ activations::LogSoftmaxInit, activations::LogSoftmaxFree,
+ activations::LogSoftmaxPrepare, activations::LogSoftmaxEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 083cdf78d7..e577e3a762 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -471,6 +471,28 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
})));
}
+TEST(QuantizedActivationsOpTest, LogSoftmax) {
+ const float kLogSoftmaxQuantizedTolerance = 16 / 256.0;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOG_SOFTMAX,
+ /*input=*/{TensorType_UINT8, {2, 4}, -10, 10},
+ /*output=*/{TensorType_UINT8, {}, 0, 0, 16. / 256, 255});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ -4.14297, -10.14297, -2.14297, -.142971, //
+ -7.00104, -12.00104, -.00104087, -9.00104, //
+ },
+ kLogSoftmaxQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111}));
+}
+
class PReluOpModel : public SingleOpModel {
public:
PReluOpModel(const TensorData& input, const TensorData& alpha) {
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 04c0263b78..50fe5c2e04 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -334,18 +334,31 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- switch (kernel_type) {
+ KernelType effective_kernel_type;
+ if ((kernel_type == kMultithreadOptimized ||
+ kernel_type == kCblasOptimized) &&
+ (params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1)) {
+ // kMultithreadOptimized and kCblasOptimized do not support dilation.
+ // Therefore, fallback to optimized.
+ effective_kernel_type = kGenericOptimized;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ switch (effective_kernel_type) {
case kReference:
reference_ops::Conv(
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output), GetTensorData<uint8_t>(im2col),
- GetTensorDims(im2col), gemm_context);
+ params->stride_width, params->stride_height,
+ params->dilation_width_factor, params->dilation_height_factor,
+ data->padding.width, data->padding.height, output_offset,
+ data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
break;
case kGenericOptimized:
case kMultithreadOptimized:
@@ -355,12 +368,13 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output), GetTensorData<uint8_t>(im2col),
- GetTensorDims(im2col), gemm_context);
+ params->stride_width, params->stride_height,
+ params->dilation_width_factor, params->dilation_height_factor,
+ data->padding.width, data->padding.height, output_offset,
+ data->output_multiplier, data->output_shift,
+ data->output_activation_min, data->output_activation_max,
+ GetTensorData<uint8_t>(output), GetTensorDims(output),
+ GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
break;
}
}
@@ -374,10 +388,10 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
KernelType effective_kernel_type;
- if (((kernel_type == kMultithreadOptimized) ||
- (kernel_type == kCblasOptimized)) &&
- ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1))) {
+ if ((kernel_type == kMultithreadOptimized ||
+ kernel_type == kCblasOptimized) &&
+ (params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1)) {
// kMultithreadOptimized and kCblasOptimized do not support dilation.
// Therefore, fallback to optimized.
effective_kernel_type = kGenericOptimized;
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 24633c2fd7..98152043c9 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -370,6 +370,65 @@ TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357}));
}
+TEST_P(ConvolutionOpTest, SimpleTestFloatWithDilation) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const int dilation_width_factor = 3;
+ const int dilation_height_factor = 3;
+ const Padding padding = Padding_VALID;
+ ConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
+ ActivationFunctionType_NONE, dilation_width_factor,
+ dilation_height_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
public:
using BaseConvolutionOpModel::BaseConvolutionOpModel;
@@ -500,6 +559,71 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
}));
}
+TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int stride_width = 1;
+ const int stride_height = 1;
+ const int dilation_width_factor = 3;
+ const int dilation_height_factor = 3;
+ const Padding padding = Padding_VALID;
+ QuantizedConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, stride_width, stride_height, padding,
+ ActivationFunctionType_NONE, dilation_width_factor,
+ dilation_height_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
INSTANTIATE_TEST_CASE_P(
ConvolutionOpTest, ConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 87155e4ba4..a97db6c6b2 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -539,7 +539,10 @@ cc_test(
cc_test(
name = "depthwiseconv_quantized_test",
srcs = ["depthwiseconv_quantized_test.cc"],
- tags = ["no_oss"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
deps = [
":optimized_base",
":reference_base",
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index d5503073a7..7f0676be27 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -30,11 +30,6 @@ namespace optimized_ops {
using reference_ops::Relu1;
using reference_ops::Relu6;
-inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
- return RuntimeShape(
- {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
@@ -294,6 +289,37 @@ void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
output_data);
}
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul4DSlow(
+ input1_data, input1_dims, input1_offset, input2_data, input2_dims,
+ input2_offset, output_offset, output_multiplier,
+ // This legacy version switches the sign of the output shift.
+ kReverseShift * output_shift,
+ // (Break to highlight preceding line.)
+ output_activation_min, output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index b870789772..2d172315da 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -47,6 +47,7 @@ using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
using reference_ops::BroadcastLess;
using reference_ops::BroadcastLessEqual;
+using reference_ops::BroadcastMul4DSlow;
using reference_ops::BroadcastSub4DSlow;
using reference_ops::Concatenation;
using reference_ops::DepthConcatenation;
@@ -75,6 +76,11 @@ using reference_ops::Transpose;
// Used mainly to convert from old-style shifts (right) to new-style (left).
static constexpr int kReverseShift = -1;
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
// construct the suitable Eigen type for the constness of the
@@ -1978,12 +1984,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
@@ -1995,9 +2001,22 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
const Dims<4>* gemm_input_dims = nullptr;
const int filter_width = ArraySize(filter_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
+ const bool need_dilated_im2col =
+ dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
- if (need_im2col) {
+ if (need_dilated_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ const int input_zero_point = -input_offset;
+ TFLITE_DCHECK_GE(input_zero_point, 0);
+ TFLITE_DCHECK_LE(input_zero_point, 255);
+ DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
+ stride_height, dilation_width_factor, dilation_height_factor,
+ pad_width, pad_height, output_dims, input_zero_point,
+ im2col_data);
+ gemm_input_data = im2col_data;
+ gemm_input_dims = &im2col_dims;
+ } else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
@@ -2053,6 +2072,24 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2904,68 +2941,130 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
output_dims);
}
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
+// Element-wise mul that can often be used for inner loop of broadcast Mul as
+// well as the non-broadcast Mul.
+inline void MulElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+// Broadcast mul that can often be used for inner loop of broadcast Mul.
+inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
+ const uint8 broadcast_value,
+ const uint8* input2_data, uint8* output_data) {
+ const int32 input1_val = params.input1_offset + broadcast_value;
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 unclamped_result =
- output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp(
- input1_val * input2_val, output_multiplier,
- kReverseShift * output_shift);
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, unclamped_result));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
+ for (int i = 0; i < size; ++i) {
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
+
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ gemmlowp::ScopedProfilingLabel label("Mul/8bit");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ MulElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMulFivefold/8bit");
+
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise Mul of
+ // sections of the arrays.
+ uint8* output_data_ptr = output_data;
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ if (y4 > 1) {
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ for (int i3 = 0; i3 < y3; ++i3) {
+ MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
+ }
+ input1_data_ptr += y4;
}
}
+ input2_data_reset = input2_data_ptr;
+ }
+ } else {
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y3;
+ output_data_ptr += y3;
+ ++input1_data_ptr;
+ }
+ }
+ input2_data_reset = input2_data_ptr;
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
- input2_dims, input2_offset, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -5383,31 +5482,53 @@ void TypedMemset(void* ptr, T value, size_t num) {
}
}
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
+// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
+// scalar input that provides the padding value. Therefore pad_value_ptr can be
+// equivalent to a simple input1_data. For Pad, it should point to a zero
+// value.
+//
+// Note that two typenames are required, so that T=P=int32 is considered a
+// specialization distinct from P=int32.
+template <typename T, typename P>
+inline void PadImpl(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("Pad");
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
+ TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
+
+ // Runtime calls are currently fixed at 4 dimensions. Copy inputs so
+ // we can pad them to 4 dims (yes, we are "padding the padding").
+ std::vector<int> left_padding_copy(4, 0);
+ for (int i = 0; i < op_params.left_padding_count; ++i) {
+ left_padding_copy[i] = op_params.left_padding[i];
+ }
+ std::vector<int> right_padding_copy(4, 0);
+ for (int i = 0; i < op_params.right_padding_count; ++i) {
+ right_padding_copy[i] = op_params.right_padding[i];
+ }
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
+ const int output_batch = ext_output_shape.Dims(0);
+ const int output_height = ext_output_shape.Dims(1);
+ const int output_width = ext_output_shape.Dims(2);
+ const int output_depth = ext_output_shape.Dims(3);
- const int left_b_padding = left_paddings[3];
- const int left_h_padding = left_paddings[2];
- const int left_w_padding = left_paddings[1];
- const int left_d_padding = left_paddings[0];
+ const int left_b_padding = left_padding_copy[0];
+ const int left_h_padding = left_padding_copy[1];
+ const int left_w_padding = left_padding_copy[2];
+ const int left_d_padding = left_padding_copy[3];
- const int right_b_padding = right_paddings[3];
- const int right_h_padding = right_paddings[2];
- const int right_w_padding = right_paddings[1];
- const int right_d_padding = right_paddings[0];
+ const int right_b_padding = right_padding_copy[0];
+ const int right_h_padding = right_padding_copy[1];
+ const int right_w_padding = right_padding_copy[2];
+ const int right_d_padding = right_padding_copy[3];
- const int input_depth = ArraySize(input_dims, 0);
+ const int input_depth = ext_input_shape.Dims(3);
+ // const T pad_value = ExtractFloatOrInt<T>(op_params.pad_value);
+ const T pad_value = *pad_value_ptr;
if (left_b_padding != 0) {
TypedMemset<T>(
@@ -5417,61 +5538,113 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims,
for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
++out_b) {
if (left_h_padding != 0) {
- TypedMemset<T>(output_data + Offset(output_dims, 0, 0, 0, out_b),
+ TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0),
pad_value, left_h_padding * output_width * output_depth);
}
for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
++out_h) {
if (left_w_padding != 0) {
- TypedMemset<T>(output_data + Offset(output_dims, 0, 0, out_h, out_b),
- pad_value, left_w_padding * output_depth);
+ TypedMemset<T>(
+ output_data + Offset(ext_output_shape, out_b, out_h, 0, 0),
+ pad_value, left_w_padding * output_depth);
}
for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
++out_w) {
if (left_d_padding != 0) {
TypedMemset<T>(
- output_data + Offset(output_dims, 0, out_w, out_h, out_b),
+ output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0),
pad_value, left_d_padding);
}
T* out = output_data +
- Offset(output_dims, left_d_padding, out_w, out_h, out_b);
- const T* in =
- input_data + Offset(input_dims, 0, out_w - left_w_padding,
- out_h - left_h_padding, out_b - left_b_padding);
+ Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding);
+ const T* in = input_data +
+ Offset(ext_input_shape, out_b - left_b_padding,
+ out_h - left_h_padding, out_w - left_w_padding, 0);
memcpy(out, in, input_depth * sizeof(T));
if (right_d_padding != 0) {
TypedMemset<T>(
- output_data + Offset(output_dims, output_depth - right_d_padding,
- out_w, out_h, out_b),
+ output_data + Offset(ext_output_shape, out_b, out_h, out_w,
+ output_depth - right_d_padding),
pad_value, right_d_padding);
}
}
if (right_w_padding != 0) {
- TypedMemset<T>(
- output_data + Offset(output_dims, 0, output_width - right_w_padding,
- out_h, out_b),
- pad_value, right_w_padding * output_depth);
+ TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_h,
+ output_width - right_w_padding, 0),
+ pad_value, right_w_padding * output_depth);
}
}
if (right_h_padding != 0) {
TypedMemset<T>(
- output_data +
- Offset(output_dims, 0, 0, output_height - right_h_padding, out_b),
+ output_data + Offset(ext_output_shape, out_b,
+ output_height - right_h_padding, 0, 0),
pad_value, right_h_padding * output_width * output_depth);
}
}
if (right_b_padding != 0) {
TypedMemset<T>(
output_data +
- Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
+ Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0),
pad_value,
right_b_padding * output_height * output_width * output_depth);
}
}
-// Legacy Pad() method that casts an int32_t to T before padding.
+template <typename T, typename P>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
+}
+
+// The second (pad-value) input can be int32 when, say, the first is uint8.
+template <typename T>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ const T converted_pad_value = static_cast<T>(*pad_value_ptr);
+ PadImpl(op_params, input_shape, input_data, &converted_pad_value,
+ output_shape, output_data);
+}
+
+// This version avoids conflicting template matching.
+template <>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const int32* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ int32* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
@@ -5482,34 +5655,45 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
output_dims, converted_pad_value);
}
+// Old Pad that only padded with 0.
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
const std::vector<int>& right_paddings, T* output_data,
const Dims<4>& output_dims) {
- Pad(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, 0);
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- // TODO(dkalenichenko): This op only supports 4D tensors.
- TFLITE_DCHECK_EQ(begin.size(), 4);
- TFLITE_DCHECK_EQ(size.size(), 4);
- const int start_b = begin[3];
- const int stop_b =
- size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
- const int start_h = begin[2];
- const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
- const int start_w = begin[1];
- const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
- const int start_d = begin[0];
- const int stop_d =
- size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+inline void Slice(const tflite::SliceParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Slice");
+ RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
+ TFLITE_DCHECK_LE(op_params.begin_count, 4);
+ TFLITE_DCHECK_LE(op_params.size_count, 4);
+ const int begin_count = op_params.begin_count;
+ const int size_count = op_params.size_count;
+ // We front-pad the begin and size vectors.
+ const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
+ const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
+ ? ext_shape.Dims(0) - start_b
+ : start_b + op_params.size[0];
+ const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
+ const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
+ ? ext_shape.Dims(1) - start_h
+ : start_h + op_params.size[size_count - 3];
+ const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
+ const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
+ ? ext_shape.Dims(2) - start_w
+ : start_w + op_params.size[size_count - 2];
+ const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
+ const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
+ ? ext_shape.Dims(3) - start_d
+ : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
@@ -5517,7 +5701,7 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
for (int in_w = start_w; in_w < stop_w; ++in_w) {
const int len = stop_d - start_d;
memcpy(out_ptr,
- input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
+ input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
len * sizeof(T));
out_ptr += len;
}
@@ -5526,28 +5710,60 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
auto min_value = input2_data[0];
output_map.array() = input1_map.array().min(min_value);
}
template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
+void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
auto max_value = input2_data[0];
output_map.array() = input1_map.array().max(max_value);
}
template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
const Dims<4>& filter_dims, int stride_width,
int stride_height, int pad_width, int pad_height,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index bcf5e4e4f6..b862ae38c7 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -26,11 +26,6 @@ namespace tflite {
namespace reference_ops {
-inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
- return RuntimeShape(
- {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
@@ -316,6 +311,37 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul4DSlow(
+ input1_data, input1_dims, input1_offset, input2_data, input2_dims,
+ input2_offset, output_offset, output_multiplier,
+ //
+ kReverseShift * output_shift,
+ //
+ output_activation_min, output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
+ int32 input1_offset, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
+ input2_dims, input2_offset, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void AveragePool(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index f4176e474e..cb254f36cc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -105,6 +105,11 @@ namespace reference_ops {
// Used mainly to convert from old-style shifts (right) to new-style (left).
static constexpr int kReverseShift = -1;
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
@@ -271,12 +276,12 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
gemmlowp::GemmContext* gemm_context) {
(void)im2col_data; // only used in optimized code.
(void)im2col_dims; // only used in optimized code.
@@ -302,8 +307,9 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -335,6 +341,24 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -1374,13 +1398,143 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
output_dims);
}
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+// Element-wise mul that can often be used for inner loop of broadcast Mul as
+// well as the non-broadcast Mul.
+inline void MulElementwise(int size, const ArithmeticParams& params,
+ const uint8* input1_data, const uint8* input2_data,
+ uint8* output_data) {
+ for (int i = 0; i < size; ++i) {
+ const int32 input1_val = params.input1_offset + input1_data[i];
+ const int32 input2_val = params.input2_offset + input2_data[i];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
+ params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output =
+ std::min(params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[i] = static_cast<uint8>(clamped_output);
+ }
+}
+
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const uint8* input1_data,
+ const RuntimeShape& input2_shape, const uint8* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ gemmlowp::ScopedProfilingLabel label("Mul/8bit");
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+
+ MulElementwise(flat_size, params, input1_data, input2_data, output_data);
+}
+
+inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
+ const RuntimeShape& unswitched_input1_shape,
+ const uint8* unswitched_input1_data,
+ const RuntimeShape& unswitched_input2_shape,
+ const uint8* unswitched_input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ ArithmeticParams switched_params = unswitched_params;
+ switched_params.input1_offset = unswitched_params.input2_offset;
+ switched_params.input2_offset = unswitched_params.input1_offset;
+
+ const bool use_unswitched =
+ unswitched_params.broadcast_category ==
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+
+ const ArithmeticParams& params =
+ use_unswitched ? unswitched_params : switched_params;
+ const uint8* input1_data =
+ use_unswitched ? unswitched_input1_data : unswitched_input2_data;
+ const uint8* input2_data =
+ use_unswitched ? unswitched_input2_data : unswitched_input1_data;
+
+ // Fivefold nested loops. The second input resets its position for each
+ // iteration of the second loop. The first input resets its position at the
+ // beginning of the fourth loop. The innermost loop is an elementwise Mul of
+ // sections of the arrays.
+ uint8* output_data_ptr = output_data;
+ const uint8* input1_data_ptr = input1_data;
+ const uint8* input2_data_reset = input2_data;
+ int y0 = params.broadcast_shape[0];
+ int y1 = params.broadcast_shape[1];
+ int y2 = params.broadcast_shape[2];
+ int y3 = params.broadcast_shape[3];
+ int y4 = params.broadcast_shape[4];
+ for (int i0 = 0; i0 < y0; ++i0) {
+ const uint8* input2_data_ptr;
+ for (int i1 = 0; i1 < y1; ++i1) {
+ input2_data_ptr = input2_data_reset;
+ for (int i2 = 0; i2 < y2; ++i2) {
+ for (int i3 = 0; i3 < y3; ++i3) {
+ MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
+ output_data_ptr);
+ input2_data_ptr += y4;
+ output_data_ptr += y4;
+ }
+ input1_data_ptr += y4;
+ }
+ }
+ input2_data_reset = input2_data_ptr;
+ }
+}
+
+inline void BroadcastMul4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const uint8* input1_data,
+ const RuntimeShape& input2_shape,
+ const uint8* input2_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit");
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+ RuntimeShape extended_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
+
+ for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
+ for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
+ for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
+ for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
+ const int32 input1_val =
+ params.input1_offset +
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)];
+ const int32 input2_val =
+ params.input2_offset +
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+ const int32 unclamped_result =
+ params.output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, params.output_multiplier,
+ params.output_shift);
+ const int32 clamped_output = std::min(
+ params.quantized_activation_max,
+ std::max(params.quantized_activation_min, unclamped_result));
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ static_cast<uint8>(clamped_output);
+ }
+ }
+ }
+ }
+}
+
+// Transitional version that will be moved shortly to legacy_reference_ops, as
+// part of RuntimeShape revisions.
+inline void BroadcastMul4DSlow(const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
NdArrayDesc<4> desc1;
@@ -1407,9 +1561,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
const int32 input2_val =
input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
const int32 unclamped_result =
- output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp(
- input1_val * input2_val, output_multiplier,
- kReverseShift * output_shift);
+ output_offset +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, output_multiplier, output_shift);
const int32 clamped_output =
std::min(output_activation_max,
std::max(output_activation_min, unclamped_result));
@@ -1464,21 +1618,6 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
- int32 input1_offset, const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
- input2_dims, input2_offset, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -3370,28 +3509,50 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
}
}
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
-
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int left_b_padding = left_paddings[3];
- const int left_h_padding = left_paddings[2];
- const int left_w_padding = left_paddings[1];
- const int left_d_padding = left_paddings[0];
-
- const int right_b_padding = right_paddings[3];
- const int right_h_padding = right_paddings[2];
- const int right_w_padding = right_paddings[1];
- const int right_d_padding = right_paddings[0];
+// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
+// scalar input that provides the padding value. Therefore pad_value_ptr can be
+// equivalent to a simple input1_data. For Pad, it should point to a zero
+// value.
+//
+// Note that two typenames are required, so that T=P=int32 is considered a
+// specialization distinct from P=int32.
+template <typename T, typename P>
+inline void PadImpl(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
+ TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
+
+ // Runtime calls are currently fixed at 4 dimensions. Copy inputs so
+ // we can pad them to 4 dims (yes, we are "padding the padding").
+ std::vector<int> left_padding_copy(4, 0);
+ for (int i = 0; i < op_params.left_padding_count; ++i) {
+ left_padding_copy[i] = op_params.left_padding[i];
+ }
+ std::vector<int> right_padding_copy(4, 0);
+ for (int i = 0; i < op_params.right_padding_count; ++i) {
+ right_padding_copy[i] = op_params.right_padding[i];
+ }
+
+ const int output_batch = ext_output_shape.Dims(0);
+ const int output_height = ext_output_shape.Dims(1);
+ const int output_width = ext_output_shape.Dims(2);
+ const int output_depth = ext_output_shape.Dims(3);
+
+ const int left_b_padding = left_padding_copy[0];
+ const int left_h_padding = left_padding_copy[1];
+ const int left_w_padding = left_padding_copy[2];
+ const int left_d_padding = left_padding_copy[3];
+
+ const int right_b_padding = right_padding_copy[0];
+ const int right_h_padding = right_padding_copy[1];
+ const int right_w_padding = right_padding_copy[2];
+ const int right_d_padding = right_padding_copy[3];
+
+ const T pad_value = *pad_value_ptr;
const T* in_ptr = input_data;
T* out_ptr = output_data;
@@ -3417,7 +3578,59 @@ inline void PadV2(const T* input_data, const Dims<4>& input_dims,
}
}
-// Legacy Pad() method that casts an int32_t to T before padding.
+template <typename T, typename P>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const P* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
+}
+
+// The second (pad-value) input can be int32 when, say, the first is uint8.
+template <typename T>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ T* output_data) {
+ const T converted_pad_value = static_cast<T>(*pad_value_ptr);
+ PadImpl(op_params, input_shape, input_data, &converted_pad_value,
+ output_shape, output_data);
+}
+
+// This version avoids conflicting template matching.
+template <>
+inline void Pad(const tflite::PadParams& op_params,
+ const RuntimeShape& input_shape, const int32* input_data,
+ const int32* pad_value_ptr, const RuntimeShape& output_shape,
+ int32* output_data) {
+ PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
@@ -3428,13 +3641,15 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
output_dims, converted_pad_value);
}
+// Old Pad that only padded with 0.
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
const std::vector<int>& right_paddings, T* output_data,
const Dims<4>& output_dims) {
- Pad(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, 0);
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
}
template <typename T>
@@ -3491,31 +3706,39 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- // TODO(dkalenichenko): This op only supports 4D tensors.
- TFLITE_DCHECK_EQ(begin.size(), 4);
- TFLITE_DCHECK_EQ(size.size(), 4);
- const int start_b = begin[3];
- const int stop_b =
- size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
- const int start_h = begin[2];
- const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
- const int start_w = begin[1];
- const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
- const int start_d = begin[0];
- const int stop_d =
- size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
+inline void Slice(const tflite::SliceParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
+ TFLITE_DCHECK_LE(op_params.begin_count, 4);
+ TFLITE_DCHECK_LE(op_params.size_count, 4);
+ const int begin_count = op_params.begin_count;
+ const int size_count = op_params.size_count;
+ // We front-pad the begin and size vectors.
+ const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
+ const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
+ ? ext_shape.Dims(0) - start_b
+ : start_b + op_params.size[0];
+ const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
+ const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
+ ? ext_shape.Dims(1) - start_h
+ : start_h + op_params.size[size_count - 3];
+ const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
+ const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
+ ? ext_shape.Dims(2) - start_w
+ : start_w + op_params.size[size_count - 2];
+ const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
+ const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
+ ? ext_shape.Dims(3) - start_d
+ : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
for (int in_d = start_d; in_d < stop_d; ++in_d) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ *out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)];
}
}
}
@@ -3523,6 +3746,22 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) {
for (size_t idx = 0; idx < num_elements; ++idx) {
@@ -3790,10 +4029,10 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims);
+void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int flat_size = MatchingFlatSize(input1_shape, output_shape);
auto min_value = input2_data[0];
for (int i = 0; i < flat_size; i++) {
@@ -3802,10 +4041,10 @@ void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims);
+void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const T* input2_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int flat_size = MatchingFlatSize(input1_shape, output_shape);
auto max_value = input2_data[0];
for (int i = 0; i < flat_size; i++) {
@@ -3813,6 +4052,22 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
}
}
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T, typename Op>
void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index c44698b677..7b6838db53 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -129,6 +129,13 @@ class RuntimeShape {
}
}
+ RuntimeShape(int shape_size, int32 value) : size_(0) {
+ Resize(shape_size);
+ for (int i = 0; i < shape_size; ++i) {
+ SetDim(i, value);
+ }
+ }
+
RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
ReplaceWith(dimensions_count, dims_data);
}
@@ -237,7 +244,7 @@ class RuntimeShape {
bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
private:
- // For use only by ExtendFrom(), written to guarantee (return-value) copy
+ // For use only by ExtendedShape(), written to guarantee (return-value) copy
// elision in C++17.
// This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
@@ -645,22 +652,6 @@ void ComputeStrides(Dims<N>* dims) {
}
}
-struct PoolParams {
- FusedActivationFunctionType activation;
- PaddingType padding_type;
- PaddingValues padding_values;
- int stride_height;
- int stride_width;
- int filter_height;
- int filter_width;
- // uint8, etc, activation params.
- int32 quantized_activation_min;
- int32 quantized_activation_max;
- // float activation params.
- float float_activation_min;
- float float_activation_max;
-};
-
enum class BroadcastableOpCategory : uint8 {
kNone,
kNonBroadcast, // Matching input shapes.
@@ -721,6 +712,37 @@ inline void SetActivationParams(int32 min, int32 max,
params->quantized_activation_max = max;
}
+struct PadParams {
+ int8 left_padding_count;
+ int32 left_padding[4];
+ int8 right_padding_count;
+ int32 right_padding[4];
+ // FloatOrInt pad_value;
+};
+
+struct PoolParams {
+ FusedActivationFunctionType activation;
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int stride_height;
+ int stride_width;
+ int filter_height;
+ int filter_width;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
+struct SliceParams {
+ int8 begin_count;
+ int32 begin[4];
+ int8 size_count;
+ int32 size[4];
+};
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 349f3e6726..561e39cfc6 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -93,7 +93,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input1->params.scale * input2->params.scale / output->params.scale;
QuantizeMultiplierSmallerThanOneExp(
real_multiplier, &data->output_multiplier, &data->output_shift);
- data->output_shift *= -1;
}
return context->ResizeTensor(context, output, output_size);
@@ -161,9 +160,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
// The quantized version of Mul doesn't support activations, so we
// always use BroadcastMul.
if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow);
} else {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow);
}
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 8d2c108116..6159311910 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -127,9 +127,11 @@ const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
int version) const {
- // Return the NULL Op for all ops whose name start with "Eager:", allowing
+ // Return the NULL Op for all ops whose name start with "Eager", allowing
// the interpreter to delegate their execution.
- if (string(op).find("Eager:") == 0) {
+ // TODO(ycling): Refactoring and extract an `IsEagerOp` function into
+ // `lite:framework` build target.
+ if (string(op).find("Eager") == 0) {
static TfLiteRegistration null_op{
nullptr, nullptr, &UnsupportedTensorFlowOp,
nullptr, nullptr, BuiltinOperator_CUSTOM,
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 13325a8c7c..45c92a8671 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -24,20 +24,27 @@ limitations under the License.
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
#ifdef __ANDROID__
+#include <android/log.h>
#include <sys/system_properties.h>
#endif
namespace tflite {
void logError(const char* format, ...) {
- // TODO(mikie): use android logging, stderr is not captured for Java
- // applications
- va_list args;
- va_start(args, format);
- vfprintf(stderr, format, args);
- va_end(args);
+ // stderr is convenient for native tests, but is not captured for apps
+ va_list args_for_stderr;
+ va_start(args_for_stderr, format);
+ vfprintf(stderr, format, args_for_stderr);
+ va_end(args_for_stderr);
fprintf(stderr, "\n");
fflush(stderr);
+#ifdef __ANDROID__
+ // produce logcat output for general consumption
+ va_list args_for_log;
+ va_start(args_for_log, format);
+ __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log);
+ va_end(args_for_log);
+#endif
}
#define FATAL(...) \
@@ -564,8 +571,14 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_L2_NORMALIZATION;
if (reinterpret_cast<TfLiteL2NormParams*>(node.builtin_data)
->activation != kTfLiteActNone) {
- FATAL(
+ logError(
"NNAPI does not support L2Normalization with fused activations");
+ return kTfLiteError;
+ }
+ if ((node.inputs->size > 0) &&
+ (interpreter->tensor(node.inputs->data[0])->dims->size != 4)) {
+ logError("NNAPI only supports input rank 4 for L2Normalization");
+ return kTfLiteError;
}
break;
case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 52ef43d71f..5ec52035ad 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -53,6 +53,7 @@ from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util
+from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -193,40 +194,41 @@ class TocoConverter(object):
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
"""
- with _session.Session() as sess:
- # Read GraphDef from file.
- graph_def = _graph_pb2.GraphDef()
- with open(graph_def_file, "rb") as f:
- file_content = f.read()
- try:
- graph_def.ParseFromString(file_content)
- except (_text_format.ParseError, DecodeError):
+ with _ops.Graph().as_default():
+ with _session.Session() as sess:
+ # Read GraphDef from file.
+ graph_def = _graph_pb2.GraphDef()
+ with open(graph_def_file, "rb") as f:
+ file_content = f.read()
try:
- print("Ignore 'tcmalloc: large alloc' warnings.")
-
- if not isinstance(file_content, str):
- if PY3:
- file_content = file_content.decode('utf-8')
- else:
- file_content = file_content.encode('utf-8')
- _text_format.Merge(file_content, graph_def)
+ graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
- raise ValueError(
- "Unable to parse input file '{}'.".format(graph_def_file))
- sess.graph.as_default()
- _import_graph_def(graph_def, name="")
-
- # Get input and output tensors.
- input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
- _set_tensor_shapes(input_tensors, input_shapes)
-
- # Check if graph is frozen.
- if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py.")
-
- # Create TocoConverter class.
- return cls(sess.graph_def, input_tensors, output_tensors)
+ try:
+ print("Ignore 'tcmalloc: large alloc' warnings.")
+
+ if not isinstance(file_content, str):
+ if PY3:
+ file_content = file_content.decode("utf-8")
+ else:
+ file_content = file_content.encode("utf-8")
+ _text_format.Merge(file_content, graph_def)
+ except (_text_format.ParseError, DecodeError):
+ raise ValueError(
+ "Unable to parse input file '{}'.".format(graph_def_file))
+ _import_graph_def(graph_def, name="")
+
+ # Get input and output tensors.
+ input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(sess.graph,
+ output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
+
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
+
+ # Create TocoConverter class.
+ return cls(sess.graph_def, input_tensors, output_tensors)
@classmethod
def from_saved_model(cls,
diff --git a/tensorflow/contrib/lite/rpi_makefile.inc b/tensorflow/contrib/lite/rpi_makefile.inc
deleted file mode 100644
index 832ef5824b..0000000000
--- a/tensorflow/contrib/lite/rpi_makefile.inc
+++ /dev/null
@@ -1,33 +0,0 @@
-# Settings for Raspberry Pi.
-ifeq ($(TARGET), RPI)
- ifeq ($(TARGET_ARCH), armv7)
- CXXFLAGS += \
- -march=armv7-a \
- -mfpu=neon-vfpv4 \
- -funsafe-math-optimizations \
- -ftree-vectorize
-
- CCFLAGS += \
- -march=armv7-a \
- -mfpu=neon-vfpv4 \
- -funsafe-math-optimizations \
- -ftree-vectorize
-
- LDFLAGS := \
- -Wl,--no-export-dynamic \
- -Wl,--exclude-libs,ALL \
- -Wl,--gc-sections \
- -Wl,--as-needed
- endif
-
- LIBS := \
- -lstdc++ \
- -lpthread \
- -lm \
- -ldl
-
- OBJDIR := $(OBJDIR)rpi_$(TARGET_ARCH)/
- LIBDIR := $(LIBDIR)rpi_$(TARGET_ARCH)/
- BINDIR := $(BINDIR)rpi_$(TARGET_ARCH)/
- DEPDIR := $(DEPDIR)rpi_$(TARGET_ARCH)/
-endif
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index a788d41ba7..89912fd116 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -162,11 +162,12 @@ cc_library(
":test_runner",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/delegates/eager:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
-cc_test(
+tf_cc_test(
name = "tflite_driver_test",
size = "small",
srcs = ["tflite_driver_test.cc"],
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 52ef0d5b86..9dd5c8ae44 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1255,6 +1255,75 @@ def make_conv_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+# Note: This is a regression test for a bug (b/112436267) that Toco incorrectly
+# fuses weights when multiple Conv2D/FULLY_CONNECTED ops share the same constant
+# weight tensor.
+def make_conv_with_shared_weights_tests(zip_path):
+ """Make a test where 2 Conv ops shared the same constant weight tensor."""
+
+ test_parameters = [{
+ "input_shape": [[1, 10, 10, 3]],
+ "filter_shape": [[3, 3]],
+ "strides": [[1, 1, 1, 1]],
+ "dilations": [[1, 1, 1, 1]],
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ "channel_multiplier": [1],
+ }]
+
+ def get_tensor_shapes(parameters):
+ input_shape = parameters["input_shape"]
+ filter_size = parameters["filter_shape"]
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]
+ ]
+ return [input_shape, filter_shape]
+
+ def build_graph(parameters):
+ """Build a conv graph given `parameters`."""
+ input_shape, filter_shape = get_tensor_shapes(parameters)
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=input_shape)
+
+ # Construct a constant weights tensor which will be used by both Conv2D.
+ filter_tensor = tf.constant(
+ create_tensor_data(np.float32, filter_shape), dtype=tf.float32)
+ input_tensors = [input_tensor]
+
+ # Construct 2 Conv2D operations which use exactly the same input and
+ # weights.
+ result1 = tf.nn.conv2d(
+ input_tensor,
+ filter_tensor,
+ strides=parameters["strides"],
+ dilations=parameters["dilations"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ result2 = tf.nn.conv2d(
+ input_tensor,
+ filter_tensor,
+ strides=parameters["strides"],
+ dilations=parameters["dilations"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ # Add MUL ops after Conv2D ops. These MUL ops should be fused into the
+ # weights of Conv2D.
+ result1 = result1 * 2
+ result2 = result2 * 3
+ # Add the 2 results up.
+ out = result1 + result2
+ return input_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ # Build list of input values either containing 1 tensor (input) or 2 tensors
+ # (input, filter) based on whether filter is constant or variable input.
+ input_shape, unused_filter_shape = get_tensor_shapes(parameters)
+ values = [create_tensor_data(np.float32, input_shape)]
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_depthwiseconv_tests(zip_path):
"""Make a set of tests to do convolution."""
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc
index f29c188e6c..62cbeccd33 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.cc
+++ b/tensorflow/contrib/lite/testing/generate_testspec.cc
@@ -114,7 +114,13 @@ bool GenerateTestSpecFromTensorflowModel(
// different set.
std::vector<string> input_values =
GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
- if (input_values.empty()) return false;
+ if (input_values.empty()) {
+ std::cerr << "Unable to generate input values for the TensorFlow model. "
+ "Make sure the correct values are defined for "
+ "input_layer, input_layer_type, and input_layer_shape."
+ << std::endl;
+ return false;
+ }
// Run TensorFlow.
for (int j = 0; j < input_values.size(); j++) {
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index e475f256c0..e67fee2a1c 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -33,13 +33,18 @@ namespace testing {
namespace {
bool FLAGS_ignore_known_bugs = true;
-// TODO(b/71769302) zip_files_dir should have a more accurate default, if
-// possible
-string* FLAGS_zip_file_path = new string("./");
+// As archive file names are test-specific, no default is possible.
+//
+// This test supports input as both zip and tar, as a stock android image does
+// not have unzip but does have tar.
+string* FLAGS_zip_file_path = new string;
+string* FLAGS_tar_file_path = new string;
#ifndef __ANDROID__
string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
+string* FLAGS_tar_binary_path = new string("/bin/tar");
#else
string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
+string* FLAGS_tar_binary_path = new string("/system/bin/tar");
#endif
bool FLAGS_use_nnapi = false;
bool FLAGS_ignore_unsupported_nnapi = false;
@@ -98,11 +103,11 @@ std::map<string, string> kBrokenTests = {
"77546240"},
};
-// Allows test data to be unzipped into a temporary directory and makes
+// Allows test data to be unarchived into a temporary directory and makes
// sure those temporary directories are removed later.
-class ZipEnvironment : public ::testing::Environment {
+class ArchiveEnvironment : public ::testing::Environment {
public:
- ~ZipEnvironment() override {}
+ ~ArchiveEnvironment() override {}
// Delete all temporary directories on teardown.
void TearDown() override {
@@ -114,15 +119,26 @@ class ZipEnvironment : public ::testing::Environment {
temporary_directories_.clear();
}
- // Unzip `zip` file into a new temporary directory `out_dir`.
- tensorflow::Status UnZip(const string& zip, string* out_dir) {
+ // Unarchive `archive` file into a new temporary directory `out_dir`.
+ tensorflow::Status UnArchive(const string& zip, const string& tar,
+ string* out_dir) {
string dir;
TF_CHECK_OK(MakeTemporaryDirectory(&dir));
tensorflow::SubProcess proc;
- string unzip_binary = *FLAGS_unzip_binary_path;
- TF_CHECK_OK(env->FileExists(unzip_binary));
- TF_CHECK_OK(env->FileExists(zip));
- proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
+ if (!zip.empty()) {
+ string unzip_binary = *FLAGS_unzip_binary_path;
+ TF_CHECK_OK(env->FileExists(unzip_binary));
+ TF_CHECK_OK(env->FileExists(zip));
+ proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
+ } else {
+ string tar_binary = *FLAGS_tar_binary_path;
+ TF_CHECK_OK(env->FileExists(tar_binary));
+ TF_CHECK_OK(env->FileExists(tar));
+ // 'o' needs to be explicitly set on Android so that
+ // untarring works as non-root (otherwise tries to chown
+ // files, which fails)
+ proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir});
+ }
proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
if (!proc.Start())
@@ -156,15 +172,15 @@ class ZipEnvironment : public ::testing::Environment {
std::vector<string> temporary_directories_;
};
-// Return the singleton zip_environment.
-ZipEnvironment* zip_environment() {
- static ZipEnvironment* env = new ZipEnvironment;
+// Return the singleton archive_environment.
+ArchiveEnvironment* archive_environment() {
+ static ArchiveEnvironment* env = new ArchiveEnvironment;
return env;
}
-// Read the manifest.txt out of the unarchived zip file. Specifically
+// Read the manifest.txt out of the unarchived archive file. Specifically
// `original_file` is the original zip file for error messages. `dir` is
-// the temporary directory where the zip file has been unarchived and
+// the temporary directory where the archive file has been unarchived and
// `test_paths` is the list of test prefixes that were in the manifest.
// Note, it is an error for a manifest to contain no tests.
tensorflow::Status ReadManifest(const string& original_file, const string& dir,
@@ -190,12 +206,22 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir,
return tensorflow::Status::OK();
}
-// Get a list of tests from a zip file `zip_file_name`.
-std::vector<string> UnarchiveZipAndFindTestNames(const string& zip_file) {
+// Get a list of tests from either zip or tar file
+std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
+ const string& tar_file) {
+ if (zip_file.empty() && tar_file.empty()) {
+ TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN,
+ "Neither zip_file nor tar_file was given"));
+ }
string decompress_tmp_dir;
- TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir));
+ TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file,
+ &decompress_tmp_dir));
std::vector<string> stuff;
- TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ if (!zip_file.empty()) {
+ TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ } else {
+ TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff));
+ }
return stuff;
}
@@ -223,8 +249,7 @@ TEST_P(OpsTest, RunZipTests) {
string message = test_driver.GetErrorMessage();
if (bug_number.empty()) {
if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
- EXPECT_EQ(message, string("Failed to invoke NNAPI interpreter"))
- << message;
+ EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
} else {
EXPECT_TRUE(result) << message;
}
@@ -256,27 +281,34 @@ struct ZipPathParamName {
}
};
-INSTANTIATE_TEST_CASE_P(
- tests, OpsTest,
- ::testing::ValuesIn(UnarchiveZipAndFindTestNames(*FLAGS_zip_file_path)),
- ZipPathParamName());
+INSTANTIATE_TEST_CASE_P(tests, OpsTest,
+ ::testing::ValuesIn(UnarchiveAndFindTestNames(
+ *FLAGS_zip_file_path, *FLAGS_tar_file_path)),
+ ZipPathParamName());
} // namespace testing
} // namespace tflite
int main(int argc, char** argv) {
- ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment());
+ ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment());
std::vector<tensorflow::Flag> flags = {
tensorflow::Flag(
"ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs,
"If a particular model is affected by a known bug, the "
"corresponding test should expect the outputs to not match."),
- tensorflow::Flag("zip_file_path", tflite::testing::FLAGS_zip_file_path,
- "Required: Location of the test zip file."),
+ tensorflow::Flag(
+ "tar_file_path", tflite::testing::FLAGS_tar_file_path,
+ "Required (or zip_file_path): Location of the test tar file."),
+ tensorflow::Flag(
+ "zip_file_path", tflite::testing::FLAGS_zip_file_path,
+ "Required (or tar_file_path): Location of the test zip file."),
tensorflow::Flag("unzip_binary_path",
tflite::testing::FLAGS_unzip_binary_path,
- "Required: Location of a suitable unzip binary."),
+ "Location of a suitable unzip binary."),
+ tensorflow::Flag("tar_binary_path",
+ tflite::testing::FLAGS_tar_binary_path,
+ "Location of a suitable tar binary."),
tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
"Whether to enable the NNAPI delegate"),
tensorflow::Flag("ignore_unsupported_nnapi",
diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc
index ec435ca60d..30381ba028 100644
--- a/tensorflow/contrib/lite/testing/tf_driver.cc
+++ b/tensorflow/contrib/lite/testing/tf_driver.cc
@@ -179,7 +179,9 @@ void TfDriver::Invoke() {
auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
output_names_, {}, &output_tensors_);
if (!status.ok()) {
- Invalidate("Failed to run input data on graph");
+ Invalidate(
+ "Failed to run input data on graph. Make sure the correct value is "
+ "defined for the input and output arrays.");
}
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 695c2a3de6..3874bc31d7 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -33,6 +33,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_shape;
string output_layer;
int32_t num_runs_per_pass = 100;
+ string delegate;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -42,18 +43,21 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
"Path of tensorflow lite model."),
tensorflow::Flag("input_layer", &values.input_layer,
"Names of input tensors, separated by comma. Example: "
- "input_1,input_2"),
+ "input_1,input_2."),
tensorflow::Flag("input_layer_type", &values.input_layer_type,
"Data types of input tensors, separated by comma. "
- "Example: float,int"),
+ "Example: float,int."),
tensorflow::Flag(
"input_layer_shape", &values.input_layer_shape,
- "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"),
+ "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2."),
tensorflow::Flag("output_layer", &values.output_layer,
- "Names of output tensors, separated by comma. Example "
- "output_1,output_2"),
+ "Names of output tensors, separated by comma. Example: "
+ "output_1,output_2."),
tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
- "Number of full runs in each pass."),
+ "[optional] Number of full runs in each pass."),
+ tensorflow::Flag("delegate", &values.delegate,
+ "[optional] Delegate to use for executing ops. Must be "
+ "`{\"\", EAGER}`"),
};
bool no_inputs = *argc == 1;
@@ -61,6 +65,14 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
if (!success || no_inputs || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
+ } else if (values.tensorflow_model.empty() || values.tflite_model.empty() ||
+ values.input_layer.empty() || values.input_layer_type.empty() ||
+ values.input_layer_shape.empty() || values.output_layer.empty()) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
+ } else if (!(values.delegate == "" || values.delegate == "EAGER")) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
}
return {values.tensorflow_model,
@@ -69,7 +81,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
Split<string>(values.output_layer, ","),
- values.num_runs_per_pass};
+ values.num_runs_per_pass,
+ values.delegate};
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
index 19f34c0a51..c6ca796ac2 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
@@ -33,7 +33,7 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
options.input_layer_shape, options.output_layer)) {
return false;
}
- TfLiteDriver tflite_driver(/*use_nnapi=*/true);
+ TfLiteDriver tflite_driver(/*use_nnapi=*/true, options.delegate);
tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index 4ab2f230fd..f67992139f 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -44,6 +44,9 @@ struct DiffOptions {
// each of the passes. The first pass has a single inference, while the
// second pass does multiple inferences back to back.
int num_runs_per_pass;
+ // Path to the delegate library to be loaded in order to execute ops. Must be
+ // `{"", EAGER}`.
+ string delegate;
};
// Run a single TensorFLow Lite diff test with a given options.
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4d08fb5458..71a98a3d56 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -135,7 +136,13 @@ class TfLiteDriver::Expectation {
size_t num_elements_;
};
-TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
+ : use_nnapi_(use_nnapi) {
+ if (delegate_name == "EAGER") {
+ delegate_.reset(new EagerDelegate());
+ }
+}
+
TfLiteDriver::~TfLiteDriver() {}
void TfLiteDriver::AllocateTensors() {
@@ -165,6 +172,13 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
}
interpreter_->UseNNAPI(use_nnapi_);
+ if (delegate_) {
+ if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) {
+ Invalidate("Unable to the build graph using the delegate");
+ return;
+ }
+ }
+
must_allocate_tensors_ = true;
}
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 5493ba3631..aed35f877d 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <map>
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -28,7 +29,7 @@ namespace testing {
// A test runner that feeds inputs into TF Lite and verifies its outputs.
class TfLiteDriver : public TestRunner {
public:
- explicit TfLiteDriver(bool use_nnapi);
+ explicit TfLiteDriver(bool use_nnapi, const string& delegate = "");
~TfLiteDriver() override;
void LoadModel(const string& bin_file_path) override;
@@ -52,6 +53,7 @@ class TfLiteDriver : public TestRunner {
class Expectation;
+ std::unique_ptr<EagerDelegate> delegate_;
bool use_nnapi_ = false;
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index aa4a4d8854..02d0890a7a 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -242,9 +242,11 @@ cc_library(
"graph_transformations/resolve_constant_random_uniform.cc",
"graph_transformations/resolve_constant_range.cc",
"graph_transformations/resolve_constant_reshape.cc",
+ "graph_transformations/resolve_constant_select.cc",
"graph_transformations/resolve_constant_shape_or_rank.cc",
"graph_transformations/resolve_constant_slice.cc",
"graph_transformations/resolve_constant_strided_slice.cc",
+ "graph_transformations/resolve_constant_tile.cc",
"graph_transformations/resolve_constant_transpose.cc",
"graph_transformations/resolve_constant_unary.cc",
"graph_transformations/resolve_fake_quant_args_from_vars.cc",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index 76c6be00d4..b324631579 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -274,8 +274,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
return false;
}
- const auto& weights = model->GetArray(preceding_op->inputs[1]);
- const auto& bias = model->GetArray(preceding_op->inputs[2]);
+ const auto& weights_name = preceding_op->inputs[1];
+ const auto& bias_name = preceding_op->inputs[2];
+ const auto& weights = model->GetArray(weights_name);
+ const auto& bias = model->GetArray(bias_name);
+ const int count_ops_consuming_bias = CountOpsWithInput(*model, bias_name);
+ const int count_ops_consuming_weights =
+ CountOpsWithInput(*model, weights_name);
+
if (binary_op->type == OperatorType::kAdd ||
binary_op->type == OperatorType::kSub) {
if (!bias.buffer) {
@@ -285,6 +291,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
LogName(*binary_op), LogName(*preceding_op));
return false;
}
+ if (count_ops_consuming_bias > 1) {
+ AddMessageF(
+ "Not fusing %s because the bias of the preceding %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
} else {
if (!weights.buffer || !bias.buffer) {
AddMessageF(
@@ -293,6 +306,13 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
LogName(*binary_op), LogName(*preceding_op));
return false;
}
+ if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) {
+ AddMessageF(
+ "Not fusing %s because the weights or bias of the preceding %s is "
+ "consumed by another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
}
int count_ops_consuming_output =
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 8d9a4c4700..99f4a7d8f6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -190,6 +190,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSelect)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTile)
DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index d26c3b2878..502de88f7c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -274,6 +274,19 @@ bool PropagateMinMaxAmongArrays(Model* model,
return changed;
}
+bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
+ Array& input = model->GetArray(op->inputs[0]);
+ Array& output = model->GetArray(op->outputs[0]);
+
+ // If input and output both exist or do not exist, do nothing.
+ if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) {
+ return false;
+ }
+
+ // Otherwise propagate info amongst the input and output array.
+ return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]});
+}
+
bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
@@ -370,7 +383,6 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
case OperatorType::kSlice:
case OperatorType::kStridedSlice:
case OperatorType::kSqueeze:
- case OperatorType::kReshape:
case OperatorType::kExpandDims:
case OperatorType::kPad:
case OperatorType::kGather:
@@ -416,6 +428,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForLstmCell(model, op);
break;
+ case OperatorType::kReshape:
+ changed = HardcodeMinMaxForReshape(model, op);
+ break;
+
default:
break;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index 9f5d8b9450..fc49fbda59 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -48,20 +48,26 @@ void RerouteEdges(const string& from_array, const string& to_array,
} // namespace
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
- Model* model, std::size_t op_index) {
+ Model* model, std::size_t op_index,
+ int input_index) {
const auto passthru_it = model->operators.begin() + op_index;
auto* passthru_op = passthru_it->get();
CHECK_EQ(passthru_op->outputs.size(), 1);
CHECK_GE(passthru_op->inputs.size(), 1);
- int count_nonconstant_input_arrays = 0;
- // We call 'main input' the unique nonconstant input array if there is one,
- // or else the 0-th input.
+
int main_input_array_index = 0;
- for (int i = 0; i < passthru_op->inputs.size(); i++) {
- if (!model->GetArray(passthru_op->inputs[i]).buffer) {
- count_nonconstant_input_arrays++;
- if (count_nonconstant_input_arrays == 1) {
- main_input_array_index = i;
+ if (input_index != -1) {
+ main_input_array_index = input_index;
+ } else {
+ // We call 'main input' the unique nonconstant input array if there is one,
+ // or else the 0-th input.
+ int count_nonconstant_input_arrays = 0;
+ for (int i = 0; i < passthru_op->inputs.size(); i++) {
+ if (!model->GetArray(passthru_op->inputs[i]).buffer) {
+ count_nonconstant_input_arrays++;
+ if (count_nonconstant_input_arrays == 1) {
+ main_input_array_index = i;
+ }
}
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
index 9d448c3ee9..663704e5ac 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
@@ -50,7 +50,8 @@ namespace toco {
// and then discards it and returns true, or, if it's not trivial (if neither
// the input nor the output may be discarded), returns false.
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
- Model* model, std::size_t op_index);
+ Model* model, std::size_t op_index,
+ int input_index = -1);
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
new file mode 100644
index 0000000000..e880a3f44d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// Resolves a constant Select operation.
+//
+// This implementation is looking strictly for all-or-nothing on the select
+// condition. It's possible to enhance this by looking per-element and possibly
+// producing a Mul op.
+bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ const auto* base_op = it->get();
+ if (base_op->type != OperatorType::kSelect) {
+ return false;
+ }
+ const auto* op = static_cast<const SelectOperator*>(base_op);
+
+ CHECK_GE(op->inputs.size(), 3);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes.
+ return false;
+ }
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes.
+ return false;
+ }
+
+ // We require the cond input to be constant.
+ if (!IsConstantParameterArray(*model, op->inputs[0])) {
+ return false;
+ }
+ const Array& cond_array = model->GetArray(op->inputs[0]);
+ CHECK(cond_array.data_type == ArrayDataType::kBool)
+ << "Only bool conditions are supported";
+ const auto& cond_data = cond_array.GetBuffer<ArrayDataType::kBool>().data;
+ if (cond_data.empty()) {
+ return false;
+ }
+
+ // Check if the condition is the same for all elements.
+ bool cond_value = cond_data[0];
+ for (size_t i = 1; i < cond_data.size(); ++i) {
+ if (cond_data[i] != cond_value) {
+ AddMessageF(
+ "Cannot resolve %s as constant; cond_array has differing "
+ "per-element values",
+ LogName(*op));
+ return false;
+ }
+ }
+
+ // Pass-through the selected input.
+ return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
new file mode 100644
index 0000000000..0b0d070714
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
@@ -0,0 +1,173 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+// NOTE: the Tile implementation here is taken from tflite's Tile kernel.
+
+template <typename T>
+void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier,
+ T* out_data) {
+ for (int i = 0; i < multiplier; ++i) {
+ const T* in_end = in_data + in_size;
+ T* new_out_data = std::copy(in_data, in_end, out_data);
+ in_data = out_data;
+ out_data = new_out_data;
+ }
+}
+
+template <typename T, typename M>
+std::pair<int, int> TileOneDimension(const Shape& in_dimensions,
+ const T* in_data, const M* multipliers,
+ T* out_data, int dimension) {
+ const int dimension_size = in_dimensions.dims(dimension);
+ if (dimension == in_dimensions.dimensions_count() - 1) {
+ CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
+ out_data);
+ return std::make_pair(
+ dimension_size,
+ dimension_size * static_cast<int>(multipliers[dimension]));
+ }
+ int total_stride_size = 0, total_tiled_stride_size = 0;
+ const T* copy_from_data = in_data;
+ T* copy_to_data = out_data;
+ for (int i = 0; i < dimension_size; ++i) {
+ int stride_size = 0, tiled_stride_size = 0;
+ std::tie(stride_size, tiled_stride_size) =
+ TileOneDimension(in_dimensions, copy_from_data, multipliers,
+ copy_to_data, dimension + 1);
+ copy_from_data += stride_size;
+ copy_to_data += tiled_stride_size;
+ total_stride_size += stride_size;
+ total_tiled_stride_size += tiled_stride_size;
+ }
+ CopyMultipleTimes(out_data, total_tiled_stride_size,
+ multipliers[dimension] - 1,
+ out_data + total_tiled_stride_size);
+ return std::make_pair(total_stride_size,
+ total_tiled_stride_size * multipliers[dimension]);
+}
+
+template <ArrayDataType Type>
+inline void Tile(const Array& input_array, const Array& multiples_array,
+ Array* output_array) {
+ // Allocate output storage.
+ auto& output_data = output_array->GetMutableBuffer<Type>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+
+ switch (multiples_array.data_type) {
+ case ArrayDataType::kInt32:
+ TileOneDimension(
+ input_array.shape(), input_array.GetBuffer<Type>().data.data(),
+ multiples_array.GetBuffer<ArrayDataType::kInt32>().data.data(),
+ output_array->GetMutableBuffer<Type>().data.data(), 0);
+ break;
+ case ArrayDataType::kInt64:
+ TileOneDimension(
+ input_array.shape(), input_array.GetBuffer<Type>().data.data(),
+ multiples_array.GetBuffer<ArrayDataType::kInt64>().data.data(),
+ output_array->GetMutableBuffer<Type>().data.data(), 0);
+ break;
+ default:
+ CHECK(false);
+ break;
+ }
+}
+
+} // namespace
+
+// Resolves a constant Tile operation.
+bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ const auto* base_op = it->get();
+ if (base_op->type != OperatorType::kTile) {
+ return false;
+ }
+ const auto* op = static_cast<const TensorFlowTileOperator*>(base_op);
+
+ CHECK_GE(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes.
+ return false;
+ }
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes.
+ return false;
+ }
+
+ // We require constant inputs.
+ if (!IsConstantParameterArray(*model, op->inputs[0]) ||
+ !IsConstantParameterArray(*model, op->inputs[1])) {
+ return false;
+ }
+ const Array& input_array = model->GetArray(op->inputs[0]);
+ const Array& multiples_array = model->GetArray(op->inputs[1]);
+ CHECK(multiples_array.data_type == ArrayDataType::kInt32 ||
+ multiples_array.data_type == ArrayDataType::kInt64)
+ << "Only int32/int64 indices are supported";
+
+ // Copy min/max info if present. The ranges of the selected values may be
+ // a subset of the original range but we want to ensure the quantization
+ // params stay the same.
+ if (input_array.minmax) {
+ const auto& input_minmax = input_array.GetMinMax();
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min;
+ output_minmax.max = input_minmax.max;
+ }
+
+ CHECK(!output_array.buffer);
+ switch (output_array.data_type) {
+ case ArrayDataType::kFloat:
+ Tile<ArrayDataType::kFloat>(input_array, multiples_array, &output_array);
+ break;
+ case ArrayDataType::kUint8:
+ Tile<ArrayDataType::kUint8>(input_array, multiples_array, &output_array);
+ break;
+ case ArrayDataType::kInt16:
+ Tile<ArrayDataType::kInt16>(input_array, multiples_array, &output_array);
+ break;
+ case ArrayDataType::kInt32:
+ Tile<ArrayDataType::kInt32>(input_array, multiples_array, &output_array);
+ break;
+ case ArrayDataType::kInt64:
+ Tile<ArrayDataType::kInt64>(input_array, multiples_array, &output_array);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type given to Tile op with output \""
+ << op->outputs[0] << "\"";
+ break;
+ }
+
+ // Erase input arrays if no longer used after we remove the op.
+ DeleteArrayIfUsedOnce(op->inputs[0], model);
+ DeleteArrayIfUsedOnce(op->inputs[1], model);
+
+ // Erase the operator.
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index fe3882c28d..475415e481 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -246,8 +246,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[i] = outval;
}
- } else if (unary_op->type == OperatorType::kRelu6 &&
- unary_op->type == OperatorType::kRelu1 &&
+ } else if (unary_op->type == OperatorType::kRelu6 ||
+ unary_op->type == OperatorType::kRelu1 ||
unary_op->type == OperatorType::kRelu) {
for (size_t i = 0; i < output_buffer_size; ++i) {
const float value = (*input_float_data)[i];
diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc
index 14168fa33f..204c0d101e 100644
--- a/tensorflow/contrib/lite/toco/toco_port.cc
+++ b/tensorflow/contrib/lite/toco/toco_port.cc
@@ -138,13 +138,15 @@ namespace port {
#define close _close
#define open _open
#define read _read
-#define O_RDONLY _O_RDONLY
-#define O_CREAT _O_CREAT
-#define O_WRONLY _O_WRONLY
-// Windows does not support the same set of file permissions as other platforms.
+// Windows does not support the same set of file permissions as other platforms,
+// and also requires an explicit flag for binary file read/write support.
constexpr int kFileCreateMode = _S_IREAD | _S_IWRITE;
+constexpr int kFileReadFlags = _O_RDONLY | _O_BINARY;
+constexpr int kFileWriteFlags = _O_WRONLY | _O_BINARY | _O_CREAT;
#else
constexpr int kFileCreateMode = 0664;
+constexpr int kFileReadFlags = O_RDONLY;
+constexpr int kFileWriteFlags = O_CREAT | O_WRONLY;
#endif // _WIN32
static bool port_initialized = false;
@@ -197,7 +199,7 @@ tensorflow::Status GetContents(const string& path, string* output,
const file::Options& options) {
output->clear();
- int fd = open(path.c_str(), O_RDONLY);
+ int fd = open(path.c_str(), kFileReadFlags);
if (fd == -1) {
return tensorflow::errors::NotFound("can't open() for read");
}
@@ -226,7 +228,7 @@ tensorflow::Status GetContents(const string& path, string* output,
tensorflow::Status SetContents(const string& filename, const string& contents,
const file::Options& options) {
- int fd = open(filename.c_str(), O_WRONLY | O_CREAT, kFileCreateMode);
+ int fd = open(filename.c_str(), kFileWriteFlags, kFileCreateMode);
if (fd == -1) {
return tensorflow::errors::Internal("can't open() for write");
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index fcd3cbab07..34130a02b0 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -90,8 +90,10 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveConstantRandomUniform);
transformations->Add(new ResolveConstantRange);
transformations->Add(new ResolveConstantReshape);
+ transformations->Add(new ResolveConstantSelect);
transformations->Add(new ResolveConstantSlice);
transformations->Add(new ResolveConstantStridedSlice);
+ transformations->Add(new ResolveConstantTile);
transformations->Add(new ResolveConstantTranspose);
transformations->Add(new ResolveConstantUnaryOperator);
transformations->Add(new ResolveTensorFlowMerge);
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index 9cc8f10b42..e30cc1d70e 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -6,120 +6,74 @@ endif
# Try to figure out the host system
HOST_OS :=
ifeq ($(OS),Windows_NT)
- HOST_OS = WINDOWS
+ HOST_OS = windows
else
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
- HOST_OS := LINUX
+ HOST_OS := linux
endif
ifeq ($(UNAME_S),Darwin)
- HOST_OS := OSX
+ HOST_OS := osx
endif
endif
HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
-# Self-hosting
-TARGET_ARCH := ${HOST_ARCH}
+# Override these on the make command line to target a specific architecture. For example:
+# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l
+TARGET := $(HOST_OS)
+TARGET_ARCH := $(HOST_ARCH)
-# Cross compiling
-ifeq ($(CROSS),rpi)
- TARGET_ARCH := armv7l
- TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf-
-endif
-
-ifeq ($(CROSS),riscv)
- TARGET_ARCH := riscv
- TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf-
-endif
-ifeq ($(CROSS),stm32f7)
- TARGET_ARCH := armf7
- TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
-endif
-ifeq ($(CROSS),stm32f1)
- TARGET_ARCH := armm1
- TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
-endif
-
-# Where compiled objects are stored.
-OBJDIR := $(MAKEFILE_DIR)/gen/obj/
-BINDIR := $(MAKEFILE_DIR)/gen/bin/
-LIBDIR := $(MAKEFILE_DIR)/gen/lib/
-GENDIR := $(MAKEFILE_DIR)/gen/obj/
-
-LIBS :=
-ifeq ($(TARGET_ARCH),x86_64)
- CXXFLAGS += -fPIC -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -pthread # -msse4.2
-endif
-
-ifeq ($(TARGET_ARCH),armv7l)
- CXXFLAGS += -mfpu=neon -pthread -fPIC
- LIBS += -ldl
-endif
-
-ifeq ($(TARGET_ARCH),riscv)
-# CXXFLAGS += -march=gap8
- CXXFLAGS += -DTFLITE_MCU
- LIBS += -ldl
- BUILD_TYPE := micro
-endif
-
-ifeq ($(TARGET_ARCH),armf7)
- CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -DTFLITE_MCU
- CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections
- CXXFLAGS += -funsigned-char -MMD
- CXXFLAGS += -mcpu=cortex-m7 -mthumb -mfpu=fpv5-sp-d16 -mfloat-abi=softfp
- CXXFLAGS += '-std=gnu++11' '-fno-rtti' '-Wvla' '-c' '-Wall' '-Wextra' '-Wno-unused-parameter' '-Wno-missing-field-initializers' '-fmessage-length=0' '-fno-exceptions' '-fno-builtin' '-ffunction-sections' '-fdata-sections' '-funsigned-char' '-MMD' '-fno-delete-null-pointer-checks' '-fomit-frame-pointer' '-Os'
- LIBS += -ldl
- BUILD_TYPE := micro
-endif
-ifeq ($(TARGET_ARCH),armm1)
- CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -mcpu=cortex-m1 -mthumb -DTFLITE_MCU
- CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections
- CXXFLAGS += -funsigned-char -MMD
- LIBS += -ldl
-endif
+# These are the default libraries needed, but they can be added to or
+# overridden by the platform-specific settings in target makefiles.
+LIBS := \
+-lstdc++ \
+-lpthread \
+-lm \
+-lz
-# Settings for the host compiler.
-CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++
-CXXFLAGS += -O3 -DNDEBUG
+# There are no rules for compiling objects for the host system (since we don't
+# generate things like the protobuf compiler that require that), so all of
+# these settings are for the target compiler.
+CXXFLAGS := -O3 -DNDEBUG
CCFLAGS := ${CXXFLAGS}
CXXFLAGS += --std=c++11
-CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc
-AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar
CFLAGS :=
-LDOPTS :=
-LDOPTS += -L/usr/local/lib
+LDOPTS := -L/usr/local/lib
ARFLAGS := -r
+TARGET_TOOLCHAIN_PREFIX :=
+CC_PREFIX :=
+
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
INCLUDES := \
-I. \
--I$(MAKEFILE_DIR)/../../../ \
--I$(MAKEFILE_DIR)/../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
-I$(MAKEFILE_DIR)/downloads/farmhash/src \
-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
--I$(GENDIR)
+-I$(OBJDIR)
# This is at the end so any globally-installed frameworks like protobuf don't
# override local versions in the source tree.
INCLUDES += -I/usr/local/include
-LIBS += \
--lstdc++ \
--lpthread \
--lm \
--lz
-
-# If we're on Linux, also link in the dl library.
-ifeq ($(HOST_OS),LINUX)
- LIBS += -ldl
-endif
-
-include $(MAKEFILE_DIR)/ios_makefile.inc
-include $(MAKEFILE_DIR)/rpi_makefile.inc
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
# This library is the main target for this makefile. It will contain a minimal
# runtime that can be linked in to other programs.
@@ -163,8 +117,8 @@ $(wildcard tensorflow/contrib/lite/kernels/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \
-$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \
-$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c)
+$(wildcard tensorflow/contrib/lite/tools/make/downloads/farmhash/src/farmhash.cc) \
+$(wildcard tensorflow/contrib/lite/tools/make/downloads/fft2d/fftsg.c)
endif
# Remove any duplicates.
CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS))
@@ -179,10 +133,6 @@ ifeq ($(BUILD_TYPE),micro)
CORE_CC_EXCLUDE_SRCS += \
tensorflow/contrib/lite/mmap_allocation.cc \
tensorflow/contrib/lite/nnapi_delegate.cc
-else
-CORE_CC_EXCLUDE_SRCS += \
-tensorflow/contrib/lite/mmap_allocation_disabled.cc \
-tensorflow/contrib/lite/nnapi_delegate_disabled.cc
endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh
index 31df43a175..fe056945a6 100755
--- a/tensorflow/contrib/lite/build_ios_universal_lib.sh
+++ b/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh
@@ -17,23 +17,23 @@
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-cd "$SCRIPT_DIR/../../.."
+cd "$SCRIPT_DIR/../../../../.."
# Build library for supported architectures and packs them in a fat binary.
make_library() {
for arch in x86_64 armv7 armv7s arm64
do
- make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \
- -j 8 \
- $SCRIPT_DIR/gen/lib/ios_${arch}/${1}
+ make -f tensorflow/contrib/lite/tools/make/Makefile TARGET=ios TARGET_ARCH=${arch} \
+ -j 8
done
+ mkdir -p tensorflow/contrib/lite/tools/make/gen/lib
lipo \
- tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \
- tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \
- tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \
- tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \
+ tensorflow/contrib/lite/tools/make/gen/ios_x86_64/lib/${1} \
+ tensorflow/contrib/lite/tools/make/gen/ios_armv7/lib/${1} \
+ tensorflow/contrib/lite/tools/make/gen/ios_armv7s/lib/${1} \
+ tensorflow/contrib/lite/tools/make/gen/ios_arm64/lib/${1} \
-create \
- -output tensorflow/contrib/lite/gen/lib/${1}
+ -output tensorflow/contrib/lite/tools/make/gen/lib/${1}
}
make_library libtensorflow-lite.a
diff --git a/tensorflow/contrib/lite/build_rpi_lib.sh b/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
index 3824b16412..24ecd4356d 100755
--- a/tensorflow/contrib/lite/build_rpi_lib.sh
+++ b/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
@@ -17,6 +17,6 @@
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-cd "$SCRIPT_DIR/../../.."
+cd "$SCRIPT_DIR/../../../../.."
-CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7
+CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/tools/make/Makefile TARGET=rpi TARGET_ARCH=armv7l
diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
index 8c7df474d5..29afa45133 100755
--- a/tensorflow/contrib/lite/download_dependencies.sh
+++ b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
@@ -17,9 +17,9 @@
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-cd "$SCRIPT_DIR/../../.."
+cd "$SCRIPT_DIR/../../../../.."
-DOWNLOADS_DIR=tensorflow/contrib/lite/downloads
+DOWNLOADS_DIR=tensorflow/contrib/lite/tools/make/downloads
BZL_FILE_PATH=tensorflow/workspace.bzl
# Ensure it is being run from repo root
diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc
index 079320586f..7f36b8ecef 100644
--- a/tensorflow/contrib/lite/ios_makefile.inc
+++ b/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc
@@ -1,11 +1,11 @@
# Settings for iOS.
-ifeq ($(TARGET), IOS)
- BUILD_FOR_IOS_SIMULATOR := false
- ifeq ($(IOS_ARCH), x86_64)
- BUILD_FOR_IOS_SIMULATOR := true
+ifeq ($(TARGET), ios)
+ BUILD_FOR_IOS_SIMULATOR := false
+ ifeq ($(TARGET_ARCH), x86_64)
+ BUILD_FOR_IOS_SIMULATOR := true
endif
- ifeq ($(IOS_ARCH), i386)
- BUILD_FOR_IOS_SIMULATOR := true
+ ifeq ($(TARGET_ARCH), i386)
+ BUILD_FOR_IOS_SIMULATOR := true
endif
ifeq ($(BUILD_FOR_IOS_SIMULATOR), true)
IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \
@@ -18,8 +18,8 @@ ifeq ($(TARGET), IOS)
endif
IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version)
MIN_SDK_VERSION := 9.0
- # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64.
- IOS_ARCH := x86_64
+ # Override TARGET_ARCH with armv7, armv7s, arm64, i386, or x86_64.
+ TARGET_ARCH := x86_64
CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
-DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \
@@ -29,21 +29,17 @@ ifeq ($(TARGET), IOS)
-fno-exceptions \
-isysroot \
${IPHONEOS_SYSROOT} \
- -arch $(IOS_ARCH) \
+ -arch $(TARGET_ARCH) \
-O3
CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
-fembed-bitcode \
-mno-thumb \
-isysroot \
${IPHONEOS_SYSROOT} \
- -arch $(IOS_ARCH) \
+ -arch $(TARGET_ARCH) \
-O3
LDFLAGS := -fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-framework Accelerate \
- -arch $(IOS_ARCH)
- OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/
- LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/
- BINDIR := $(BINDIR)ios_$(IOS_ARCH)/
- DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/
+ -arch $(TARGET_ARCH)
endif
diff --git a/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc
new file mode 100644
index 0000000000..86499da99e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc
@@ -0,0 +1,10 @@
+# Settings for Linux.
+ifeq ($(TARGET), linux)
+ CXXFLAGS += \
+ -fPIC \
+ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
+ -pthread
+ # TODO(petewarden): In the future we may want to add architecture-specific
+ # flags like -msse4.2
+ LIBS += -ldl
+endif
diff --git a/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc
new file mode 100644
index 0000000000..1a82afec33
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc
@@ -0,0 +1,10 @@
+# Settings for RiscV platforms.
+ifeq ($(TARGET), riscv)
+ TARGET_ARCH := riscv
+ TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf-
+
+ #CXXFLAGS += -march=gap8
+ CXXFLAGS += -DTFLITE_MCU
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
diff --git a/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc
new file mode 100644
index 0000000000..1ad0c50237
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc
@@ -0,0 +1,60 @@
+# Settings for Raspberry Pi.
+ifeq ($(TARGET),rpi)
+ # Default to the architecture used on the Pi Two/Three (ArmV7), but override this
+ # with TARGET_ARCH=armv6 to build for the Pi Zero or One.
+ TARGET_ARCH := armv7l
+ TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf-
+
+ ifeq ($(TARGET_ARCH), armv7l)
+ CXXFLAGS += \
+ -march=armv7-a \
+ -mfpu=neon-vfpv4 \
+ -funsafe-math-optimizations \
+ -ftree-vectorize \
+ -fPIC
+
+ CCFLAGS += \
+ -march=armv7-a \
+ -mfpu=neon-vfpv4 \
+ -funsafe-math-optimizations \
+ -ftree-vectorize \
+ -fPIC
+
+ LDFLAGS := \
+ -Wl,--no-export-dynamic \
+ -Wl,--exclude-libs,ALL \
+ -Wl,--gc-sections \
+ -Wl,--as-needed
+ endif
+
+ # TODO(petewarden) In the future, we'll want to use OpenBLAS as a faster
+ # alternative to Eigen on non-NEON ARM hardware like armv6.
+ ifeq ($(TARGET_ARCH), armv6)
+ CXXFLAGS += \
+ -march=armv6 \
+ -mfpu=vfp \
+ -funsafe-math-optimizations \
+ -ftree-vectorize \
+ -fPIC
+
+ CCFLAGS += \
+ -march=armv6 \
+ -mfpu=vfp \
+ -funsafe-math-optimizations \
+ -ftree-vectorize \
+ -fPIC
+
+ LDFLAGS := \
+ -Wl,--no-export-dynamic \
+ -Wl,--exclude-libs,ALL \
+ -Wl,--gc-sections \
+ -Wl,--as-needed
+ endif
+
+ LIBS := \
+ -lstdc++ \
+ -lpthread \
+ -lm \
+ -ldl
+
+endif
diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc
new file mode 100644
index 0000000000..7418e4d196
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc
@@ -0,0 +1,21 @@
+# Settings for STM32F1 platforms.
+ifeq ($(TARGET), stm32f1)
+ TARGET_ARCH := armm1
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+
+ CXXFLAGS += \
+ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
+ -mcpu=cortex-m1 \
+ -mthumb \
+ -DTFLITE_MCU \
+ -fno-rtti \
+ -fmessage-length=0 \
+ -fno-exceptions \
+ -fno-builtin \
+ -ffunction-sections \
+ -fdata-sections \
+ -funsigned-char \
+ -MMD
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc
new file mode 100644
index 0000000000..48af71e5b4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc
@@ -0,0 +1,41 @@
+# Settings for STM32F7 platforms.
+ifeq ($(TARGET), stm32f7)
+ TARGET_ARCH := armf7
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+
+ CXXFLAGS += \
+ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
+ -DTFLITE_MCU \
+ -fno-rtti \
+ -fmessage-length=0 \
+ -fno-exceptions \
+ -fno-builtin \
+ -ffunction-sections \
+ -fdata-sections \
+ -funsigned-char \
+ -MMD \
+ -mcpu=cortex-m7 \
+ -mthumb \
+ -mfpu=fpv5-sp-d16 \
+ -mfloat-abi=softfp \
+ -std=gnu++11 \
+ -fno-rtti \
+ -Wvla \
+ -c \
+ -Wall \
+ -Wextra \
+ -Wno-unused-parameter \
+ -Wno-missing-field-initializers \
+ -fmessage-length=0 \
+ -fno-exceptions \
+ -fno-builtin \
+ -ffunction-sections \
+ -fdata-sections \
+ -funsigned-char \
+ -MMD \
+ -fno-delete-null-pointer-checks \
+ -fomit-frame-pointer \
+ -Os
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 4942d94176..8c0bfefb30 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.ops import lookup_ops
# pylint: disable=unused-import
@@ -395,17 +394,12 @@ class MutableHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype.base_dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
(self._table_ref, keys, self._default_value)) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self._table_ref):
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, keys, self._default_value, name=name)
-
- values.set_shape(keys.get_shape().concatenate(self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -451,9 +445,6 @@ class MutableHashTable(LookupInterface):
with ops.colocate_with(self._table_ref):
exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
-
- exported_values.set_shape(exported_keys.get_shape().concatenate(
- self._value_shape))
return exported_keys, exported_values
class _Saveable(BaseSaverBuilder.SaveableObject):
@@ -537,14 +528,15 @@ class MutableDenseHashTable(LookupInterface):
ValueError: If checkpoint is True and no name was specified.
"""
self._default_value = ops.convert_to_tensor(
- default_value, dtype=value_dtype)
+ default_value, dtype=value_dtype, name="default_value")
self._value_shape = self._default_value.get_shape()
# The table must be shared if checkpointing is requested for multi-worker
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
use_node_name_sharing = checkpoint and shared_name is None
- empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
+ empty_key = ops.convert_to_tensor(
+ empty_key, dtype=key_dtype, name="empty_key")
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
@@ -591,20 +583,13 @@ class MutableDenseHashTable(LookupInterface):
Raises:
TypeError: when `keys` do not match the table data types.
"""
- if keys.dtype.base_dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
- (self._key_dtype, keys.dtype))
-
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self._table_ref):
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, keys, self._default_value, name=name)
- if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
- values.set_shape(
- tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate(
- self._value_shape))
return values
def insert(self, keys, values, name=None):
@@ -624,11 +609,11 @@ class MutableDenseHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- # pylint: disable=protected-access
- lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
- # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
+ keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
+ values = ops.convert_to_tensor(
+ values, dtype=self._value_dtype, name="values")
with ops.colocate_with(self._table_ref):
op = gen_lookup_ops.lookup_table_insert_v2(
self._table_ref, keys, values, name=name)
@@ -650,8 +635,6 @@ class MutableDenseHashTable(LookupInterface):
exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
self._table_ref, self._key_dtype, self._value_dtype, name=name)
- exported_values.set_shape(exported_keys.get_shape().concatenate(
- self._value_shape))
return exported_keys, exported_values
class _Saveable(BaseSaverBuilder.SaveableObject):
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 8d510ede58..6fb5244fc6 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -434,8 +434,10 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
exported_keys, exported_values = table.export()
- self.assertAllEqual([None], exported_keys.get_shape().as_list())
- self.assertAllEqual([None, 2], exported_values.get_shape().as_list())
+ self.assertAllEqual([None], exported_keys.get_shape().as_list(),
+ msg="Saw shape %s" % exported_keys.shape)
+ self.assertAllEqual([None, 2], exported_values.get_shape().as_list(),
+ msg="Saw shape %s" % exported_values.shape)
# exported data is in the order of the internal map, i.e. undefined
sorted_keys = np.sort(exported_keys.eval())
sorted_values = np.sort(exported_values.eval())
@@ -669,7 +671,7 @@ class MutableHashTableOpTest(test.TestCase):
# lookup with keys of the wrong type
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
- with self.assertRaises(TypeError):
+ with self.assertRaises(ValueError):
table.lookup(input_string).eval()
# default value of the wrong type
@@ -853,7 +855,8 @@ class MutableDenseHashTableOpTest(test.TestCase):
input_string = constant_op.constant([11, 12, 15], dtypes.int64)
output = table.lookup(input_string)
- self.assertAllEqual([3, 4], output.get_shape())
+ self.assertAllEqual(
+ [3, 4], output.shape, msg="Saw shape: %s" % output.shape)
result = output.eval()
self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]],
diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh
index 448ae6d22e..dc9b17a627 100755
--- a/tensorflow/contrib/makefile/download_dependencies.sh
+++ b/tensorflow/contrib/makefile/download_dependencies.sh
@@ -35,7 +35,9 @@ NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.
# process. For now we're hardcoding to the version which is used by
# TensorFlow 1.9.
PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz"
-RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
+# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once
+# the archive has been propagated in mirror.bazel.build.
+RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)"
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py
index e553612269..7053907da0 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification.py
@@ -24,7 +24,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
# TODO(nsilberman): move into metrics/python/ops/
@@ -174,7 +174,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200,
ops.add_to_collections(metrics_collections, best_f1)
return best_f1
- best_f1 = distribute_lib.get_tower_context().merge_call(
+ best_f1 = distribution_strategy_context.get_tower_context().merge_call(
f1_across_towers, values)
update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 8c11d8bcfd..f6ecaba834 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Map from graph_key to state for that graph. We use the graph_key
# since it works in both eager and graph mode, and gives the outer
# graph inside functions.
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None:
# In a cross-tower context for a DistributionStrategy, which means
# only one Optimizer will be created, not one per tower.
@@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= 1. / num_towers
@@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= 1. / num_towers
@@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in grads_and_vars],))
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, filtered, global_step=global_step, name=name)
def _get_or_create_state(self, var_list=None):
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py
index 164ff0ea06..3de53405ec 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop.py
@@ -22,7 +22,7 @@ A detailed description of rmsprop.
- divide gradient by the root of this average
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
-mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon)
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square)
delta = - mom
This implementation of RMSProp uses plain momentum, not Nesterov momentum.
@@ -33,7 +33,7 @@ gradients, and uses that average to estimate the variance:
mean_grad = decay * mean_square{t-1} + (1-decay) * gradient
mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
mom = momentum * mom{t-1} + learning_rate * g_t /
- sqrt(mean_square - mean_grad**2 + epsilon)
+ sqrt(mean_square - mean_grad**2)
delta = - mom
"""
@@ -43,7 +43,6 @@ from __future__ import print_function
from tensorflow.contrib.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.training import training_ops
@@ -87,7 +86,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
decay: A float hyperparameter. Discounting factor for the history/coming
gradient.
momentum: A float hyperparameter.
- epsilon: A float hyperparameter. Small value to avoid zero denominator.
+ epsilon: A float hyperparameter. Small value to initialize the average
+ square gradient variable and avoid zero denominator.
use_locking: If True use locks for update operation.
centered: If True, gradients are normalized by the estimated variance of
the gradient; if False, by the uncentered second moment. Setting this to
@@ -106,10 +106,8 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
for v in var_list:
- if v.get_shape().is_fully_defined():
- init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype)
- else:
- init_rms = array_ops.ones_like(v)
+ init_rms = state.get_hyper(
+ "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
state.create_slot_with_initializer(v, init_rms, v.get_shape(),
v.dtype.base_dtype, "rms")
if self._centered:
@@ -129,7 +127,9 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ # epsilon is now the rms initial value and is not added to the
+ # denominator anymore, hence calling the kernel op with epsilon=0.
+ 0,
grad,
use_locking=self._use_locking).op
else:
@@ -140,7 +140,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
use_locking=self._use_locking).op
@@ -157,7 +157,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
use_locking=self._use_locking)
else:
@@ -168,7 +168,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
use_locking=self._use_locking)
@@ -185,7 +185,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad.values,
grad.indices,
use_locking=self._use_locking)
@@ -197,7 +197,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad.values,
grad.indices,
use_locking=self._use_locking)
@@ -215,7 +215,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
indices,
use_locking=self._use_locking)
@@ -227,7 +227,7 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
state.get_hyper("learning_rate", var.dtype.base_dtype),
state.get_hyper("decay", var.dtype.base_dtype),
state.get_hyper("momentum", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
+ 0,
grad,
indices,
use_locking=self._use_locking)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
index dc23ef241a..628d0418dd 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
@@ -39,34 +39,34 @@ _DATA_TYPES = [dtypes.half, dtypes.float32]
_TEST_PARAM_VALUES = [
# learning_rate, decay, momentum, epsilon, centered, use_resource
- [0.5, 0.9, 0.0, 1e-3, True, False],
- [0.5, 0.9, 0.0, 1e-3, False, False],
- [0.5, 0.9, 0.0, 1e-3, True, True],
- [0.5, 0.9, 0.0, 1e-3, False, True],
- [0.1, 0.9, 0.0, 1e-3, True, False],
- [0.5, 0.95, 0.0, 1e-3, False, False],
- [0.5, 0.95, 0.0, 1e-5, True, False],
- [0.5, 0.95, 0.9, 1e-5, True, False],
+ [0.5, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.9, 0.0, 1.0, False, False],
+ [0.5, 0.9, 0.0, 1.0, True, True],
+ [0.5, 0.9, 0.0, 1.0, False, True],
+ [0.1, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.95, 0.0, 1.0, False, False],
+ [0.5, 0.8, 0.0, 1e-3, True, False],
+ [0.5, 0.8, 0.9, 1e-3, True, False],
]
class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum,
- epsilon, centered):
+ centered):
rms_t = rms * decay + (1 - decay) * g * g
- denom_t = rms_t + epsilon
if centered:
mg_t = mg * decay + (1 - decay) * g
- denom_t -= mg_t * mg_t
+ denom_t = rms_t - mg_t * mg_t
else:
mg_t = mg
+ denom_t = rms_t
mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
var_t = var - mom_t
return var_t, mg_t, rms_t, mom_t
def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
- lr, decay, momentum, epsilon, centered):
+ lr, decay, momentum, centered):
mg_t = copy.deepcopy(mg)
rms_t = copy.deepcopy(rms)
mom_t = copy.deepcopy(mom)
@@ -75,7 +75,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
gindex = gindexs[i]
gvalue = gvalues[i]
rms_t[gindex] = rms[gindex] * decay + (1 - decay) * gvalue * gvalue
- denom_t = rms_t[gindex] + epsilon
+ denom_t = rms_t[gindex]
if centered:
mg_t[gindex] = mg_t[gindex] * decay + (1 - decay) * gvalue
denom_t -= mg_t[gindex] * mg_t[gindex]
@@ -129,8 +129,8 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
- rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
- rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -144,10 +144,10 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate,
- decay, momentum, epsilon, centered)
+ decay, momentum, centered)
var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate,
- decay, momentum, epsilon, centered)
+ decay, momentum, centered)
# Validate updated params
if centered:
@@ -191,7 +191,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
loss = pred * pred
sgd_op = rmsprop.RMSPropOptimizer(
learning_rate=1.0,
- decay=0.0,
+ decay=0.1,
momentum=0.0,
epsilon=1.0,
centered=True).minimize(loss)
@@ -202,7 +202,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
sgd_op.run()
# Validate updated params
self.assertAllCloseAccordingToType(
- [[-111, -138]], var0.eval(), atol=0.01)
+ [[-7/3.0, -4/3.0]], var0.eval(), atol=0.01)
@parameterized.named_parameters(
*test_util.generate_combinations_with_testcase_name(
@@ -251,8 +251,8 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
- rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
- rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -266,10 +266,10 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
- learning_rate, decay, momentum, epsilon, centered)
+ learning_rate, decay, momentum, centered)
var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
- learning_rate, decay, momentum, epsilon, centered)
+ learning_rate, decay, momentum, centered)
# Validate updated params
if centered:
@@ -317,13 +317,13 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
# Check the parameters.
self.assertAllCloseAccordingToType(
np.array([
- 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)),
- 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0))
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([
- 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)),
- 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
]), var1.eval())
# Step 2: the root mean square accumulators contain the previous update.
update.run()
@@ -335,17 +335,17 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
# Check the parameters.
self.assertAllCloseAccordingToType(
np.array([
- 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
- (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)),
- 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
- (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0))
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([
- 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
- (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)),
- 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
- (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0))
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
]), var1.eval())
@parameterized.parameters(_DATA_TYPES)
@@ -357,7 +357,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
opt = rmsprop.RMSPropOptimizer(
- learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5)
+ learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1.0)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
@@ -383,22 +383,22 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
np.array([0.90001, 0.90001]), rms1.eval())
# Check the momentum accumulators
self.assertAllCloseAccordingToType(
- np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)),
- (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval())
+ np.array([(0.1 * 2.0 / math.sqrt(0.901)),
+ (0.1 * 2.0 / math.sqrt(0.901))]), mom0.eval())
self.assertAllCloseAccordingToType(
- np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)),
- (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval())
+ np.array([(0.01 * 2.0 / math.sqrt(0.90001)),
+ (0.01 * 2.0 / math.sqrt(0.90001))]), mom1.eval())
# Check that the parameters.
self.assertAllCloseAccordingToType(
np.array([
- 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)),
- 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([
- 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)),
- 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
]), var1.eval())
# Step 2: the root mean square accumulators contain the previous update.
@@ -410,38 +410,38 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
self.assertAllCloseAccordingToType(
np.array([
- 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
- (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)),
- 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
- (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
]), mom0.eval())
self.assertAllCloseAccordingToType(
np.array([
- 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
- (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)),
- 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
- (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
]), mom1.eval())
# Check the parameters.
self.assertAllCloseAccordingToType(
np.array([
- 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) -
- (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
- (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))),
- 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) -
- (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
- (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)))
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)))
]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([
- 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) -
- (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
- (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))),
- 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) -
- (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
- (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)))
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)))
]), var1.eval())
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index cb66fd1f76..2ddbd73ea6 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -455,6 +455,24 @@ class _LayerMatch(object):
return self._bias_add_op
+def _FollowedByFakeQuant(tensor):
+ """Returns True if the tensor is followed by a FakeQuant."""
+ fake_quant_ops = set([
+ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
+ 'FakeQuantWithMinMaxVarsPerChannel'
+ ])
+ pass_through_ops = set(['Reshape', 'Identity'])
+ consumers = tensor.consumers()
+ while consumers:
+ c = consumers.pop()
+ if c.type in fake_quant_ops:
+ return True
+ elif c.type in pass_through_ops:
+ for output in c.outputs:
+ consumers.extend(output.consumers())
+ return False
+
+
def _InsertQuantOp(context,
name,
producer,
@@ -535,11 +553,7 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- fake_quant_ops = set([
- 'FakeQuantWithMinMaxVars',
- 'FakeQuantWithMinMaxArgs'
- ])
- if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
+ if _FollowedByFakeQuant(inputs):
return
if moving_avg:
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 06ebcdfee1..212d902a3c 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -471,6 +471,60 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertTrue(
'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names)
+ def testSkipReshapeQuantization(self):
+ self._RunTestOverParameters(self._TestSkipReshapeQuantization)
+
+ def _TestSkipReshapeQuantization(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=nn_ops.relu6,
+ scope='test/test')
+
+ reshape = array_ops.reshape(
+ conv, (int(10), int(height / 2), int(width / 2), int(16)))
+
+ # Insert a fake quant node after the reshape. We will check that one isn't
+ # insert before.
+ array_ops.fake_quant_with_min_max_vars(reshape, -1, 1)
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ # Ensure that there isn't a FakeQuant added before the reshape.
+ self.assertFalse(
+ 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs])
+
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=nn_ops.relu6,
+ scope='test/test')
+
+ reshape = array_ops.reshape(
+ conv, (int(10), int(height / 2), int(width / 2), int(16)))
+
+ # If no fake quant is added after the reshape, a FakeQuant should be added
+ # before the reshape.
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+
+ # Ensure that there isn't a FakeQuant added before the reshape.
+ self.assertTrue(
+ 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs])
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 2a84629080..5874245d58 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -149,7 +149,7 @@ cuda_py_tests(
cuda_py_tests(
name = "core_rnn_test",
- size = "large",
+ size = "medium",
srcs = ["python/kernel_tests/core_rnn_test.py"],
additional_deps = [
":rnn_py",
@@ -175,7 +175,7 @@ cuda_py_tests(
tf_py_test(
name = "fused_rnn_cell_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/fused_rnn_cell_test.py"],
additional_deps = [
":rnn_py",
@@ -192,10 +192,6 @@ tf_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
- tags = [
- "manual",
- "notap",
- ],
)
cuda_py_tests(
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 1c20d88fe4..d62ec45d18 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -1288,7 +1288,10 @@ class LSTMTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testDynamicEquivalentToStaticRNN(self):
self._testDynamicEquivalentToStaticRNN(use_sequence_length=False)
- self._testDynamicEquivalentToStaticRNN(use_sequence_length=False)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDynamicEquivalentToStaticRNNWithSequenceLength(self):
+ self._testDynamicEquivalentToStaticRNN(use_sequence_length=True)
class BidirectionalRNNTest(test.TestCase):
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index fbb50befdf..e7eb4ac563 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -113,7 +113,6 @@ py_test(
size = "small",
srcs = ["python/saved_model/keras_saved_model_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":saved_model_py",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD
index 0b8fc0cdc6..412a2c81a1 100644
--- a/tensorflow/contrib/stat_summarizer/BUILD
+++ b/tensorflow/contrib/stat_summarizer/BUILD
@@ -31,8 +31,5 @@ tf_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
],
- tags = [
- "no_windows",
- "notap", # TODO(b/80546574): test is flaky
- ],
+ tags = ["notap"], # TODO(b/80546574): test is flaky
)
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 35e8c92aba..8fa0b3ada9 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
-
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
-
+from tensorflow.python.estimator import estimator as core_estimator
+from tensorflow.python.estimator.canned import head as core_head_lib
+from tensorflow.python.estimator.export.export_output import PredictOutput
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
@@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all'
EPSILON = 0.000001
+class ModelBuilderOutputType(object):
+ MODEL_FN_OPS = 0
+ ESTIMATOR_SPEC = 1
+
+
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
def __init__(self, op_dict):
@@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
run_context.request_stop()
-def get_default_head(params, weights_name, name=None):
- if params.regression:
- return head_lib.regression_head(
- weight_column_name=weights_name,
- label_dimension=params.num_outputs,
- enable_centered_bias=False,
- head_name=name)
+def _get_default_head(params, weights_name, output_type, name=None):
+ """Creates a default head based on a type of a problem."""
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if params.regression:
+ return head_lib.regression_head(
+ weight_column_name=weights_name,
+ label_dimension=params.num_outputs,
+ enable_centered_bias=False,
+ head_name=name)
+ else:
+ return head_lib.multi_class_head(
+ params.num_classes,
+ weight_column_name=weights_name,
+ enable_centered_bias=False,
+ head_name=name)
else:
- return head_lib.multi_class_head(
- params.num_classes,
- weight_column_name=weights_name,
- enable_centered_bias=False,
- head_name=name)
-
+ if params.regression:
+ return core_head_lib._regression_head( # pylint:disable=protected-access
+ weight_column=weights_name,
+ label_dimension=params.num_outputs,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ else:
+ return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
+ n_classes=params.num_classes,
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
def get_model_fn(params,
graph_builder_class,
@@ -135,19 +156,27 @@ def get_model_fn(params,
report_feature_importances=False,
local_eval=False,
head_scope=None,
- include_all_in_serving=False):
+ include_all_in_serving=False,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Return a model function given a way to construct a graph builder."""
if model_head is None:
- model_head = get_default_head(params, weights_name)
+ model_head = _get_default_head(params, weights_name, output_type)
def _model_fn(features, labels, mode):
"""Function that returns predictions, training loss, and training op."""
+
if (isinstance(features, ops.Tensor) or
isinstance(features, sparse_tensor.SparseTensor)):
features = {'features': features}
if feature_columns:
features = features.copy()
- features.update(layers.transform_features(features, feature_columns))
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ features.update(layers.transform_features(features, feature_columns))
+ else:
+ for fc in feature_columns:
+ tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access
+ features[fc.name] = tensor
weights = None
if weights_name and weights_name in features:
@@ -201,52 +230,95 @@ def get_model_fn(params,
def _train_fn(unused_loss):
return training_graph
- model_ops = model_head.create_model_fn_ops(
- features=features,
- labels=labels,
- mode=mode,
- train_op_fn=_train_fn,
- logits=logits,
- scope=head_scope)
# Ops are run in lexigraphical order of their keys. Run the resource
# clean-up op last.
all_handles = graph_builder.get_all_resource_handles()
ops_at_end = {
- '9: clean up resources': control_flow_ops.group(
- *[resource_variable_ops.destroy_resource_op(handle)
- for handle in all_handles])}
+ '9: clean up resources':
+ control_flow_ops.group(*[
+ resource_variable_ops.destroy_resource_op(handle)
+ for handle in all_handles
+ ])
+ }
if report_feature_importances:
ops_at_end['1: feature_importances'] = (
graph_builder.feature_importances())
- training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
-
- if early_stopping_rounds:
- training_hooks.append(
- TensorForestLossHook(
- early_stopping_rounds,
- early_stopping_loss_threshold=early_stopping_loss_threshold,
- loss_op=model_ops.loss))
-
- model_ops.training_hooks.extend(training_hooks)
-
- if keys is not None:
- model_ops.predictions[keys_name] = keys
-
- if params.inference_tree_paths:
- model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
-
- model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
- if include_all_in_serving:
- # In order to serve the variance we need to add the prediction dict
- # to output_alternatives dict.
- if not model_ops.output_alternatives:
- model_ops.output_alternatives = {}
- model_ops.output_alternatives[ALL_SERVING_KEY] = (
- constants.ProblemType.UNSPECIFIED, model_ops.predictions)
- return model_ops
+ training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ model_ops = model_head.create_model_fn_ops(
+ features=features,
+ labels=labels,
+ mode=mode,
+ train_op_fn=_train_fn,
+ logits=logits,
+ scope=head_scope)
+
+ if early_stopping_rounds:
+ training_hooks.append(
+ TensorForestLossHook(
+ early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ loss_op=model_ops.loss))
+
+ model_ops.training_hooks.extend(training_hooks)
+
+ if keys is not None:
+ model_ops.predictions[keys_name] = keys
+
+ if params.inference_tree_paths:
+ model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+
+ model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+
+ if include_all_in_serving:
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ if not model_ops.output_alternatives:
+ model_ops.output_alternatives = {}
+ model_ops.output_alternatives[ALL_SERVING_KEY] = (
+ constants.ProblemType.UNSPECIFIED, model_ops.predictions)
+
+ return model_ops
+
+ else:
+ # Estimator spec
+ estimator_spec = model_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_fn,
+ logits=logits)
+
+ if early_stopping_rounds:
+ training_hooks.append(
+ TensorForestLossHook(
+ early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ loss_op=estimator_spec.loss))
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ if keys is not None:
+ estimator_spec.predictions[keys_name] = keys
+ if params.inference_tree_paths:
+ estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+ estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+
+ if include_all_in_serving:
+ outputs = estimator_spec.export_outputs
+ if not outputs:
+ outputs = {}
+ outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
+ print(estimator_spec.export_outputs)
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ estimator_spec = estimator_spec._replace(export_outputs=outputs)
+
+ return estimator_spec
return _model_fn
@@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
params,
graph_builder_class,
device_assigner,
- model_head=get_default_head(
- params, weight_column, name='head{0}'.format(i)),
+ model_head=_get_default_head(
+ params,
+ weight_column,
+ name='head{0}'.format(i),
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS),
weights_name=weight_column,
keys_name=keys_column,
early_stopping_rounds=early_stopping_rounds,
@@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreTensorForestEstimator(core_estimator.Estimator):
+ """A CORE estimator that can train and evaluate a random forest.
+
+ Example:
+
+ ```python
+ params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
+
+ # Estimator using the default graph builder.
+ estimator = CoreTensorForestEstimator(params, model_dir=model_dir)
+
+ # Or estimator using TrainingLossForest as the graph builder.
+ estimator = CoreTensorForestEstimator(
+ params, graph_builder_class=tensor_forest.TrainingLossForest,
+ model_dir=model_dir)
+
+ # Input builders
+ def input_fn_train: # returns x, y
+ ...
+ def input_fn_eval: # returns x, y
+ ...
+ estimator.train(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+
+ # Predict returns an iterable of dicts.
+ results = list(estimator.predict(x=x))
+ prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
+ prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
+ ```
+ """
+
+ def __init__(self,
+ params,
+ device_assigner=None,
+ model_dir=None,
+ feature_columns=None,
+ graph_builder_class=tensor_forest.RandomForestGraphs,
+ config=None,
+ weight_column=None,
+ keys_column=None,
+ feature_engineering_fn=None,
+ early_stopping_rounds=100,
+ early_stopping_loss_threshold=0.001,
+ num_trainers=1,
+ trainer_id=0,
+ report_feature_importances=False,
+ local_eval=False,
+ version=None,
+ head=None,
+ include_all_in_serving=False):
+ """Initializes a TensorForestEstimator instance.
+
+ Args:
+ params: ForestHParams object that holds random forest hyperparameters.
+ These parameters will be passed into `model_fn`.
+ device_assigner: An `object` instance that controls how trees get
+ assigned to devices. If `None`, will use
+ `tensor_forest.RandomForestDeviceAssigner`.
+ model_dir: Directory to save model parameters, graph, etc. To continue
+ training a previously saved model, load checkpoints saved to this
+ directory into an estimator.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `_FeatureColumn`.
+ graph_builder_class: An `object` instance that defines how TF graphs for
+ random forest training and inference are built. By default will use
+ `tensor_forest.RandomForestGraphs`. Can be overridden by version
+ kwarg.
+ config: `RunConfig` object to configure the runtime settings.
+ weight_column: A string defining feature column name representing
+ weights. Will be multiplied by the loss of the example. Used to
+ downweight or boost examples during training.
+ keys_column: A string naming one of the features to strip out and
+ pass through into the inference/eval results dict. Useful for
+ associating specific examples with their prediction.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ early_stopping_rounds: Allows training to terminate early if the forest is
+ no longer growing. 100 by default. Set to a Falsy value to disable
+ the default training hook.
+ early_stopping_loss_threshold: Percentage (as fraction) that loss must
+ improve by within early_stopping_rounds steps, otherwise training will
+ terminate.
+ num_trainers: Number of training jobs, which will partition trees
+ among them.
+ trainer_id: Which trainer this instance is.
+ report_feature_importances: If True, print out feature importances
+ during evaluation.
+ local_eval: If True, don't use a device assigner for eval. This is to
+ support some common setups where eval is done on a single machine, even
+ though training might be distributed.
+ version: Unused.
+ head: A heads_lib.Head object that calculates losses and such. If None,
+ one will be automatically created based on params.
+ include_all_in_serving: if True, allow preparation of the complete
+ prediction dict including the variance to be exported for serving with
+ the Servo lib; and it also requires calling export_savedmodel with
+ default_output_alternative_key=ALL_SERVING_KEY, i.e.
+ estimator.export_savedmodel(export_dir_base=your_export_dir,
+ serving_input_fn=your_export_input_fn,
+ default_output_alternative_key=ALL_SERVING_KEY)
+ if False, resort to default behavior, i.e. export scores and
+ probabilities but no variances. In this case
+ default_output_alternative_key should be None while calling
+ export_savedmodel().
+ Note, that due to backward compatibility we cannot always set
+ include_all_in_serving to True because in this case calling
+ export_saved_model() without
+ default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
+ saved_model_export_utils.get_output_alternatives() would raise
+ ValueError.
+
+ Returns:
+ A `TensorForestEstimator` instance.
+ """
+
+ super(CoreTensorForestEstimator, self).__init__(
+ model_fn=get_model_fn(
+ params.fill(),
+ graph_builder_class,
+ device_assigner,
+ feature_columns=feature_columns,
+ model_head=head,
+ weights_name=weight_column,
+ keys_name=keys_column,
+ early_stopping_rounds=early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ num_trainers=num_trainers,
+ trainer_id=trainer_id,
+ report_feature_importances=report_feature_importances,
+ local_eval=local_eval,
+ include_all_in_serving=include_all_in_serving,
+ output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
+ model_dir=model_dir,
+ config=config)
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
index ac42364d25..aa0016b740 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
@@ -23,7 +23,39 @@ import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column_lib as core_feature_column
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_utils
+
+
+def _get_classification_input_fns():
+ iris = base.load_iris()
+ data = iris.data.astype(np.float32)
+ labels = iris.target.astype(np.int32)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
+ return train_input_fn, predict_input_fn
+
+
+def _get_regression_input_fns():
+ boston = base.load_boston()
+ data = boston.data.astype(np.float32)
+ labels = boston.target.astype(np.int32)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
+ return train_input_fn, predict_input_fn
class TensorForestTrainerTests(test.TestCase):
@@ -39,32 +71,287 @@ class TensorForestTrainerTests(test.TestCase):
inference_tree_paths=True)
classifier = random_forest.TensorForestEstimator(hparams.fill())
+ input_fn, predict_input_fn = _get_classification_input_fns()
+ classifier.fit(input_fn=input_fn, steps=100)
+ res = classifier.evaluate(input_fn=input_fn, steps=10)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ predictions = list(classifier.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0.576117, 0.211942, 0.211942]],
+ [pred['probabilities'] for pred in predictions])
+
+ def testRegression(self):
+ """Tests regression using matrix data as input."""
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=5,
+ max_nodes=1000,
+ num_classes=1,
+ num_features=13,
+ regression=True,
+ split_after_samples=20)
+
+ regressor = random_forest.TensorForestEstimator(hparams.fill())
+
+ input_fn, predict_input_fn = _get_regression_input_fns()
+
+ regressor.fit(input_fn=input_fn, steps=100)
+ res = regressor.evaluate(input_fn=input_fn, steps=10)
+ self.assertGreaterEqual(0.1, res['loss'])
+
+ predictions = list(regressor.predict(input_fn=predict_input_fn))
+ self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1)
+
+ def testAdditionalOutputs(self):
+ """Tests multi-class classification using matrix data as input."""
+ hparams = tensor_forest.ForestHParams(
+ num_trees=1,
+ max_nodes=100,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+ classifier = random_forest.TensorForestEstimator(
+ hparams.fill(), keys_column='keys', include_all_in_serving=True)
+
iris = base.load_iris()
data = iris.data.astype(np.float32)
labels = iris.target.astype(np.int32)
- classifier.fit(x=data, y=labels, steps=100, batch_size=50)
- classifier.evaluate(x=data, y=labels, steps=10)
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'x': data,
+ 'keys': np.arange(len(iris.data)).reshape(150, 1)
+ },
+ y=labels,
+ batch_size=10,
+ num_epochs=1,
+ shuffle=False)
- def testRegression(self):
+ classifier.fit(input_fn=input_fn, steps=100)
+ predictions = list(classifier.predict(input_fn=input_fn))
+ # Check that there is a key column, tree paths and var.
+ for pred in predictions:
+ self.assertTrue('keys' in pred)
+ self.assertTrue('tree_paths' in pred)
+ self.assertTrue('prediction_variance' in pred)
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertLessEqual(
+ reader.get_tensor(ops.GraphKeys.GLOBAL_STEP), global_step)
+
+ def testEarlyStopping(self):
"""Tests multi-class classification using matrix data as input."""
+ hparams = tensor_forest.ForestHParams(
+ num_trees=100,
+ max_nodes=10000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+ classifier = random_forest.TensorForestEstimator(
+ hparams.fill(),
+ # Set a crazy threshold - 30% loss change.
+ early_stopping_loss_threshold=0.3,
+ early_stopping_rounds=2)
+
+ input_fn, _ = _get_classification_input_fns()
+ classifier.fit(input_fn=input_fn, steps=100)
+
+ # We stopped early.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
+
+
+class CoreTensorForestTests(test.TestCase):
+
+ def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self):
+ head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
hparams = tensor_forest.ForestHParams(
num_trees=3,
max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+
+ est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn)
+
+ input_fn, predict_input_fn = _get_classification_input_fns()
+
+ est.train(input_fn=input_fn, steps=100)
+ res = est.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0.576117, 0.211942, 0.211942]],
+ [pred['probabilities'] for pred in predictions])
+
+ def testRegression(self):
+ """Tests regression using matrix data as input."""
+ head_fn = head_lib._regression_head(
+ label_dimension=1,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=5,
+ max_nodes=1000,
num_classes=1,
num_features=13,
regression=True,
split_after_samples=20)
- regressor = random_forest.TensorForestEstimator(hparams.fill())
+ regressor = random_forest.CoreTensorForestEstimator(
+ hparams.fill(), head=head_fn)
+
+ input_fn, predict_input_fn = _get_regression_input_fns()
+
+ regressor.train(input_fn=input_fn, steps=100)
+ res = regressor.evaluate(input_fn=input_fn, steps=10)
+ self.assertGreaterEqual(0.1, res['loss'])
+
+ predictions = list(regressor.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[24.]], [pred['predictions'] for pred in predictions], atol=1)
+
+ def testWithFeatureColumns(self):
+ head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=3,
+ max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+
+ est = random_forest.CoreTensorForestEstimator(
+ hparams.fill(),
+ head=head_fn,
+ feature_columns=[core_feature_column.numeric_column('x')])
+
+ iris = base.load_iris()
+ data = {'x': iris.data.astype(np.float32)}
+ labels = iris.target.astype(np.int32)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
+
+ est.train(input_fn=input_fn, steps=100)
+ res = est.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ def testAutofillsClassificationHead(self):
+ hparams = tensor_forest.ForestHParams(
+ num_trees=3,
+ max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+
+ est = random_forest.CoreTensorForestEstimator(hparams.fill())
+
+ input_fn, _ = _get_classification_input_fns()
+
+ est.train(input_fn=input_fn, steps=100)
+ res = est.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ def testAutofillsRegressionHead(self):
+ hparams = tensor_forest.ForestHParams(
+ num_trees=5,
+ max_nodes=1000,
+ num_classes=1,
+ num_features=13,
+ regression=True,
+ split_after_samples=20)
+
+ regressor = random_forest.CoreTensorForestEstimator(hparams.fill())
+
+ input_fn, predict_input_fn = _get_regression_input_fns()
+
+ regressor.train(input_fn=input_fn, steps=100)
+ res = regressor.evaluate(input_fn=input_fn, steps=10)
+ self.assertGreaterEqual(0.1, res['loss'])
+
+ predictions = list(regressor.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[24.]], [pred['predictions'] for pred in predictions], atol=1)
+
+ def testAdditionalOutputs(self):
+ """Tests multi-class classification using matrix data as input."""
+ hparams = tensor_forest.ForestHParams(
+ num_trees=1,
+ max_nodes=100,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+ classifier = random_forest.CoreTensorForestEstimator(
+ hparams.fill(), keys_column='keys', include_all_in_serving=True)
+
+ iris = base.load_iris()
+ data = iris.data.astype(np.float32)
+ labels = iris.target.astype(np.int32)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'x': data,
+ 'keys': np.arange(len(iris.data)).reshape(150, 1)
+ },
+ y=labels,
+ batch_size=10,
+ num_epochs=1,
+ shuffle=False)
+
+ classifier.train(input_fn=input_fn, steps=100)
+ predictions = list(classifier.predict(input_fn=input_fn))
+ # Check that there is a key column, tree paths and var.
+ for pred in predictions:
+ self.assertTrue('keys' in pred)
+ self.assertTrue('tree_paths' in pred)
+ self.assertTrue('prediction_variance' in pred)
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertLessEqual(
+ reader.get_tensor(ops.GraphKeys.GLOBAL_STEP), global_step)
+
+ def testEarlyStopping(self):
+ head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=3,
+ max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
- boston = base.load_boston()
- data = boston.data.astype(np.float32)
- labels = boston.target.astype(np.int32)
+ est = random_forest.CoreTensorForestEstimator(
+ hparams.fill(),
+ head=head_fn,
+ # Set a crazy threshold - 30% loss change.
+ early_stopping_loss_threshold=0.3,
+ early_stopping_rounds=2)
- regressor.fit(x=data, y=labels, steps=100, batch_size=50)
- regressor.evaluate(x=data, y=labels, steps=10)
+ input_fn, _ = _get_classification_input_fns()
+ est.train(input_fn=input_fn, steps=100)
+ # We stopped early.
+ self._assert_checkpoint(est.model_dir, global_step=8)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 2abf402e6c..56e451e2e3 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -265,7 +265,6 @@ tf_py_test(
":datasets",
],
grpc_enabled = True,
- tags = ["no_windows"],
)
tf_py_test(
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 19f088f8b8..d4ccb0f246 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.9.0'
+_VERSION = '1.10.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index 1bf49966d1..aee094177b 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.9.0"
+#define TPU_PROFILER_VERSION "1.10.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 029492b489..f221155568 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -45,6 +45,7 @@ from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest as data_nest
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
@@ -204,6 +205,12 @@ def _increase_eval_step_op(iterations_per_loop):
use_locking=True)
+def _extract_key_names(tensor_or_dict):
+ if isinstance(tensor_or_dict, dict):
+ return sorted(tensor_or_dict.keys())
+ return []
+
+
class _SIGNAL(object):
"""Signal used to control the thread of infeed/outfeed.
@@ -224,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
+ @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -247,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with `tf.contrib.summary.create_file_writer`.
+ summaries with @{tf.contrib.summary.create_file_writer}.
"""
def __new__(cls,
@@ -711,8 +718,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
features, labels = inputs.features_and_labels()
signals = inputs.signals()
- inputs_structure_recorder.validate_and_record_structure(
- features, labels, signals)
+ inputs_structure_recorder.validate_and_record_structure(features, labels)
unsharded_tensor_list = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
@@ -859,7 +865,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
signals = inputs.signals()
inputs_structure_recorder.validate_and_record_structure(
- features, labels, signals)
+ features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
@@ -901,17 +907,19 @@ class _InputPipeline(object):
inputs returned by the `input_fn` can have one of the following forms:
1. features
2. (features, labels)
+ 3. ((arbitrarily nested structure of features), labels)
Internally, form 1 is reformed to `(features, None)` as features and labels
are passed separately to underlying methods. For TPU training, TPUEstimator
may expect multiple `features` and `labels` tuples one for each core.
TPUEstimator allows various different structures for inputs (namely `features`
- and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`,
- and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`.
- TPU infeed/outfeed library expects flattened tensor list. So, `features` and
- `labels` need to be flattened, before infeed enqueue, and the structure of
- them needs to be recorded, in order to restore them after infeed dequeue.
+ and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`,
+ or nested tuples and `labels` could be `None`, `Tensor`, or dict of string
+ name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list.
+ So, `features` and `labels` need to be flattened, before infeed enqueue, and
+ the structure of them needs to be recorded, in order to restore them after
+ infeed dequeue.
"""
class InputsStructureRecorder(object):
@@ -919,10 +927,7 @@ class _InputPipeline(object):
def __init__(self, input_partition_dims=None):
# Holds the structure of inputs
- self._feature_names = []
- self._label_names = []
- self._has_labels = False
- self._signals_helper = None
+ self._feature_structure = {}
self._flattened_input_dims = None
if input_partition_dims:
@@ -949,7 +954,7 @@ class _InputPipeline(object):
return self._flattened_input_dims
def has_labels(self):
- return self._has_labels
+ return 'labels' in self._feature_structure
def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims,
label_dims_names, label_names, has_labels):
@@ -977,35 +982,16 @@ class _InputPipeline(object):
return flattened_input_dims
- def validate_and_record_structure(self, features, labels, signals=None):
+ def validate_and_record_structure(self, features, labels):
"""Validates and records the structure of `features` and `labels`."""
-
- def _extract_key_names(tensor_or_dict):
- if tensor_or_dict is None:
- return []
- return sorted(tensor_or_dict.keys()) if isinstance(
- tensor_or_dict, dict) else []
-
# Extract structure.
has_labels = labels is not None
feature_names = _extract_key_names(features)
label_names = _extract_key_names(labels)
- if signals is not None and self._signals_helper is None:
- # Record signals helper.
- self._signals_helper = _SignalsHelper(signals)
-
- if self._initialized:
- # Verify the structure is same. The following should never happen.
- assert feature_names == self._feature_names, 'feature keys mismatched'
- assert label_names == self._label_names, 'label keys mismatched'
- assert has_labels == self._has_labels, 'label presence mismatched'
- else:
+ if not self._initialized:
# Record structure.
self._initialized = True
- self._feature_names = feature_names
- self._label_names = label_names
- self._has_labels = has_labels
if self._feature_dims is not None:
feature_dims_names = _extract_key_names(self._feature_dims)
if feature_dims_names != feature_names:
@@ -1027,24 +1013,12 @@ class _InputPipeline(object):
def flatten_features_and_labels(self, features, labels, signals=None):
"""Flattens the `features` and `labels` to a single tensor list."""
- flattened_inputs = []
- if self._feature_names:
- # We need a fixed ordering for enqueueing and dequeueing.
- flattened_inputs.extend(
- [features[name] for name in self._feature_names])
- else:
- flattened_inputs.append(features)
-
+ self._feature_structure['features'] = features
if labels is not None:
- if self._label_names:
- # We need a fixed ordering for enqueueing and dequeueing.
- flattened_inputs.extend([labels[name] for name in self._label_names])
- else:
- flattened_inputs.append(labels)
-
+ self._feature_structure['labels'] = labels
if signals is not None:
- flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals))
- return flattened_inputs
+ self._feature_structure['signals'] = signals
+ return data_nest.flatten(self._feature_structure)
def unflatten_features_and_labels(self, flattened_inputs):
"""Restores the flattened inputs to original features and labels form.
@@ -1061,49 +1035,13 @@ class _InputPipeline(object):
ValueError: If the number of expected tensors from `flattened_inputs`
mismatches the recorded structure.
"""
- expected_num_features = (
- len(self._feature_names) if self._feature_names else 1)
- if self._has_labels:
- expected_num_labels = (
- len(self._label_names) if self._label_names else 1)
- else:
- expected_num_labels = 0
- expected_num_signals = (
- self._signals_helper.num_signals if self._signals_helper else 0)
-
- expected_num_tensors = (
- expected_num_features + expected_num_labels + expected_num_signals)
-
- if expected_num_tensors != len(flattened_inputs):
- raise ValueError(
- 'The number of flattened tensors mismatches expected num. '
- 'Expected {}, got {}'.format(expected_num_tensors,
- len(flattened_inputs)))
- if self._feature_names:
- unflattened_features = dict(
- zip(self._feature_names, flattened_inputs[:expected_num_features]))
- else:
- # Single tensor case
- unflattened_features = flattened_inputs[0]
-
- if expected_num_labels == 0:
- unflattened_label = None
- elif self._label_names:
- label_list = flattened_inputs[
- expected_num_features:expected_num_features + expected_num_labels]
- unflattened_label = dict(zip(self._label_names, label_list))
- else:
- # Single tensor case.
- unflattened_label = flattened_inputs[expected_num_features]
-
- signals = None
- if expected_num_signals != 0:
- tensor_list_for_signals = flattened_inputs[
- expected_num_features + expected_num_labels:]
- signals = self._signals_helper.unflatten(tensor_list_for_signals)
-
- return _Inputs(unflattened_features, unflattened_label, signals=signals)
+ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
+ flattened_inputs)
+ return _Inputs(
+ unflattened_inputs['features'],
+ unflattened_inputs.get('labels'),
+ signals=unflattened_inputs.get('signals'))
def __init__(self, input_fn, batch_axis, ctx):
"""Constructor.
@@ -1505,12 +1443,14 @@ class _ModelFnWrapper(object):
'The {} to the model returned by input_fn must have static shape.'
' Tensor: {}'.format(obj_name, obj))
else:
- for (key, tensor) in obj.items():
- if not tensor.get_shape().is_fully_defined():
- raise ValueError(
- 'The {} to the model returned by input_fn must have static '
- 'shape. Key: \'{}\', Tensor: {}'.format(
- obj_name, key, tensor))
+ for (key, value) in obj.items():
+ flattened_tensors = data_nest.flatten(value)
+ for tensor in flattened_tensors:
+ if not tensor.get_shape().is_fully_defined():
+ raise ValueError(
+ 'The {} to the model returned by input_fn must have static '
+ 'shape. Key: \'{}\', Tensor: {}'.format(
+ obj_name, key, tensor))
validate(features, 'features')
if labels is not None:
@@ -3338,26 +3278,6 @@ class _PaddingSignals(object):
return padding_mask
-class _SignalsHelper(object):
- """A general helper class to handle common signals manipulation."""
-
- def __init__(self, signals):
- self._signal_keys = []
- for key in sorted(iter(signals.keys())):
- self._signal_keys.append(key)
-
- @property
- def num_signals(self):
- return len(self._signal_keys)
-
- def unflatten(self, tensor_list):
- return dict(zip(self._signal_keys, tensor_list))
-
- @staticmethod
- def as_tensor_list(signals):
- return [signals[key] for key in sorted(iter(signals.keys()))]
-
-
def _verify_cross_hosts_transfer_size(tensor_dict, message):
total_size = 0
tensor_structure = {}
diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py
index f72e0a3f83..c272a2ac14 100644
--- a/tensorflow/contrib/training/python/training/training.py
+++ b/tensorflow/contrib/training/python/training/training.py
@@ -484,7 +484,8 @@ def train(train_op,
save_checkpoint_secs=600,
save_summaries_steps=100,
config=None,
- max_wait_secs=7200):
+ max_wait_secs=7200,
+ run_metadata=None):
"""Runs the training loop.
Args:
@@ -511,6 +512,7 @@ def train(train_op,
become available. This should be kept relatively short to help detect
incorrect code, but sometimes may need to be increased if the chief takes
a while to start up.
+ run_metadata: A [`RunMetadata`] protocol buffer.
Returns:
the value of the loss function after training.
@@ -541,5 +543,5 @@ def train(train_op,
max_wait_secs=max_wait_secs) as session:
loss = None
while not session.should_stop():
- loss = session.run(train_op)
+ loss = session.run(train_op, run_metadata=run_metadata)
return loss
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 82443fd7e8..9a8c20b1fd 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -149,6 +149,7 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
+ "mkl_deps",
)
exports_files(["ops/ops.pbtxt"])
@@ -735,7 +736,10 @@ cc_library(
"util/reporter.h",
],
copts = tf_copts(),
- linkopts = ["-lm"],
+ linkopts = select({
+ "//tensorflow:windows": [],
+ "//conditions:default": ["-lm"],
+ }),
visibility = ["//visibility:public"],
deps = [
":lib",
@@ -860,7 +864,6 @@ tf_cuda_library(
"util/work_sharder.h",
] + select({
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"util/memmapped_file_system.h",
"util/memmapped_file_system_writer.h",
@@ -2036,7 +2039,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//tensorflow:android": [],
"//conditions:default": [
"-ldl",
@@ -2126,7 +2128,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"],
}),
deps = [
@@ -2151,7 +2152,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"],
}),
deps = [
@@ -2183,7 +2183,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"],
}),
deps = [
@@ -2489,7 +2488,6 @@ tf_cuda_library(
],
) + select({
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"util/memmapped_file_system.cc",
"util/memmapped_file_system_writer.cc",
@@ -2498,13 +2496,13 @@ tf_cuda_library(
hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS,
copts = tf_copts(),
linkopts = select({
- "//tensorflow:freebsd": [],
+ "//tensorflow:freebsd": ["-lm"],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
- "//conditions:default": ["-ldl"],
- }) + [
- "-lm",
- ],
+ "//conditions:default": [
+ "-ldl",
+ "-lm",
+ ],
+ }),
deps = [
":lib",
":lib_internal",
@@ -2519,12 +2517,7 @@ tf_cuda_library(
] + if_static(
extra_deps = ["@protobuf_archive//:protobuf"],
otherwise = ["@protobuf_archive//:protobuf_headers"],
- ) + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
- ),
+ ) + mkl_deps(),
alwayslink = 1,
)
@@ -2806,12 +2799,7 @@ tf_cuda_library(
":protos_all_cc",
"//third_party/eigen3",
"//tensorflow/core/grappler:grappler_item",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
- ),
+ ] + mkl_deps(),
alwayslink = 1,
)
@@ -2851,12 +2839,7 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//third_party/eigen3",
"//tensorflow/core/kernels:required",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
- ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
+ ] + mkl_deps() + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
alwayslink = 1,
)
@@ -3149,7 +3132,10 @@ cc_library(
testonly = 1,
srcs = ["platform/test_main.cc"],
copts = tf_copts(),
- linkopts = ["-lm"],
+ linkopts = select({
+ "//tensorflow:windows": [],
+ "//conditions:default": ["-lm"],
+ }),
visibility = ["//tensorflow:internal"],
deps = [
":lib",
@@ -3860,11 +3846,7 @@ tf_cuda_only_cc_test(
":test",
":test_main",
"//third_party/eigen3",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- ],
- ),
+ ] + mkl_deps(),
)
tf_cc_test_gpu(
diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc
index ae03a61ae6..51812caeb2 100644
--- a/tensorflow/core/api_def/api_test.cc
+++ b/tensorflow/core/api_def/api_test.cc
@@ -59,8 +59,8 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir,
file_contents = PBTxtFromMultiline(file_contents);
ApiDefs api_defs;
- CHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents,
- &api_defs))
+ QCHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents,
+ &api_defs))
<< "Failed to load " << file_path;
CHECK_EQ(api_defs.op_size(), 1);
(*name_to_api_def)[api_defs.op(0).graph_op_name()] = api_defs.op(0);
diff --git a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
index a0e42dd02c..9f3f9b276b 100644
--- a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
@@ -123,5 +123,7 @@ Batched indexing into a 3-tensor:
[['a1', 'b1'], ['c1', 'd1']]]
output = [['b0', 'b1'], ['d0', 'c1']]
```
+
+See also `tf.gather` and `tf.batch_gather`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt
index 162ef2b033..c6104da4a6 100644
--- a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt
@@ -54,5 +54,7 @@ params.shape[axis + 1:]` where:
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, a 0 is stored in the
corresponding output value.
+
+See also `tf.batch_gather` and `tf.gather_nd`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt b/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt
new file mode 100644
index 0000000000..9d04a01f6f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt
@@ -0,0 +1,11 @@
+op {
+ graph_op_name: "HostConst"
+ attr {
+ name: "value"
+ description: <<END
+Attr `value` is the tensor to return.
+END
+ }
+ visibility: SKIP
+ summary: "Returns a constant tensor on the host. Only for writing C++ tests."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt
new file mode 100644
index 0000000000..f1a4cccbc3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index dbb2e67c7d..44408438b9 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -34,7 +34,7 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
virtual ~CollectiveRemoteAccessLocal() {}
- void StartAbort(const Status& s);
+ void StartAbort(const Status& s) override;
void RecvFromPeer(const string& peer_device, const string& peer_task,
bool peer_is_local, const string& key, Device* to_device,
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0695278c0d..bf1d78ec65 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -602,7 +602,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
if (tracer) {
TF_RETURN_IF_ERROR(tracer->Stop());
- TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
+ TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get()));
}
{
@@ -618,8 +618,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
&session_state_));
}
- if (args.stats_collector) {
- args.stats_collector->Finalize();
+ if (run_state.collector) {
+ run_state.collector->Finalize();
}
// Build and return the cost model as instructed.
@@ -634,7 +634,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
mutex_lock l(executor_lock_);
- args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
+ run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
// annotate stats onto cost graph.
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 21c5bdf8e9..9835b19511 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -206,6 +206,7 @@ class EagerContext {
// Only one of the below is set.
std::unique_ptr<DeviceMgr> local_device_manager_;
DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager
std::vector<Device*> devices_;
@@ -253,7 +254,6 @@ class EagerContext {
#ifndef __ANDROID__
void CloseRemoteContexts();
- std::unique_ptr<DeviceMgr> remote_device_manager_;
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index c2fac4c2c8..951bc4197e 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1319,7 +1319,7 @@ class ExecutorState {
TensorStore* tensor_store_;
// Step-local container.
ScopedStepContainer* step_container_;
- StepStatsCollector* stats_collector_;
+ StepStatsCollectorInterface* const stats_collector_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index cd01b43aea..a238a6763a 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -83,7 +83,7 @@ class Executor {
struct Args {
int64 step_id = 0;
Rendezvous* rendezvous = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
CallFrameInterface* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;
SessionState* session_state = nullptr;
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 996dbb59bc..0394f25839 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
#include <memory>
#include <unordered_map>
@@ -62,10 +62,29 @@ class NodeExecStatsWrapper {
std::unique_ptr<NodeExecStats> stats_;
};
+// Statistics collection interface for individual node execution.
+//
+// See `StepStatsCollector` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class StepStatsCollectorInterface {
+ public:
+ virtual ~StepStatsCollectorInterface() {}
+
+ // Saves `stats` to the collector.
+ virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+
+ // Generates a string reporting the currently used memory based
+ // on ResourceExhausted OOM `err` message.
+ // `err` message needs to contain device name and allocator name, e.g.:
+ // "ResourceExhaustedError: OOM when allocating tensor ...
+ // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
+ virtual string ReportAllocsOnResourceExhausted(const string& err) = 0;
+};
+
// StepStatsCollector manages the collection of a StepStats object.
// The StepStats object holds multiple DeviceStats.
// Each DeviceStats object holds multiple NodeExecStats.
-class StepStatsCollector {
+class StepStatsCollector : public StepStatsCollectorInterface {
public:
// Does not take ownership of `ss`.
explicit StepStatsCollector(StepStats* ss);
@@ -80,14 +99,9 @@ class StepStatsCollector {
// Save saves nt to the DeviceStats object associated with device.
// Should be called before Finalize.
void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats);
+ void Save(const string& device, NodeExecStatsWrapper* stats) override;
- // Generates a string reporting the currently used memory based
- // on ResourceExhausted OOM `err` message.
- // `err` message needs to contain device name and allocator name, E.g.:
- // "ResourceExhaustedError: OOM when allocating tensor ...
- // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
- string ReportAllocsOnResourceExhausted(const string& err);
+ string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
// from the constructor. Calling it more than once won't have any effect.
@@ -112,4 +126,4 @@ class StepStatsCollector {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index e886ef7b8e..f3c7189292 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -74,18 +74,18 @@ class DatasetVariantWrapper {
} // namespace
Status GraphDefBuilderWrapper::AddDataset(
- const GraphDatasetBase* dataset,
+ const DatasetBase* dataset,
const std::vector<std::pair<size_t, Node*>>& inputs,
const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
- const string& op_type_name = dataset->op_name();
+ const string& name = dataset->name();
std::unique_ptr<const GraphDefBuilder::Options> opts(
new GraphDefBuilder::Options(b_->opts()));
// TODO(srbs|mrry): Not all datasets have output_types and output_shapes
// attributes defined. It will be nice to have a consistent pattern.
- bool has_output_types_attr = HasAttr(op_type_name, "output_types");
- bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
+ bool has_output_types_attr = HasAttr(name, "output_types");
+ bool has_output_shapes_attr = HasAttr(name, "output_shapes");
if (has_output_shapes_attr) {
opts.reset(new GraphDefBuilder::Options(
opts->WithAttr("output_shapes", dataset->output_shapes())));
@@ -102,8 +102,7 @@ Status GraphDefBuilderWrapper::AddDataset(
return errors::Internal("AddDataset: Failed to build Options with error ",
opts->StatusToString());
}
- NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
- opts->op_registry());
+ NodeBuilder node_builder(opts->GetNameForOp(name), name, opts->op_registry());
{
size_t total_size = inputs.size() + list_inputs.size();
auto inputs_iter = inputs.begin();
@@ -128,7 +127,7 @@ Status GraphDefBuilderWrapper::AddDataset(
}
*output = opts->FinalizeBuilder(&node_builder);
if (*output == nullptr) {
- return errors::Internal("AddDataset: Failed to build ", op_type_name,
+ return errors::Internal("AddDataset: Failed to build ", name,
" op with error ", opts->StatusToString());
}
return Status::OK();
@@ -184,27 +183,32 @@ void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
}
-bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
+bool GraphDefBuilderWrapper::HasAttr(const string& name,
const string& attr_name) const {
const OpDef* op_def = nullptr;
- Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
+ Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
if (!s.ok() || op_def == nullptr) {
return false;
}
return HasAttr(op_def, attr_name);
}
-Status GraphDatasetBase::Serialize(SerializationContext* ctx,
- string* serialized_graph_def,
- string* output_node) const {
+Status DatasetBase::Save(SerializationContext* ctx,
+ IteratorStateWriter* writer) const {
+ string serialized_graph_def;
+ string output_node;
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* node = nullptr;
TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
- *output_node = node->name();
+ output_node = node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
- graph_def.SerializeToString(serialized_graph_def);
+ graph_def.SerializeToString(&serialized_graph_def);
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
return Status::OK();
}
@@ -264,8 +268,8 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output);
}
-const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
-const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
+const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
+const char DatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE";
BackgroundWorker::BackgroundWorker(Env* env, const string& name) {
@@ -315,22 +319,4 @@ void BackgroundWorker::WorkerLoop() {
}
}
-namespace dataset {
-
-IteratorContext MakeIteratorContext(OpKernelContext* ctx) {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.lib = ctx->function_library();
- // Note: must use reinterpret_cast because function.h forward-declares Device.
- DeviceBase* device =
- reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- return IteratorContext(params);
-}
-
-} // namespace dataset
-
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 66e836f9a6..e0c26d9286 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -40,6 +40,8 @@ limitations under the License.
namespace tensorflow {
+class DatasetBase;
+
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
class IteratorStateReader {
@@ -66,7 +68,6 @@ class IteratorStateWriter {
// Forward declarations to avoid introducing a dependency on headers in
// "tensorflow/core/graph/...".
class GraphDefBuilder;
-class GraphDatasetBase;
class Node;
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
@@ -120,7 +121,7 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- Status AddDataset(const GraphDatasetBase* dataset,
+ Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
@@ -133,7 +134,7 @@ class GraphDefBuilderWrapper {
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
- Status AddDataset(const GraphDatasetBase* dataset,
+ Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
@@ -145,7 +146,7 @@ class GraphDefBuilderWrapper {
}
Status AddDataset(
- const GraphDatasetBase* dataset,
+ const DatasetBase* dataset,
const std::vector<std::pair<size_t, Node*>>& inputs,
const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
@@ -276,6 +277,19 @@ class IteratorContext {
explicit IteratorContext(Params params) : params_(std::move(params)) {}
+ explicit IteratorContext(OpKernelContext* ctx) {
+ params_.env = ctx->env();
+ params_.runner = *(ctx->runner());
+ params_.lib = ctx->function_library();
+ // NOTE: must use reinterpret_cast because function.h forward-declares
+ // Device.
+ DeviceBase* device =
+ reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
+ params_.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ }
+
Env* env() const { return params_.env; }
std::function<void(std::function<void()>)>* runner() {
@@ -355,6 +369,11 @@ class IteratorBase {
virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) = 0;
+ Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ return GetNext(&ctx, out_tensors, end_of_sequence);
+ }
+
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// iterator.
@@ -406,10 +425,40 @@ class IteratorBase {
}
};
+// Represents runtime information needed to construct a dataset.
+class DatasetContext {
+ public:
+ struct Params {
+ string name;
+ };
+
+ explicit DatasetContext(Params params) : params_(std::move(params)) {}
+
+ explicit DatasetContext(OpKernelContext* ctx) {
+ params_.name = ctx->op_kernel().type_string();
+ }
+
+ const string& name() const { return params_.name; }
+
+ private:
+ Params params_;
+};
+
// Represents a (potentially infinite) range of outputs, where each
// output is a tuple of tensors.
class DatasetBase : public core::RefCounted {
public:
+ // Key for storing the Dataset graph in the serialized format.
+ TF_EXPORT static const char kDatasetGraphKey[];
+
+ // Key for storing the output node of the Dataset graph in the serialized
+ // format.
+ TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
+
+ explicit DatasetBase(DatasetContext&& ctx) : name_(ctx.name()) {}
+
+ const string& name() const { return name_; }
+
// Returns a new iterator for iterating over the range of elements in
// this dataset.
//
@@ -426,6 +475,11 @@ class DatasetBase : public core::RefCounted {
return (*iterator)->Initialize(ctx);
}
+ Status MakeIterator(IteratorContext&& ctx, const string& prefix,
+ std::unique_ptr<IteratorBase>* iterator) const {
+ return MakeIterator(&ctx, prefix, iterator);
+ }
+
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// dataset.
@@ -441,16 +495,9 @@ class DatasetBase : public core::RefCounted {
// Serializes the dataset and writes it to the `writer`.
virtual Status Save(SerializationContext* ctx,
- IteratorStateWriter* writer) const {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
+ IteratorStateWriter* writer) const;
protected:
- // TODO(srbs): Ideally all graph related logic should reside in
- // GraphDatasetBase. However, that would require Datasets defined in all ops
- // to derive from GraphDatasetBase. Once that is done we can move
- // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase.
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -463,54 +510,15 @@ class DatasetBase : public core::RefCounted {
// TODO(jsimsa): Consolidate overloading into a single method.
virtual Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
- Node** node) const {
- return AsGraphDefInternal(b, node);
- }
-
- virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- Node** node) const {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
+ Node** node) const = 0;
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
friend class DatasetToGraphOp; // For access to graph related members.
-};
-
-// Base-class for datasets that are built by ops.
-class GraphDatasetBase : public DatasetBase {
- public:
- GraphDatasetBase(OpKernelContext* ctx)
- : op_name_(ctx->op_kernel().type_string()) {}
-
- const string op_name() const { return op_name_; }
-
- Status Save(SerializationContext* ctx,
- IteratorStateWriter* writer) const override {
- string serialized_graph_def;
- string output_node;
- TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
- return Status::OK();
- }
-
- // Key for storing the Dataset graph in the serialized format.
- TF_EXPORT static const char kDatasetGraphKey[];
-
- // Key for storing the output node of the Dataset graph in the serialized
- // format.
- TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
private:
- Status Serialize(SerializationContext* ctx, string* serialized_graph_def,
- string* output_node) const;
-
- const string op_name_;
+ const string name_;
};
// Represents an iterator that is associated with a particular dataset.
@@ -718,12 +726,6 @@ class BackgroundWorker {
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
};
-namespace dataset {
-
-IteratorContext MakeIteratorContext(OpKernelContext* ctx);
-
-} // namespace dataset
-
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index c81f4a4450..edb7ed01e9 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -41,7 +41,7 @@ class ProcessFunctionLibraryRuntime;
class ResourceMgr;
class Rendezvous;
class ScopedStepContainer;
-class StepStatsCollector;
+class StepStatsCollectorInterface;
class Node;
// FunctionDefHelper::Create is a convenient helper to construct a
@@ -527,7 +527,7 @@ class FunctionLibraryRuntime {
CancellationManager* cancellation_manager = nullptr;
CollectiveExecutor* collective_executor = nullptr;
ScopedStepContainer* step_container = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index aab95b785b..e752599de1 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -70,7 +70,7 @@ class OpRegistryInterface;
class ResourceMgr;
class ScopedStepContainer;
class CollectiveExecutor;
-class StepStatsCollector;
+class StepStatsCollectorInterface;
class OpKernel {
public:
@@ -569,7 +569,7 @@ class OpKernelContext {
CallFrameInterface* call_frame = nullptr;
FunctionLibraryRuntime* function_library = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@@ -984,7 +984,7 @@ class OpKernelContext {
std::function<void(std::function<void()>)>* runner() const {
return params_->runner;
}
- StepStatsCollector* stats_collector() const {
+ StepStatsCollectorInterface* stats_collector() const {
return params_->stats_collector;
}
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 8d597e198d..3e77028a5f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -950,8 +950,7 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
*val = t->scalar<int64>()();
return Status::OK();
} else {
- return errors::InvalidArgument(
- "Scalar input for dim size must be int32 or int64");
+ return errors::InvalidArgument("Scalar input must be int32 or int64.");
}
}
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index c1a8a63784..bec41712b1 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -65,16 +65,37 @@ struct NodeOutEq {
static Node* AddZerosLike(Graph* g, NodeOut input) {
DCHECK_LT(0, input.dtype());
DCHECK_LT(input.dtype(), DT_FLOAT_REF);
- NodeDef ndef;
- ndef.set_name(g->NewName(kNodeLabel));
- ndef.set_op("ZerosLike");
- ndef.add_input(input.name());
- AddNodeAttr("T", input.dtype(), &ndef);
- Status s;
- Node* ret = g->AddNode(ndef, &s);
- TF_CHECK_OK(s);
- g->AddEdge(input.node, input.index, ret, 0);
- return ret;
+ if (input.dtype() == DT_RESOURCE) {
+ NodeDef read_def;
+ read_def.set_name(g->NewName("Read"));
+ read_def.set_op("ReadVariableOp");
+ read_def.add_input(input.name());
+ AddNodeAttr("dtype", DT_FLOAT, &read_def);
+ Status s;
+ Node* read = g->AddNode(read_def, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, read, 0);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(read_def.name());
+ AddNodeAttr("T", DT_FLOAT, &ndef);
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(read, 0, ret, 0);
+ return ret;
+ } else {
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+ }
}
static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index 67b252cb6c..ea7788f654 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -21,39 +21,14 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
-
-// HostConst: forced to generate output on the host.
-// Only used by testlib; no op is registered for this kernel
-// externally (i.e., in array_ops.cc)
-REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
-REGISTER_KERNEL_BUILDER(
- Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
-#ifdef TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(
- Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
-#endif // TENSORFLOW_USE_SYCL
-
-// Register the HostConst Op
-// Returns a constant tensor on the host. Useful for writing C++ tests
-// and benchmarks which run on GPU but require arguments pinned to the host.
-// Used by test::graph::HostConstant.
-// value: Attr `value` is the tensor to return.
-REGISTER_OP("HostConst")
- .Output("output: dtype")
- .Attr("value: tensor")
- .Attr("dtype: type")
- .SetShapeFn(shape_inference::UnknownShape);
-
namespace test {
namespace graph {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6406a4bdbf..0341d7f8e1 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -175,14 +175,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
int rank, bool* found_unknown_shapes) {
auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
+ bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
+
+ if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
*found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
VLOG(2) << "Use minimum shape because the rank is unknown.";
// The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
+ for (int i = shape.dim_size(); i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (is_scalar) {
+ for (int i = 0; i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (shape.dim_size() > rank) {
+ *found_unknown_shapes = true;
+ shape.clear_dim();
for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
+ shape.add_dim()->set_size(original_shape.dim(i).size());
}
} else {
for (int i = 0; i < shape.dim_size(); i++) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 7271a29319..9e579098ef 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -1126,5 +1126,77 @@ TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
}
}
+
+TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(true);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({1, 1, 1, 1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({1, 1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({10, 20}, y);
+
+ unknown_shapes = false;
+ TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ EXPECT_EQ(4, z.dim_size());
+ ExpectTensorShape({10, 20, 1, 1}, z);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ x.add_dim()->set_size(-1);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({10, 20, 1, 20}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ x.add_dim()->set_size(30);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({10, 20}, y);
+ }
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a30916d8b9..46c234d057 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -52,6 +52,8 @@ load(
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
+ "if_mkl_ml",
+ "mkl_deps",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
@@ -628,6 +630,7 @@ cc_library(
":gather_nd_op",
":gather_op",
":guarantee_const_op",
+ ":host_constant_op",
":identity_n_op",
":identity_op",
":inplace_ops",
@@ -702,6 +705,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "host_constant_op",
+ prefix = "host_constant_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "diag_op",
prefix = "diag_op",
deps = ARRAY_DEPS,
@@ -902,10 +911,7 @@ if_mkl(
"transpose_op.cc",
],
hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS + if_mkl([
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ]),
+ deps = ARRAY_DEPS + mkl_deps(),
)],
[tf_kernel_library(
name = "transpose_op",
@@ -2558,6 +2564,7 @@ tf_kernel_library(
# allow multiple definitions when linking this.
linkopts = select({
"//tensorflow:darwin": [],
+ "//tensorflow:windows": [],
"//conditions:default": ["-Wl,-z,muldefs"],
}),
visibility = [":friends"],
@@ -2867,7 +2874,7 @@ tf_kernel_library(
tf_kernel_library(
name = "batch_matmul_op",
- srcs = [] + if_mkl([
+ srcs = if_mkl_ml([
"mkl_batch_matmul_op.cc",
]),
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
@@ -2876,8 +2883,8 @@ tf_kernel_library(
# to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "batch_matmul_op",
- deps = MATH_DEPS + if_mkl([
- "//third_party/mkl:intel_binary_blob",
+ deps = MATH_DEPS + if_mkl_ml([
+ "//third_party/intel_mkl_ml",
]),
)
@@ -2959,10 +2966,7 @@ tf_kernel_library(
"@libxsmm_archive//:xsmm_avx",
],
"//conditions:default": [],
- }) + if_mkl([
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ]) + if_cuda([
+ }) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]),
)
@@ -6152,8 +6156,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6167,8 +6170,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6183,8 +6185,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6203,8 +6204,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6219,8 +6219,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6235,56 +6234,43 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_fused_batch_norm_op",
srcs = ["mkl_fused_batch_norm_op.cc"],
- deps = NN_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = NN_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_aggregate_ops",
prefix = "mkl_aggregate_ops",
- deps = MATH_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = MATH_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_concat_op",
prefix = "mkl_concat_op",
- deps = ARRAY_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = ARRAY_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_reshape_op",
prefix = "mkl_reshape_op",
- deps = ARRAY_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = ARRAY_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
- deps = ARRAY_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = ARRAY_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_lrn_op",
prefix = "mkl_lrn_op",
- deps = NN_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = NN_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6295,10 +6281,7 @@ tf_mkl_kernel_library(
"cwise_ops_gradients.h",
],
prefix = "mkl_cwise_ops_common",
- deps = NN_DEPS + [
- "cwise_op",
- "//third_party/mkl:intel_binary_blob",
- ],
+ deps = NN_DEPS + mkl_deps() + [":cwise_op"],
)
# NOTE(lespeholt): This rule is deprecated, please use:
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index a888422d49..375819a8a2 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -140,44 +140,6 @@ REGISTER_SYCL_KERNEL(SYCL, bool);
#undef REGISTER_SYCL_KERNEL
#endif
-HostConstantOp::HostConstantOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), tensor_(ctx->output_type(0)) {
- const TensorProto* proto = nullptr;
- AllocatorAttributes alloc_attr;
- alloc_attr.set_on_host(true);
- OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
- OP_REQUIRES_OK(
- ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_));
- OP_REQUIRES(
- ctx, ctx->output_type(0) == tensor_.dtype(),
- errors::InvalidArgument("Type mismatch between value (",
- DataTypeString(tensor_.dtype()), ") and dtype (",
- DataTypeString(ctx->output_type(0)), ")"));
-}
-
-void HostConstantOp::Compute(OpKernelContext* ctx) {
- ctx->set_output(0, tensor_);
-}
-
-#if GOOGLE_CUDA
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Const")
- .Device(DEVICE_GPU)
- .HostMemory("output")
- .TypeConstraint<int32>("dtype"),
- HostConstantOp);
-#endif
-
-#ifdef TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("Const")
- .Device(DEVICE_SYCL)
- .HostMemory("output")
- .TypeConstraint<int32>("dtype"),
- HostConstantOp);
-#endif // TENSORFLOW_USE_SYCL
-
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/constant_op.h b/tensorflow/core/kernels/constant_op.h
index b98153e347..77ba441863 100644
--- a/tensorflow/core/kernels/constant_op.h
+++ b/tensorflow/core/kernels/constant_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONSTANT_OP_H_
-#define TENSORFLOW_KERNELS_CONSTANT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -36,20 +36,6 @@ class ConstantOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(ConstantOp);
};
-// HostConstantOp differs from ConstantOp in that its output is always
-// in host memory.
-class HostConstantOp : public OpKernel {
- public:
- explicit HostConstantOp(OpKernelConstruction* ctx);
- void Compute(OpKernelContext* ctx) override;
- bool IsExpensive() override { return false; }
- ~HostConstantOp() override {}
-
- private:
- Tensor tensor_;
- TF_DISALLOW_COPY_AND_ASSIGN(HostConstantOp);
-};
-
class PlaceholderOp : public OpKernel {
public:
explicit PlaceholderOp(OpKernelConstruction* ctx);
@@ -61,4 +47,4 @@ class PlaceholderOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONSTANT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 98df0844ea..d6988a562c 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
+namespace functor {
+template <typename Device, typename T>
+struct SelectScalarHandler;
+} // namespace functor
+
template <typename Device, typename T>
class SelectOp : public OpKernel {
public:
@@ -131,16 +136,8 @@ class SelectOp : public OpKernel {
then->shape().DebugString(), " vs. ",
else_->shape().DebugString()));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
- {"t", "e"}, "output", then->shape(), &output));
-
- if (output->NumElements() > 0) {
- functor::SelectScalarFunctor<Device, T> func;
- TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
- func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
- then->flat<T>(), else_->flat<T>());
- }
+ functor::SelectScalarHandler<Device, T> handler;
+ handler(ctx, cond, then, else_);
}
private:
@@ -209,6 +206,40 @@ struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
+struct SelectScalarHandler {
+ void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
+ const Tensor* else_) {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+ {"t", "e"}, "output", then->shape(), &output));
+
+ if (output->NumElements() > 0) {
+ functor::SelectScalarFunctor<Device, T> func;
+ TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
+ then->flat<T>(), else_->flat<T>());
+ }
+ }
+};
+
+// Specilization for CPU device. Forward input to output depending on the `cond`
+// value.
+// TODO(sjhwang): Consider specializing for GPUDevice as well by using
+// GPUDevice::memcpyDeviceToHost() to fetch bool value.
+template <typename T>
+struct SelectScalarHandler<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
+ const Tensor* else_) {
+ if (cond->scalar<bool>()()) {
+ OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
+ } else {
+ OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
+ }
+ }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename Device, typename T>
struct SelectScalarFunctorBase {
void operator()(const Device& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
@@ -218,11 +249,6 @@ struct SelectScalarFunctorBase {
}
};
-// CPU Specializations of Select functors with scalar
-template <typename T>
-struct SelectScalarFunctor<CPUDevice, T>
- : SelectScalarFunctorBase<CPUDevice, T> {};
-#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct SelectScalarFunctor<SYCLDevice, T>
: SelectScalarFunctorBase<SYCLDevice, T> {};
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 5295c9d2a6..f9b5353724 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -49,11 +49,11 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
drop_remainder_(drop_remainder),
input_(input) {
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 3762e403a9..6ca0bcd37d 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -46,11 +46,11 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class FileDataset : public GraphDatasetBase {
+ class FileDataset : public DatasetBase {
public:
explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
string filename, Env* env)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
filename_(std::move(filename)),
env_(env),
@@ -539,10 +539,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
const string tensor_format_string_;
}; // FileDataset
- class MemoryDataset : public GraphDatasetBase {
+ class MemoryDataset : public DatasetBase {
public:
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ cache_(new MemoryCache()) {
input->Ref();
}
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index 6393005cdc..c361a9adcb 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -39,11 +39,11 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const DatasetBase* to_concatenate)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
to_concatenate_(to_concatenate) {
input_->Ref();
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 9105587cf4..9770bc025d 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -76,11 +76,11 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
private:
// TODO(mrry): Push the templated code down to the raw copying routine.
template <class T>
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size,
const PartialTensorShape& row_shape, const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
row_shape_(row_shape),
input_(input) {
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
index 4b6d808af0..ce577397c5 100644
--- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -48,12 +48,12 @@ class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
output_types_(output_types),
output_shapes_(std::move(output_shapes)) {
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index b11d7cf2ef..a80e102ccf 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -79,12 +79,12 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
private:
const int graph_def_version_;
- class FilterDatasetBase : public GraphDatasetBase {
+ class FilterDatasetBase : public DatasetBase {
public:
FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)) {
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 3419eed6c6..07bcb9d414 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -56,14 +56,14 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index c4dd849b8b..3c3d78b724 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -26,14 +26,14 @@ namespace tensorflow {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
+class GeneratorDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
std::unique_ptr<CapturedFunction> next_func,
std::unique_ptr<CapturedFunction> finalize_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
init_func_(std::move(init_func)),
next_func_(std::move(next_func)),
finalize_func_(std::move(finalize_func)),
@@ -47,12 +47,21 @@ class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index bcf0adacc7..be4132a064 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -66,7 +66,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_key_func,
@@ -75,7 +75,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_finalize_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
captured_key_func_(std::move(captured_key_func)),
captured_init_func_(std::move(captured_init_func)),
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 683a50e71c..288695f3cd 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -93,7 +93,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& key_func, const NameAttrList& reduce_func,
@@ -103,7 +103,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_window_size_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
key_func_(key_func),
reduce_func_(reduce_func),
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 8fee29d4d0..58b79d6026 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -76,14 +76,14 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index da9d29dd76..61a6c06135 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -130,7 +130,7 @@ class IteratorResource : public ResourceBase {
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
string serialized_graph_def;
- TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
+ TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphKey,
&serialized_graph_def));
GraphDef graph_def;
if (!graph_def.ParseFromString(serialized_graph_def)) {
@@ -138,7 +138,7 @@ class IteratorResource : public ResourceBase {
}
string output_node;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
+ DatasetBase::kDatasetGraphOutputNodeKey, &output_node));
DatasetBase* dataset = nullptr;
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
@@ -161,9 +161,9 @@ class IteratorResource : public ResourceBase {
graph_runner.Run(&graph, lib, {}, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ TF_RETURN_IF_ERROR(
+ dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
@@ -611,9 +611,9 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
@@ -633,11 +633,11 @@ class ToSingleElementOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
- dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
+ dataset->MakeIterator(IteratorContext(ctx), "SingleElementIterator",
+ &iterator),
done);
// NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
@@ -651,8 +651,8 @@ class ToSingleElementOp : public AsyncOpKernel {
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence = false;
- Status s =
- raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ Status s = raw_iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
if (!s.ok()) {
ctx->SetStatus(s);
return;
@@ -667,8 +667,8 @@ class ToSingleElementOp : public AsyncOpKernel {
}
components.clear();
- Status s2 =
- raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ Status s2 = raw_iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
if (!s2.ok()) {
ctx->SetStatus(s2);
return;
@@ -836,9 +836,9 @@ class OneShotIteratorOp : public AsyncOpKernel {
// factory function.
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iter;
- TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
+ TF_RETURN_IF_ERROR(
+ dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 51a7fd23a8..0e17011b05 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -101,7 +101,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
@@ -110,7 +110,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const Eigen::ThreadPoolDevice* device)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index ec9e12453b..294fb1c49a 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -55,14 +55,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 8add049123..b097598cd9 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -59,13 +59,13 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
optimizations_(optimizations),
output_types_(output_types),
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index 755d46dac2..be45eac46e 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -98,12 +98,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
std::vector<PartialTensorShape> padded_shapes,
std::vector<Tensor> padding_values, const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
drop_remainder_(drop_remainder),
padded_shapes_(std::move(padded_shapes)),
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index d2b83f9eab..e492a8215a 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -92,7 +92,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
@@ -100,7 +100,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
int64 block_length, bool sloppy, int64 buffer_output_elements,
int64 prefetch_input_elements, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
interleave_func_(func),
captured_func_(std::move(captured_func)),
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index c56a7ea808..a407abfce4 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -67,14 +67,14 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func, int32 num_parallel_calls,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
std::unique_ptr<CapturedFunction> captured_func)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
num_parallel_calls_(num_parallel_calls),
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 20148a4378..50efbcbe2a 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -25,10 +25,12 @@ namespace tensorflow {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
+class PrefetchDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size)
- : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ buffer_size_(buffer_size) {
input_->Ref();
}
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index 7e48428b3f..7817170e73 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -49,10 +49,10 @@ class RandomDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
- : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
+ : DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index 50bd3dac4e..aa38775125 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -43,10 +43,13 @@ class RangeDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
- : GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {}
+ : DatasetBase(DatasetContext(ctx)),
+ start_(start),
+ stop_(stop),
+ step_(step) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 6a71a7af1d..086b552936 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -78,12 +78,12 @@ class TextLineDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type,
const io::ZlibCompressionOptions& options)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
use_compression_(!compression_type.empty()),
@@ -312,12 +312,12 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
int64 header_bytes, int64 record_bytes, int64 footer_bytes,
int64 buffer_size)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
header_bytes_(header_bytes),
record_bytes_(record_bytes),
@@ -531,11 +531,11 @@ class TFRecordDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type, int64 buffer_size)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 093ea563b4..5e9ace3486 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -39,10 +39,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 7c59874d96..e4cb31e2b2 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -69,7 +69,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func, std::vector<Tensor> initial_state,
@@ -77,7 +77,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& state_types,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
initial_state_(std::move(initial_state)),
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 603c3feb79..93a4376836 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -40,11 +40,11 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
protected:
// Abstract base dataset that implements a shuffling iterator.
- class ShuffleDatasetBase : public GraphDatasetBase {
+ class ShuffleDatasetBase : public DatasetBase {
public:
ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
int64 buffer_size, int64 count)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
buffer_size_(buffer_size),
count_(count) {
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index 61db6a0a54..fe7ef38d5f 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -38,10 +38,10 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index fd8c5ccd92..14df3a6801 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -63,11 +63,11 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 window_size, int64 window_shift,
int64 window_stride, const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
window_size_(window_size),
window_shift_(window_shift),
window_stride_(window_stride),
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index 9bb86e76a2..e526578701 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -28,11 +28,11 @@ namespace {
// description of the following op.
template <typename T>
-class Dataset : public GraphDatasetBase {
+class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx,
const sparse::SparseTensor& sparse_tensor)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
sparse_tensor_(sparse_tensor),
dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}),
shapes_({{-1, sparse_tensor.dims() - 1},
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 9b0190e3fc..2aa153fcfa 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -75,13 +75,13 @@ class SqlDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const string& driver_name,
const string& data_source_name, const string& query,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
driver_name_(driver_name),
data_source_name_(data_source_name),
query_(query),
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 8465a1d2c0..75af73df54 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -37,11 +37,11 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
StatsAggregatorResource* stats_aggregator_resource)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
stats_aggregator_resource_(stats_aggregator_resource) {
input_->Ref();
@@ -70,6 +70,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
return "SetStatsAggregatorDatasetOp::Dataset";
}
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 85fed31773..52753a3ccd 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -49,10 +49,12 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ tag_(std::move(tag)) {
input_->Ref();
}
@@ -149,10 +151,12 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ tag_(std::move(tag)) {
input_->Ref();
}
@@ -255,10 +259,12 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ tag_(std::move(tag)) {
input_->Ref();
}
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index d4a3c7a978..e5c237dfaa 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -38,10 +38,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index ac2015c865..fc21c3235a 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -43,10 +43,10 @@ class TensorDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
- : GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
+ : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
shapes_.emplace_back(t.shape().dim_sizes());
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index ea472e2b79..ccd5e60acc 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -61,14 +61,14 @@ std::vector<PartialTensorShape> PrependQueueShapeWithBatch(
class EnqueueInQueueDatasetOp;
-class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
+class PrependFromQueueAndPaddedBatchDataset : public DatasetBase {
public:
PrependFromQueueAndPaddedBatchDataset(
OpKernelContext* ctx, const int64 batch_size, const DatasetBase* input,
const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
std::vector<Tensor> padding_values)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
input_(input),
dtypes_(dtypes),
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 8f18d38f83..5b051e0e08 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -54,10 +54,10 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
- : GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
+ : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
gtl::InlinedVector<int64, 4> partial_dim_sizes;
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 02c3c5315a..1a79f72b28 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -35,10 +35,10 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
for (const PartialTensorShape& shape : input->output_shapes()) {
gtl::InlinedVector<int64, 4> partial_dim_sizes;
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 17551bccd9..0ab6beabfc 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/window_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
-// TODO(b/110981596): Support checkpointing.
class WindowDataset : public DatasetBase {
public:
WindowDataset(std::vector<std::vector<Tensor>> elements,
DataTypeVector output_types,
std::vector<PartialTensorShape> output_shapes)
- : elements_(std::move(elements)),
+ : DatasetBase(DatasetContext({"Window"})),
+ elements_(std::move(elements)),
output_types_(std::move(output_types)),
output_shapes_(std::move(output_shapes)) {}
@@ -41,6 +42,15 @@ class WindowDataset : public DatasetBase {
string DebugString() const override { return "WindowDataset"; }
+ protected:
+ // TODO(b/110981596): Support checkpointing.
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public DatasetIterator<WindowDataset> {
public:
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index f9fd5b5a83..41bf9d43fe 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -43,10 +43,12 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
- : GraphDatasetBase(ctx), window_size_(window_size), input_(input) {
+ : DatasetBase(DatasetContext(ctx)),
+ window_size_(window_size),
+ input_(input) {
input_->Ref();
}
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 80d9a5b867..1c49874a6a 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -70,20 +70,21 @@ class ToTFRecordOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
- dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator),
+ dataset->MakeIterator(IteratorContext(ctx), "ToTFRecordOpIterator",
+ &iterator),
done);
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
do {
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
+ OP_REQUIRES_OK_ASYNC(ctx,
+ iterator->GetNext(IteratorContext(ctx),
+ &components, &end_of_sequence),
+ done);
if (!end_of_sequence) {
OP_REQUIRES_OK_ASYNC(
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index 63e9b99d4b..e4306579ed 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -38,11 +38,11 @@ class ZipDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx,
const std::vector<DatasetBase*>& inputs)
- : GraphDatasetBase(ctx), inputs_(inputs) {
+ : DatasetBase(DatasetContext(ctx)), inputs_(inputs) {
for (const auto& input : inputs_) {
input->Ref();
for (DataType dt : input->output_dtypes()) {
diff --git a/tensorflow/core/kernels/host_constant_op.cc b/tensorflow/core/kernels/host_constant_op.cc
new file mode 100644
index 0000000000..d08a7c9bd2
--- /dev/null
+++ b/tensorflow/core/kernels/host_constant_op.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/host_constant_op.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+_HostConstantOp::_HostConstantOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), tensor_(ctx->output_type(0)) {
+ const TensorProto* proto = nullptr;
+ AllocatorAttributes alloc_attr;
+ alloc_attr.set_on_host(true);
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
+ OP_REQUIRES_OK(
+ ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_));
+ OP_REQUIRES(
+ ctx, ctx->output_type(0) == tensor_.dtype(),
+ errors::InvalidArgument("Type mismatch between value (",
+ DataTypeString(tensor_.dtype()), ") and dtype (",
+ DataTypeString(ctx->output_type(0)), ")"));
+}
+
+void _HostConstantOp::Compute(OpKernelContext* ctx) {
+ ctx->set_output(0, tensor_);
+}
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Const")
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<int32>("dtype"),
+ _HostConstantOp);
+#endif
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Const")
+ .Device(DEVICE_SYCL)
+ .HostMemory("output")
+ .TypeConstraint<int32>("dtype"),
+ _HostConstantOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// HostConst: forced to generate output on the host.
+// Only used in tests; no op is registered for this kernel
+// externally (i.e., in array_ops.cc)
+REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), _HostConstantOp);
+REGISTER_KERNEL_BUILDER(
+ Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), _HostConstantOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"),
+ _HostConstantOp);
+#endif // TENSORFLOW_USE_SYCL
+
+} // end namespace tensorflow
+
diff --git a/tensorflow/core/kernels/host_constant_op.h b/tensorflow/core/kernels/host_constant_op.h
new file mode 100644
index 0000000000..1b887ea1aa
--- /dev/null
+++ b/tensorflow/core/kernels/host_constant_op.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// HostConstantOp differs from ConstantOp in that its output is always
+// in host memory.
+class _HostConstantOp : public OpKernel {
+ public:
+ explicit _HostConstantOp(OpKernelConstruction* ctx);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+ ~_HostConstantOp() override {}
+
+ private:
+ Tensor tensor_;
+ TF_DISALLOW_COPY_AND_ASSIGN(_HostConstantOp);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 07e754a6ef..2e8d9c623c 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -341,7 +341,7 @@ class MutableDenseHashTable final : public LookupInterface {
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 key_size = key_shape_.num_elements();
const int64 value_size = value_shape_.num_elements();
if (key.NumElements() != num_elements * key_size) {
@@ -403,8 +403,9 @@ class MutableDenseHashTable final : public LookupInterface {
Status Insert(OpKernelContext* ctx, const Tensor& key,
const Tensor& value) override LOCKS_EXCLUDED(mu_) {
- if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
- TensorShape expected_shape({key.dim_size(0)});
+ const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
+ if (key.NumElements() != batch_size * key_shape_.num_elements()) {
+ TensorShape expected_shape({batch_size});
expected_shape.AppendShape(key_shape_);
return errors::InvalidArgument("Expected key shape ",
expected_shape.DebugString(), " got ",
@@ -415,7 +416,7 @@ class MutableDenseHashTable final : public LookupInterface {
// rather than updates. That means we may grow the table even though we
// don't need to. As long as the number of keys inserted in one call is
// small compared to the size of the map, the impact of this is minimal.
- const int64 pending_num_entries = num_entries_ + key.dim_size(0);
+ const int64 pending_num_entries = num_entries_ + batch_size;
if (pending_num_entries > num_buckets_ * max_load_factor_) {
int64 new_num_buckets = num_buckets_;
do {
@@ -500,7 +501,7 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
@@ -812,17 +813,21 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>)
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int32, string);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
+REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(string, double);
REGISTER_KERNEL(string, float);
REGISTER_KERNEL(string, int32);
REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
REGISTER_KERNEL(string, string);
-REGISTER_KERNEL(string, bool);
-REGISTER_KERNEL(int32, int32);
-REGISTER_KERNEL(int32, string);
#undef REGISTER_KERNEL
@@ -843,12 +848,20 @@ REGISTER_KERNEL(int32, string);
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -869,10 +882,19 @@ REGISTER_KERNEL(int64, Variant);
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -893,13 +915,20 @@ REGISTER_KERNEL(string, bool);
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
-REGISTER_KERNEL(int64, double);
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
REGISTER_KERNEL(int64, bool);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index c7d0d4de0d..5d9257e20b 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -126,7 +126,7 @@ void DoNonMaxSuppressionOp(
const Tensor& max_output_size, const float score_threshold,
const std::function<bool(int, int)>& suppress_check_fn,
bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
- const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
+ const int output_size = max_output_size.scalar<int>()();
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 115a8eb251..ebcfb673d1 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -232,12 +232,12 @@ class AssignVariableOp : public OpKernel {
return Status::OK();
}));
core::ScopedUnref s(variable);
+ mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
- mutex_lock ml(*variable->mu());
variable->is_initialized = true;
*variable->tensor() = value;
}
@@ -268,11 +268,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
return Status::OK();
}));
core::ScopedUnref s(variable);
- OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
- errors::InvalidArgument(
- "Trying to assign variable with wrong dtype. Expected ",
- DataTypeString(variable->tensor()->dtype()), " got ",
- DataTypeString(DT_VARIANT)));
// For purposes of forwarding DT_VARIANT, we want the least
// restrictive attr; we already know the input is on host.
@@ -293,6 +288,11 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
attr);
mutex_lock ml(*variable->mu());
+ OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
+ errors::InvalidArgument(
+ "Trying to assign variable with wrong dtype. Expected ",
+ DataTypeString(variable->tensor()->dtype()), " got ",
+ DataTypeString(DT_VARIANT)));
variable->is_initialized = true;
*variable->tensor() = Tensor(DT_VARIANT, value.shape());
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h
index 55be308901..f75723af7d 100644
--- a/tensorflow/core/kernels/shape_ops.h
+++ b/tensorflow/core/kernels/shape_ops.h
@@ -154,6 +154,9 @@ class ExpandDimsOp : public OpKernel {
OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
errors::InvalidArgument("ExpandDims on Variant not supported"));
+ OP_REQUIRES(
+ ctx, (ctx->input(1).NumElements() == 1),
+ errors::InvalidArgument("'dim' must be a tensor with a single value"));
Tdim dim = ctx->input(1).flat<Tdim>()(0);
OP_REQUIRES(
ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
@@ -236,9 +239,8 @@ class SqueezeOp : public OpKernel {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
errors::InvalidArgument(
- "Tried to explicitly squeeze "
- "dimension ",
- i, " but dimension was not 1: ", existing_dim));
+ "Can not squeeze dim[", i,
+ "], expected a dimension of 1, got ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index ef8ad7972c..1d11ec00ce 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -427,7 +427,19 @@ REGISTER_OP("UnravelIndex")
.Input("dims: Tidx")
.Output("output: Tidx")
.Attr("Tidx: {int32, int64} = DT_INT32")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices = c->input(0);
+ ShapeHandle dims;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
+ if (c->RankKnown(indices) && c->Rank(indices) == 0) {
+ c->set_output(0, c->Vector(c->Dim(dims, 0)));
+ } else if (c->RankKnown(indices)) {
+ c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
+ } else {
+ c->set_output(0, c->UnknownShape());
+ }
+ return Status::OK();
+ });
REGISTER_OP("BroadcastTo")
.Input("input: T")
@@ -690,6 +702,16 @@ REGISTER_OP("Const")
return Status::OK();
});
+// Returns a constant tensor on the host. Useful for writing C++ tests
+// and benchmarks which run on GPU but require arguments pinned to the host.
+// Used by test::graph::HostConstant.
+// value: Attr `value` is the tensor to return.
+REGISTER_OP("HostConst")
+ .Output("output: dtype")
+ .Attr("value: tensor")
+ .Attr("dtype: type")
+ .SetShapeFn(shape_inference::UnknownShape);
+
// --------------------------------------------------------------------------
// TODO(mgubin): Update the doc when the freeze_graph script supports converting
// into memmapped format.
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index b1463338fb..c15409a246 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -27,6 +27,21 @@ limitations under the License.
namespace tensorflow {
+TEST(ArrayOpsTest, UnravelIndex_ShapeFn) {
+ ShapeInferenceTestOp op("UnravelIndex");
+
+ INFER_OK(op, "?;?", "?");
+
+ INFER_OK(op, "[];[?]", "[d1_0]");
+
+ INFER_OK(op, "[4,5];[?]", "[d1_0,20]");
+ INFER_OK(op, "[2,3,4];[?]", "[d1_0,24]");
+ INFER_OK(op, "?;[?]", "?");
+ INFER_OK(op, "[?];[?]", "[d1_0,?]");
+
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,1]");
+}
+
TEST(ArrayOpsTest, Pack_ShapeFn) {
ShapeInferenceTestOp op("Pack");
auto set_axis = [&op](int axis) {
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 44dddffd59..92ccbd979d 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -25613,6 +25613,21 @@ op {
}
}
op {
+ name: "HostConst"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+}
+op {
name: "IFFT"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 81f324a3ef..11ca0bd259 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -108,6 +108,29 @@ Status ColorspaceShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status NMSShapeFn(InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+}
+
} // namespace
// --------------------------------------------------------------------------
@@ -694,29 +717,7 @@ REGISTER_OP("NonMaxSuppressionV3")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
- .SetShapeFn([](InferenceContext* c) {
- // Get inputs and validate ranks.
- ShapeHandle boxes;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
- ShapeHandle scores;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
- ShapeHandle max_output_size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
- ShapeHandle iou_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
- ShapeHandle score_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
- // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
- DimensionHandle unused;
- // The boxes[0] and scores[0] are both num_boxes.
- TF_RETURN_IF_ERROR(
- c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
- // The boxes[1] is 4.
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
-
- c->set_output(0, c->Vector(c->UnknownDim()));
- return Status::OK();
- });
+ .SetShapeFn(NMSShapeFn);
REGISTER_OP("NonMaxSuppressionV4")
.Input("boxes: float")
@@ -728,26 +729,16 @@ REGISTER_OP("NonMaxSuppressionV4")
.Output("valid_outputs: int32")
.Attr("pad_to_max_output_size: bool = false")
.SetShapeFn([](InferenceContext* c) {
- // Get inputs and validate ranks.
- ShapeHandle boxes;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
- ShapeHandle scores;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
- ShapeHandle max_output_size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
- ShapeHandle iou_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
- ShapeHandle score_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
- // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
- DimensionHandle unused;
- // The boxes[0] and scores[0] are both num_boxes.
- TF_RETURN_IF_ERROR(
- c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
- // The boxes[1] is 4.
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
-
- c->set_output(0, c->Vector(c->UnknownDim()));
+ TF_RETURN_IF_ERROR(NMSShapeFn(c));
+
+ bool pad_to_max;
+ TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
+ if (pad_to_max) {
+ // If padded, overwrite the shape of the output to be static.
+ DimensionHandle output_dim;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
+ c->set_output(0, c->MakeShape({output_dim}));
+ }
c->set_output(1, c->MakeShape({}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 2059741da9..7c71406c6b 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -23,6 +23,7 @@ namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
+using shape_inference::ShapeAndType;
using shape_inference::ShapeHandle;
// --------------------------------------------------------------------------
@@ -86,6 +87,74 @@ REGISTER_OP("LookupTableFind")
return Status::OK();
});
+Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys,
+ const string& key_dtype_attr,
+ const string& value_dtype_attr,
+ bool is_lookup,
+ ShapeAndType* output_shape_and_type) {
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data == nullptr || handle_data->size() != 2) {
+ output_shape_and_type->shape = c->UnknownShape();
+ output_shape_and_type->dtype = DT_INVALID;
+ } else {
+ const ShapeAndType& key_shape_and_type = (*handle_data)[0];
+ const ShapeAndType& value_shape_and_type = (*handle_data)[1];
+ DataType key_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype));
+ if (key_shape_and_type.dtype != key_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read value with wrong dtype. "
+ "Expected ",
+ DataTypeString(key_shape_and_type.dtype), " got ",
+ DataTypeString(key_dtype));
+ }
+ DataType value_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype));
+ if (value_shape_and_type.dtype != value_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read value with wrong dtype. "
+ "Expected ",
+ DataTypeString(value_shape_and_type.dtype), " got ",
+ DataTypeString(value_dtype));
+ }
+ output_shape_and_type->dtype = value_shape_and_type.dtype;
+
+ if (is_lookup) {
+ if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) {
+ int keys_rank = c->Rank(keys);
+ int key_suffix_rank = c->Rank(key_shape_and_type.shape);
+ if (keys_rank < key_suffix_rank) {
+ return errors::InvalidArgument(
+ "Expected keys to have suffix ",
+ c->DebugString(key_shape_and_type.shape),
+ " but saw shape: ", c->DebugString(keys));
+ }
+ for (int d = 0; d < key_suffix_rank; d++) {
+ // Ensure the suffix of keys match what's in the Table.
+ DimensionHandle dim = c->Dim(key_shape_and_type.shape, d);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys));
+ }
+ std::vector<DimensionHandle> keys_prefix_vec;
+ keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
+ for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
+ keys_prefix_vec.push_back(c->Dim(keys, d));
+ }
+ ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec);
+ TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix,
+ value_shape_and_type.shape,
+ &output_shape_and_type->shape));
+ } else {
+ output_shape_and_type->shape = c->UnknownShape();
+ }
+ } else {
+ TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape,
+ &output_shape_and_type->shape));
+ }
+ }
+ return Status::OK();
+}
+
REGISTER_OP("LookupTableFindV2")
.Input("table_handle: resource")
.Input("keys: Tin")
@@ -98,9 +167,18 @@ REGISTER_OP("LookupTableFindV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
+
+ ShapeAndType value_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
+ c,
+ /*keys=*/c->input(1),
+ /*key_dtype_attr=*/"Tin",
+ /*value_dtype_attr=*/"Tout",
+ /*is_lookup=*/true, &value_shape_and_type));
+ c->set_output(0, value_shape_and_type.shape);
+
return Status::OK();
});
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2");
@@ -177,12 +255,16 @@ REGISTER_OP("LookupTableExportV2")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ ShapeHandle keys = c->UnknownShapeOfRank(1);
+ ShapeAndType value_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
+ c,
+ /*keys=*/keys,
+ /*key_dtype_attr=*/"Tkeys",
+ /*value_dtype_attr=*/"Tvalues",
+ /*is_lookup=*/false, &value_shape_and_type));
c->set_output(0, keys);
- c->set_output(1, values);
+ c->set_output(1, value_shape_and_type.shape);
return Status::OK();
});
@@ -216,6 +298,26 @@ REGISTER_OP("LookupTableImportV2")
return Status::OK();
});
+Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key,
+ const ShapeHandle& value) {
+ c->set_output(0, c->Scalar());
+
+ ShapeHandle key_s;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s));
+
+ DataType key_t;
+ TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t));
+
+ DataType value_t;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t));
+
+ // ShapeAndType vector for {key, value}.
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}});
+
+ return Status::OK();
+}
+
REGISTER_OP("HashTable")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@@ -254,7 +356,10 @@ REGISTER_OP("MutableHashTableV2")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ return MutableHashTableShape(c, /*key=*/c->Scalar(),
+ /*value=*/c->Scalar());
+ });
REGISTER_OP("MutableHashTableOfTensors")
.Output("table_handle: Ref(string)")
@@ -276,7 +381,13 @@ REGISTER_OP("MutableHashTableOfTensorsV2")
.Attr("value_dtype: type")
.Attr("value_shape: shape = {}")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape value_p;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
+ ShapeHandle value_s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
+ return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s);
+ });
REGISTER_OP("MutableDenseHashTable")
.Input("empty_key: key_dtype")
@@ -304,7 +415,13 @@ REGISTER_OP("MutableDenseHashTableV2")
.Attr("initial_num_buckets: int = 131072") // 2^17
.Attr("max_load_factor: float = 0.8")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape value_p;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
+ ShapeHandle value_s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
+ return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s);
+ });
REGISTER_OP("InitializeTable")
.Input("table_handle: Ref(string)")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 1fda569b8e..eda82f9c18 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -12257,6 +12257,21 @@ op {
}
}
op {
+ name: "HostConst"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+}
+op {
name: "IFFT"
input_arg {
name: "input"
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 28891320c4..6383180e94 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -8,7 +8,7 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"//third_party/mkl:build_defs.bzl",
- "if_mkl",
+ "if_mkl_ml",
)
# Appends a suffix to a list of deps.
@@ -467,7 +467,6 @@ def tf_platform_srcs(files):
return select({
"//tensorflow:windows" : native.glob(windows_set),
- "//tensorflow:windows_msvc" : native.glob(windows_set),
"//conditions:default" : native.glob(posix_set),
})
@@ -479,7 +478,6 @@ def tf_additional_lib_hdrs(exclude = []):
], exclude = exclude)
return select({
"//tensorflow:windows" : windows_hdrs,
- "//tensorflow:windows_msvc" : windows_hdrs,
"//conditions:default" : native.glob([
"platform/default/*.h",
"platform/posix/*.h",
@@ -494,7 +492,6 @@ def tf_additional_lib_srcs(exclude = []):
], exclude = exclude)
return select({
"//tensorflow:windows" : windows_srcs,
- "//tensorflow:windows_msvc" : windows_srcs,
"//conditions:default" : native.glob([
"platform/default/*.cc",
"platform/posix/*.cc",
@@ -703,8 +700,8 @@ def tf_additional_binary_deps():
# core).
"//tensorflow/core/kernels:lookup_util",
"//tensorflow/core/util/tensor_bundle",
- ] + if_mkl(
+ ] + if_mkl_ml(
[
- "//third_party/mkl:intel_binary_blob",
+ "//third_party/intel_mkl_ml",
],
)
diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md
index 0aa8e7612a..865a203bf8 100644
--- a/tensorflow/docs_src/community/index.md
+++ b/tensorflow/docs_src/community/index.md
@@ -25,10 +25,10 @@ the appropriate repository for the project. Major repositories include:
### Security
-Before using TensorFlow, please take a look at our security model, list of
-recent security announcements, and ways you can report security issues to the
-TensorFlow team at the
-[Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) page on GitHub.
+Before using TensorFlow, please take a look at our [security model](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md#tensorflow-models-are-programs),
+[list of recent security advisories and announcements](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md),
+and [ways you can report security issues](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md#reporting-vulnerabilities)
+to the TensorFlow team at the [Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) page on GitHub.
## Stay Informed
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index 017fdaf81e..e47a8b599c 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -193,7 +193,8 @@ class MNISTModel(tf.keras.Model):
def call(self, input):
"""Run the model."""
result = self.dense1(input)
- result = self.dense2(result) # reuse variables from dense1 layer
+ result = self.dense2(result)
+ result = self.dense2(result) # reuse variables from dense2 layer
return result
model = MNISTModel()
@@ -567,9 +568,8 @@ inserted during model construction. For example, to record summaries once every
100 global steps:
```py
+global_step = tf.train.get_or_create_global_step()
writer = tf.contrib.summary.create_file_writer(logdir)
-global_step=tf.train.get_or_create_global_step() # return global step var
-
writer.set_as_default()
for _ in range(iterations):
diff --git a/tensorflow/docs_src/install/install_sources_windows.md b/tensorflow/docs_src/install/install_sources_windows.md
new file mode 100644
index 0000000000..a1da122317
--- /dev/null
+++ b/tensorflow/docs_src/install/install_sources_windows.md
@@ -0,0 +1,320 @@
+# Install TensorFlow from Sources on Windows
+
+This guide explains how to build TensorFlow sources into a TensorFlow binary and
+how to install that TensorFlow binary on Windows.
+
+## Determine which TensorFlow to install
+
+You must choose one of the following types of TensorFlow to build and install:
+
+* **TensorFlow with CPU support only**. If your system does not have a NVIDIA®
+ GPU, build and install this version. Note that this version of TensorFlow is
+ typically easier to build and install, so even if you have an NVIDIA GPU, we
+ recommend building and installing this version first.
+* **TensorFlow with GPU support**. TensorFlow programs typically run
+ significantly faster on a GPU than on a CPU. Therefore, if your system has a
+ NVIDIA GPU and you need to run performance-critical applications, you should
+ ultimately build and install this version. Beyond the NVIDIA GPU itself,
+ your system must also fulfill the NVIDIA software requirements described in
+ the following document:
+
+ * [Installing TensorFlow on Windows](install_windows.md#NVIDIARequirements)
+
+## Prepare environment for Windows
+
+Before building TensorFlow on Windows, install the following build tools on your
+system:
+
+* [MSYS2](#InstallMSYS2)
+* [Visual C++ build tools](#InstallVCBuildTools)
+* [Bazel for Windows](#InstallBazel)
+* [TensorFlow Python dependencies](#InstallPython)
+* [optionally, NVIDIA packages to support TensorFlow for GPU](#InstallCUDA)
+
+<a name="InstallMSYS2"></a>
+
+### Install MSYS2
+
+Bash bin tools are used in TensorFlow Bazel build, you can install them through [MSYS2](https://www.msys2.org/).
+
+Assume you installed MSYS2 at `C:\msys64`, add `C:\msys64\usr\bin` to your `%PATH%` environment variable.
+
+To install necessary bash bin tools, issue the following command under `cmd.exe`:
+
+<pre>
+C:\> <b>pacman -S git patch unzip</b>
+</pre>
+
+<a name="InstallVCBuildTools"></a>
+
+### Install Visual C++ Build Tools 2015
+
+To build TensorFlow, you need to install Visual C++ build tools 2015. It is a part of Visual Studio 2015.
+But you can install it separately by the following way:
+
+ * Open the [official downloand page](https://visualstudio.microsoft.com/vs/older-downloads/).
+ * Go to <b>Redistributables and Build Tools</b> section.
+ * Find <b>Microsoft Build Tools 2015 Update 3</b> and click download.
+ * Run the installer.
+
+It's possible to build TensorFlow with newer version of Visual C++ build tools,
+but we only test against Visual Studio 2015 Update 3.
+
+<a name="InstallBazel"></a>
+
+### Install Bazel
+
+If bazel is not installed on your system, install it now by following
+[these instructions](https://docs.bazel.build/versions/master/install-windows.html).
+It is recommended to use a Bazel version >= `0.15.0`.
+
+Add the directory where you installed Bazel to your `%PATH%` environment variable.
+
+<a name="InstallPython"></a>
+
+### Install TensorFlow Python dependencies
+
+If you don't have Python 3.5 or Python 3.6 installed, install it now:
+
+ * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/)
+ * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/)
+
+To build and install TensorFlow, you must install the following python packages:
+
+* `six`, which provides simple utilities for wrapping over differences between
+ Python 2 and Python 3.
+* `numpy`, which is a numerical processing package that TensorFlow requires.
+* `wheel`, which enables you to manage Python compressed packages in the wheel
+ (.whl) format.
+* `keras_applications`, the applications module of the Keras deep learning library.
+* `keras_preprocessing`, the data preprocessing and data augmentation module
+ of the Keras deep learning library.
+
+Assume you already have `pip3` in `%PATH%`, issue the following command:
+
+<pre>
+C:\> <b>pip3 install six numpy wheel</b>
+C:\> <b>pip3 install keras_applications==1.0.4 --no-deps</b>
+C:\> <b>pip3 install keras_preprocessing==1.0.2 --no-deps</b>
+</pre>
+
+<a name="InstallCUDA"></a>
+
+### Optional: install TensorFlow for GPU prerequisites
+
+If you are building TensorFlow without GPU support, skip this section.
+
+The following NVIDIA® _hardware_ must be installed on your system:
+
+* GPU card with CUDA Compute Capability 3.5 or higher. See
+ [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of
+ supported GPU cards.
+
+The following NVIDIA® _software_ must be installed on your system:
+
+* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher.
+* [CUDA Toolkit](http://nvidia.com/cuda) (>= 8.0). We recommend version 9.0.
+* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 6.0). We recommend
+ version 7.1.x.
+* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but
+ you also need to append its path to `%PATH%` environment
+ variable.
+
+Assume you have CUDA Toolkit installed at `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0`
+and cuDNN at `C:\tools\cuda`, issue the following commands.
+
+<pre>
+C:\> SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\bin;%PATH%
+C:\> SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\extras\CUPTI\libx64;%PATH%
+C:\> SET PATH=C:\tools\cuda\bin;%PATH%
+</pre>
+
+## Clone the TensorFlow repository
+
+Now you need to clone **the latest** TensorFlow repository,
+thanks to MSYS2 we already have `git` avaiable, issue the following command:
+
+<pre>C:\> <b>git clone https://github.com/tensorflow/tensorflow.git</b> </pre>
+
+The preceding <code>git clone</code> command creates a subdirectory named
+`tensorflow`. After cloning, you may optionally build a **specific branch**
+(such as a release branch) by invoking the following commands:
+
+<pre>
+C:\> <b>cd tensorflow</b>
+C:\> <b>git checkout</b> <i>Branch</i> # where <i>Branch</i> is the desired branch
+</pre>
+
+For example, to work with the `r1.11` release instead of the master release,
+issue the following command:
+
+<pre>C:\> <b>git checkout r1.11</b></pre>
+
+Next, you must now configure the installation.
+
+## Configure the installation
+
+The root of the source tree contains a python script named <code>configure.py</code>.
+This script asks you to identify the pathname of all relevant TensorFlow
+dependencies and specify other build configuration options such as compiler
+flags. You must run this script *prior* to creating the pip package and
+installing TensorFlow.
+
+If you wish to build TensorFlow with GPU, `configure.py` will ask you to specify
+the version numbers of CUDA and cuDNN. If several versions of CUDA or cuDNN are
+installed on your system, explicitly select the desired version instead of
+relying on the default.
+
+One of the questions that `configure.py` will ask is as follows:
+
+<pre>
+Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is /arch:AVX]:
+</pre>
+
+Here is an example execution of the `configure.py` script. Note that your own input
+will likely differ from our sample input:
+
+<pre>
+C:\> <b>cd tensorflow</b> # cd to the top-level directory created
+C:\tensorflow> <b>python ./configure.py</b>
+Starting local Bazel server and connecting to it...
+................
+You have bazel 0.15.0 installed.
+Please specify the location of python. [Default is C:\python36\python.exe]:
+
+Found possible Python library paths:
+ C:\python36\lib\site-packages
+Please input the desired Python library path to use. Default is [C:\python36\lib\site-packages]
+
+Do you wish to build TensorFlow with CUDA support? [y/N]: <b>Y</b>
+CUDA support will be enabled for TensorFlow.
+
+Please specify the CUDA SDK version you want to use. [Leave empty to default to CUDA 9.0]:
+
+Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0]:
+
+Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: <b>7.0</b>
+
+Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0]: <b>C:\tools\cuda</b>
+
+Please specify a list of comma-separated Cuda compute capabilities you want to build with.
+You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
+Please note that each additional compute capability significantly increases your build time and binary size. [Default is: 3.5,7.0]: <b>3.7</b>
+
+Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is /arch:AVX]:
+
+Would you like to override eigen strong inline for some C++ compilation to reduce the compilation time? [Y/n]:
+Eigen strong inline overridden.
+
+Configuration finished
+</pre>
+
+## Build the pip package
+
+### CPU-only support
+
+To build a pip package for TensorFlow with CPU-only support:
+
+<pre>
+C:\tensorflow> <b>bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package</b>
+</pre>
+
+### GPU support
+
+To build a pip package for TensorFlow with GPU support:
+
+<pre>
+C:\tensorflow> <b>bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package</b>
+</pre>
+
+**NOTE :** When building with GPU support, you might want to add `--copt=-nvcc_options=disable-warnings`
+to suppress nvcc warning messages.
+
+The `bazel build` command builds a binary named `build_pip_package`
+(an executable binary to launch bash and run a bash script to create the pip package).
+Running this binary as follows will build a `.whl` file within the `C:/tmp/tensorflow_pkg` directory:
+
+<pre>
+C:\tensorflow> <b>bazel-bin\tensorflow\tools\pip_package\build_pip_package C:/tmp/tensorflow_pkg</b>
+</pre>
+
+## Install the pip package
+
+Invoke `pip3 install` to install that pip package. The filename of the `.whl`
+file depends on the TensorFlow version and your platform. For example, the
+following command will install the pip package for TensorFlow 1.11.0rc0:
+
+<pre>
+C:\tensorflow> <b>pip3 install C:/tmp/tensorflow_pkg/tensorflow-1.11.0rc0-cp36-cp36m-win_amd64.whl</b>
+</pre>
+
+## Validate your installation
+
+Validate your TensorFlow installation by doing the following:
+
+Start a terminal.
+
+Change directory (`cd`) to any directory on your system other than the
+`tensorflow` subdirectory from which you invoked the `configure` command.
+
+Invoke python:
+
+<pre>$ <b>python</b></pre>
+
+Enter the following short program inside the python interactive shell:
+
+```python
+# Python
+import tensorflow as tf
+hello = tf.constant('Hello, TensorFlow!')
+sess = tf.Session()
+print(sess.run(hello))
+```
+
+If the system outputs the following, then you are ready to begin writing
+TensorFlow programs:
+
+<pre>Hello, TensorFlow!</pre>
+
+To learn more, see the [TensorFlow tutorials](../tutorials/).
+
+## Build under MSYS shell
+The above instruction assumes you are building under the Windows native command line (`cmd.exe`), but you can also
+build TensorFlow from MSYS shell. There are a few things to notice:
+
+* Disable the path conversion heuristic in MSYS. MSYS automatically converts arguments that look
+ like a Unix path to Windows path when running a program, this will confuse Bazel.
+ (eg. A Bazel label `//foo/bar:bin` is considered a Unix absolute path, only because it starts with a slash)
+
+ ```sh
+$ export MSYS_NO_PATHCONV=1
+$ export MSYS2_ARG_CONV_EXCL="*"
+```
+
+* Add the directory where you install Bazel in `$PATH`. Assume you have Bazel
+ installed at `C:\tools\bazel.exe`, issue the following command:
+
+ ```sh
+# `:` is used as path separator, so we have to convert the path to Unix style.
+$ export PATH="/c/tools:$PATH"
+```
+
+* Add the directory where you install Python in `$PATH`. Assume you have
+ Python installed at `C:\Python36\python.exe`, issue the following command:
+
+ ```sh
+$ export PATH="/c/Python36:$PATH"
+```
+
+* If you have Python in `$PATH`, you can run configure script just by
+ `./configure`, a shell script will help you invoke python.
+
+* (For GPU build only) Add Cuda and cuDNN bin directories in `$PATH` in the following way:
+
+ ```sh
+$ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH"
+$ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH"
+$ export PATH="/c/tools/cuda/bin:$PATH"
+```
+
+The rest steps should be the same as building under `cmd.exe`.
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index e9061bf3c1..0bb0e5aeb9 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -24,6 +24,8 @@ You must choose one of the following types of TensorFlow to install:
and you need to run performance-critical applications, you should
ultimately install this version.
+<a name="NVIDIARequirements"></a>
+
### Requirements to run TensorFlow with GPU support
If you are installing TensorFlow with GPU support using one of the mechanisms
diff --git a/tensorflow/docs_src/install/leftnav_files b/tensorflow/docs_src/install/leftnav_files
index ace275c0e8..59292f7121 100644
--- a/tensorflow/docs_src/install/leftnav_files
+++ b/tensorflow/docs_src/install/leftnav_files
@@ -6,6 +6,7 @@ install_mac.md: MacOS
install_windows.md: Windows
install_raspbian.md: Raspbian
install_sources.md: From source
+install_sources_windows.md: From source on Windows
>>>
migration.md
diff --git a/tensorflow/docs_src/tutorials/sequences/recurrent.md b/tensorflow/docs_src/tutorials/sequences/recurrent.md
index 715cc7856a..10d60f7966 100644
--- a/tensorflow/docs_src/tutorials/sequences/recurrent.md
+++ b/tensorflow/docs_src/tutorials/sequences/recurrent.md
@@ -77,9 +77,7 @@ The basic pseudocode is as follows:
words_in_dataset = tf.placeholder(tf.float32, [time_steps, batch_size, num_features])
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
# Initial state of the LSTM memory.
-hidden_state = tf.zeros([batch_size, lstm.state_size])
-current_state = tf.zeros([batch_size, lstm.state_size])
-state = hidden_state, current_state
+state = lstm.zero_state(batch_size, dtype=tf.float32)
probabilities = []
loss = 0.0
for current_batch_of_words in words_in_dataset:
@@ -112,7 +110,7 @@ words = tf.placeholder(tf.int32, [batch_size, num_steps])
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
# Initial state of the LSTM memory.
-initial_state = state = tf.zeros([batch_size, lstm.state_size])
+initial_state = state = lstm.zero_state(batch_size, dtype=tf.float32)
for i in range(num_steps):
# The value of state is updated after processing each batch of words.
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 87e6107c2d..9dce78b9a3 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -86,7 +86,10 @@ tf_cc_binary(
"src/gen/cc/op_gen_main.cc",
],
copts = tf_copts(),
- linkopts = ["-lm"],
+ linkopts = select({
+ "//tensorflow:windows": [],
+ "//conditions:default": ["-lm"],
+ }),
linkstatic = 1,
deps = [
":java_op_gen_lib",
@@ -368,7 +371,6 @@ tf_cc_binary(
"$(location {})".format(LINKER_EXPORTED_SYMBOLS),
],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-s",
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml
index 7fa751a46a..e0409fa41b 100644
--- a/tensorflow/java/maven/hadoop/pom.xml
+++ b/tensorflow/java/maven/hadoop/pom.xml
@@ -5,7 +5,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>hadoop</artifactId>
<packaging>jar</packaging>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index 8ecabfd399..f9093ce385 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index e03ce32216..1208956dec 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index fee840f547..755449cb3c 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 0c33819b2b..035077e1e0 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 2af7a5cd2e..b89f042567 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index f4794d68a9..8c4c9d498c 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -110,11 +110,17 @@ download_libtensorflow_jni_gpu() {
cd "${NATIVE_DIR}"
mkdir linux-x86_64
+ mkdir windows-x86_64
curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
+
+ unzip /tmp/windows.zip -d windows-x86_64
+ rm -f /tmp/windows.zip
# Updated timestamps seem to be required to get Maven to pick up the file.
touch linux-x86_64/*
+ touch windows-x86_64/*
cd "${DIR}"
}
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml
index 27d9b54c6c..31e39c588a 100644
--- a/tensorflow/java/maven/spark-connector/pom.xml
+++ b/tensorflow/java/maven/spark-connector/pom.xml
@@ -6,7 +6,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>spark-connector_2.11</artifactId>
<packaging>jar</packaging>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<name>spark-tensorflow-connector</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index c952545bc6..0de90244b1 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0-rc1</version>
+ <version>1.10.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
index 0c751aed9f..824f7fbe32 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
@@ -16,6 +16,33 @@ limitations under the License.
package org.tensorflow.types;
/** Represents an 8-bit unsigned integer. */
-public class UInt8 {
+public class UInt8 extends Number {
+
+ private static final long serialVersionUID = 1L;
+
+ // This class is only used for generic parameterization and is not instantiable. Thus,
+ // it is safe to implement the Number abstract methods with all zeros, as they will
+ // never be invoked.
+
+ @Override
+ public double doubleValue() {
+ return 0.0;
+ }
+
+ @Override
+ public float floatValue() {
+ return 0.0f;
+ }
+
+ @Override
+ public int intValue() {
+ return 0;
+ }
+
+ @Override
+ public long longValue() {
+ return 0L;
+ }
+
private UInt8() {}
}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 8d8d4792fa..5f985654f0 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1871,6 +1871,7 @@ py_library(
":framework_for_generated_wrappers",
":math_ops",
":nn_ops_gen",
+ ":numerics",
"@six_archive//:six",
],
)
@@ -1884,7 +1885,6 @@ py_test(
":client_testlib",
":clip_ops",
":framework_for_generated_wrappers",
- ":numerics",
"//third_party/py/numpy",
],
)
@@ -3341,7 +3341,10 @@ py_library(
py_library(
name = "distribute",
- srcs = ["training/distribute.py"],
+ srcs = [
+ "training/distribute.py",
+ "training/distribution_strategy_context.py",
+ ],
srcs_version = "PY2AND3",
deps = [
":array_ops",
@@ -4207,7 +4210,6 @@ cuda_py_test(
":math_ops",
"//tensorflow/core:protos_all_py",
],
- tags = ["no_windows"],
)
cuda_py_test(
@@ -4501,7 +4503,6 @@ py_test(
srcs = ["training/saver_large_partitioned_variable_test.py"],
srcs_version = "PY2AND3",
tags = [
- "no_windows",
"noasan", # http://b/30782289
"notsan", # http://b/30782289
],
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 42f96a002a..73adb7a559 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 12)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 15)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 27b8ebd362..8a4ac6aaef 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -936,7 +936,6 @@ py_test(
size = "small",
srcs = ["cli/profile_analyzer_cli_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":debugger_cli_common",
":profile_analyzer_cli",
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 1a78559ac0..e2b1890c2f 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -77,19 +77,54 @@ class SubclassedKerasModel(keras.Model):
def __init__(self):
super(SubclassedKerasModel, self).__init__()
- self.layer = keras.layers.Dense(
+ self.layer_a = keras.layers.Dense(
+ 64, kernel_initializer="ones", bias_initializer="zeros")
+ self.layer_b = keras.layers.Dense(
+ 128, kernel_initializer="ones", bias_initializer="zeros")
+ self.layer_c = keras.layers.Dense(
+ 256, kernel_initializer="ones", bias_initializer="zeros")
+ self.layer_d = keras.layers.Dense(
+ 256, kernel_initializer="ones", bias_initializer="zeros")
+ self.layer_e = keras.layers.Dense(
10, kernel_initializer="ones", bias_initializer="zeros")
def call(self, x):
- return self.layer(x)
+ x = self.layer_a(x)
+ x = self.layer_b(x)
+ x = self.layer_c(x)
+ x = self.layer_d(x)
+ return self.layer_e(x)
def make_keras_model():
- x = keras.Input(shape=(10,))
- y = keras.layers.Dense(
- 10, kernel_initializer="ones", bias_initializer="zeros")(
- x)
- return keras.Model(inputs=x, outputs=y)
+ model_input = keras.Input(shape=(10,))
+ x = keras.layers.Dense(
+ 64, kernel_initializer="ones", bias_initializer="zeros")(model_input)
+ x = keras.layers.Dense(
+ 128, kernel_initializer="ones", bias_initializer="zeros")(x)
+ x = keras.layers.Dense(
+ 256, kernel_initializer="ones", bias_initializer="zeros")(x)
+ x = keras.layers.Dense(
+ 256, kernel_initializer="ones", bias_initializer="zeros")(x)
+ x = keras.layers.Dense(
+ 10, kernel_initializer="ones", bias_initializer="zeros")(x)
+ return keras.Model(inputs=model_input, outputs=x)
+
+
+def make_sequential_keras_model():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(
+ 64, kernel_initializer="ones", bias_initializer="zeros",
+ input_shape=(10,)))
+ model.add(keras.layers.Dense(
+ 128, kernel_initializer="ones", bias_initializer="zeros"))
+ model.add(keras.layers.Dense(
+ 256, kernel_initializer="ones", bias_initializer="zeros"))
+ model.add(keras.layers.Dense(
+ 256, kernel_initializer="ones", bias_initializer="zeros"))
+ model.add(keras.layers.Dense(
+ 10, kernel_initializer="ones", bias_initializer="zeros"))
+ return model
class MicroBenchmarks(test.Benchmark):
@@ -638,6 +673,15 @@ class MicroBenchmarks(test.Benchmark):
assert np.equal(func(), SubclassedKerasModel()(data)).all()
self._run(func, 30000)
+ def benchmark_keras_model_sequential(self):
+ model = make_sequential_keras_model()
+ data = random_ops.random_uniform((10, 10))
+ func = lambda: model(data)
+ # Symmetry with benchmark_keras_model_functional
+ func()
+ assert np.equal(func(), make_keras_model()(data)).all()
+ self._run(func, 30000)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f87d88040f..5afba466bc 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -42,7 +42,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training import distribute
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -227,6 +228,9 @@ class FuncGraph(CapturingGraph):
self.get_collection_ref(collection)[:] = graph.get_collection(
collection)
+ # Copy distribution strategy scope from the containing graph as well.
+ self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
+
if context.executing_eagerly():
self.seed = context.global_seed()
else:
@@ -243,78 +247,6 @@ class FuncGraph(CapturingGraph):
return internal_tensor
-# pylint: disable=invalid-name
-class HelperContext(object):
- """ControlFlowContext with a customizable AddOp method."""
-
- def __init__(self, add_op_internal):
- self._add_op_internal = add_op_internal
- self._values = set() # control flow code sometimes updates this.
-
- def _AddOpInternal(self, op):
- self._add_op_internal(op)
-
- @property
- def outer_context(self):
- return self._outer_context
-
- def GetWhileContext(self):
- if self._outer_context:
- return self._outer_context.GetWhileContext()
-
- def IsWhileContext(self):
- return False
-
- def IsCondContext(self):
- return False
-
- def IsXLAContext(self):
- return False
-
- def AddOp(self, op): # pylint: disable=invalid-name
- self._AddOpInternal(op)
- if self._outer_context:
- self._outer_context.AddOp(op)
-
- def AddName(self, _):
- pass
-
- def AddInnerOp(self, op):
- self._AddOpInternal(op)
- if self._outer_context:
- self._outer_context.AddInnerOp(op)
-
- def AddValue(self, val):
- if self._outer_context:
- return self._outer_context.AddValue(val)
- else:
- return val
-
- def EnterGradientColocation(self, op, gradient_uid):
- """Start building a gradient colocated with an op."""
- if self._outer_context:
- self._outer_context.EnterGradientColocation(op, gradient_uid)
-
- def ExitGradientColocation(self, op, gradient_uid):
- """Start building a gradient colocated with an op."""
- if self._outer_context:
- self._outer_context.ExitGradientColocation(op, gradient_uid)
-
- def __enter__(self):
- # pylint: disable=protected-access
- self._g = ops.get_default_graph()
- self._outer_context = self._g._get_control_flow_context()
- self._g._set_control_flow_context(self)
- self._nested_contexts = (
- self._outer_context._nested_contexts
- if self._outer_context is not None else None)
- # pylint: enable=protected-access
-
- def __exit__(self, *_):
- self._g._set_control_flow_context(self._outer_context) # pylint: disable=protected-access
-# pylint: enable=invalid-name
-
-
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, ops.uid())
@@ -479,11 +411,6 @@ class _EagerDefinedFunction(object):
return outputs
-def _map_sequence_obj_to_idx(sequence):
- """Maps objs in the sequence from id(obj) to sequence index."""
- return {id(x): i for i, x in enumerate(sequence)}
-
-
def _flatten(sequence):
"""A wrapper around `nest.flatten` that also unpacks `IndexedSlices`."""
# TODO(akshayka): Support `SparseTensor` in a similar fashion.
@@ -568,7 +495,7 @@ class GraphModeFunction(object):
# Find the variables that are components of something distributed and
# put them into a {handle_tensor -> distributed variable object} map.
self._distributed_variables = {}
- strategy = distribute.get_distribution_strategy()
+ strategy = distribution_strategy_context.get_distribution_strategy()
for variable in self._variables:
# If variable is not distributed, unwrap returns [variable].
component_variables = strategy.unwrap(variable)
@@ -832,6 +759,8 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
with func_graph.as_default(), AutomaticControlDependencies() as a:
+ variable_scope.get_variable_scope().set_use_resource(True)
+
if signature is None:
func_args = _get_defun_inputs_from_args(args)
func_kwds = _get_defun_inputs_from_args(kwds)
@@ -898,7 +827,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
# the function is run on a different device). Thus, instead of storing
# the specific captured variable, we replace it with its distributed
# container.
- strategy = distribute.get_distribution_strategy()
+ strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
variables[i] = strategy.value_container(variable)
@@ -1322,18 +1251,60 @@ def defun(func=None, input_signature=None, compiled=False):
generates and placed in the eager context if executing eagerly or into an
outer graph otherwise.
- _Tracing and Input Signatures_.
- The signature of inputs supplied to `F` is defined to be a tuple of the shapes
- and dtypes of Tensor-typed arguments and the values of non-Tensor arguments,
- where "arguments" includes both args and kwargs. Every time `F` is invoked,
- the signature of its inputs are inferred. The first time `F(*args, **kwargs)`
- is invoked with a particular signature, `f(*args, **kwargs)` is executed and
- all the TensorFlow operations that `f` executes, along with the Tensors that
- flow between them, are recorded in a TensorFlow graph. `F` caches this graph
- and binds it to the inputs' signature; every subsequent invocation of `F` with
- inputs conforming to this signature will immediately retrieve the cached graph
- and pass it to the TensorFlow runtime for execution.
+ _Input Signatures_
+ By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
+ for every unique sequence of the shapes and dtypes of Tensor arguments and
+ the values of Python objects it is invoked with. For example, calling
+ `F(tf.random_uniform([2])` will execute a different graph than
+ `F(tf.random_uniform([3])` because the two inputs have different shapes.
+ The first time that `F(*args, **kwargs)` is called with a particular sequence
+ of Tensor shapes and dtypes and Python values, it constructs a graph by
+ tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
+ input signature inferred from `(*args, **kwargs)` and cached for future reuse.
+
+ `tf.contrib.eager.defun` caches graphs for your convenience, letting you
+ define TensorFlow functions without explicitly specifying their signatures.
+ However, this policy is conservative and potentially expensive; for example,
+ when different invocations of your function have differently-shaped Tensor
+ inputs, this policy might generate more graph functions than necessary. To
+ eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
+ optional `input_signature` argument specifying the shapes and dtypes of the
+ inputs. In particular, the shapes may be partially unspecified, with `None`s
+ in the unknown dimensions. When an input signature is provided,
+ `tf.contrib.eager.defun` will only instantiate a single graph for the
+ decorated Python function. The following is an example:
+
+ ```python
+ import tensorflow as tf
+
+ # The first `TensorSpec` below describes the shape and dtype of `words`,
+ # and the second describes the shape and dtype of `another_tensor`. Note that
+ # the last dimension of the `words` `TensorSpec` is left unspecified.
+ @tf.contrib.eager.defun(input_signature=[
+ tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
+ tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
+ ])
+ def my_sequence_model(words, another_tensor):
+ ...
+
+ # Note how the third dimension of the first input can vary freely.
+ words = tf.random_uniform(([50, 300, 10])
+ second_input = tf.random_uniform([300, 100])
+ my_sequence_model(words, second_input)
+
+ words = tf.random_uniform(([50, 300, 20])
+ my_sequence_model(words, second_input)
+
+ # Passing an input with an incompatible shape will raise an error.
+ words = tf.random_uniform(([50, 100, 20])
+ my_sequence_model(words, second_input) # <---- This will raise an error.
+
+ ```
+
+ Python functions that are compiled with an `input_signature` must only accept
+ Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
+ _Tracing_
Be aware that because `F` only logs TensorFlow operations, all the other
Python code that `f` executes will only shape the _construction_ of the graphs
that `F` executes: the Python code won't be executed when the graphs
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 0488dc9752..380bcf763f 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -397,6 +397,18 @@ class FunctionTest(test.TestCase):
compiled = function.defun(f)
compiled()
+ @test_util.run_in_graph_and_eager_modes
+ def testDefunForcesResourceVariables(self):
+
+ def variable_creator():
+ return variables.Variable(0.0).read_value()
+
+ defined = function.defun(variable_creator)
+ defined() # Create the variable.
+ self.assertEqual(len(defined.variables), 1)
+ self.assertIsInstance(
+ defined.variables[0], resource_variable_ops.ResourceVariable)
+
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -434,6 +446,22 @@ class FunctionTest(test.TestCase):
op = call()
self.assertAllEqual(sess.run(op), 2.0)
+ def testSymbolicGradientVariableZerosLike(self):
+ with ops.Graph().as_default():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def f(x, v):
+ v.read_value()
+ return x * x
+
+ x = constant_op.constant(1.0)
+ l = f(x, v)
+ _, dv = gradients_impl.gradients(l, [x, v])
+ with self.test_session():
+ v.initializer.run()
+ self.assertAllEqual(dv.eval(), 0.0)
+
def testGraphModeManyFunctions(self):
with context.graph_mode(), self.test_session():
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 8b423f76de..16928ca4b7 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -703,9 +703,30 @@ def _bt_model_fn(
global_step = training_util.get_or_create_global_step()
bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
sorted_feature_columns)
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ # Create logits.
+ if mode != model_fn.ModeKeys.TRAIN:
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ return head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ # ============== Training graph ==============
# Extract input features and set up cache for training.
training_state_cache = None
- if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
+ if train_in_memory:
# cache transformed features as well for in-memory training.
batch_size = array_ops.shape(labels)[0]
input_feature_list, input_cache_op = (
@@ -717,63 +738,51 @@ def _bt_model_fn(
else:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
- if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
+ if example_id_column_name:
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension)
- # Create Ensemble resources.
- tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
# Variable that determines whether bias centering is needed.
center_bias_var = variable_scope.variable(
initial_value=center_bias, name='center_bias_needed', trainable=False)
- # Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
- logits = boosted_trees_ops.predict(
- # For non-TRAIN mode, ensemble doesn't change after initialization,
- # so no local copy is needed; using tree_ensemble directly.
- tree_ensemble_handle=tree_ensemble.resource_handle,
+ if is_single_machine:
+ local_tree_ensemble = tree_ensemble
+ ensemble_reload = control_flow_ops.no_op()
+ else:
+ # Have a local copy of ensemble for the distributed setting.
+ with ops.device(worker_device):
+ local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ name=name + '_local', is_local=True)
+ # TODO(soroush): Do partial updates if this becomes a bottleneck.
+ ensemble_reload = local_tree_ensemble.deserialize(
+ *tree_ensemble.serialize())
+
+ if training_state_cache:
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ training_state_cache.lookup())
+ else:
+ # Always start from the beginning when no cache is set up.
+ batch_size = array_ops.shape(labels)[0]
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
+ array_ops.zeros(
+ [batch_size, head.logits_dimension], dtype=dtypes.float32))
+
+ with ops.control_dependencies([ensemble_reload]):
+ (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
+ last_layer_nodes_range) = local_tree_ensemble.get_states()
+ summary.scalar('ensemble/num_trees', num_trees)
+ summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+ summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+
+ partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
+ tree_ensemble_handle=local_tree_ensemble.resource_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
- else:
- if is_single_machine:
- local_tree_ensemble = tree_ensemble
- ensemble_reload = control_flow_ops.no_op()
- else:
- # Have a local copy of ensemble for the distributed setting.
- with ops.device(worker_device):
- local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
- name=name + '_local', is_local=True)
- # TODO(soroush): Do partial updates if this becomes a bottleneck.
- ensemble_reload = local_tree_ensemble.deserialize(
- *tree_ensemble.serialize())
-
- if training_state_cache:
- cached_tree_ids, cached_node_ids, cached_logits = (
- training_state_cache.lookup())
- else:
- # Always start from the beginning when no cache is set up.
- batch_size = array_ops.shape(labels)[0]
- cached_tree_ids, cached_node_ids, cached_logits = (
- array_ops.zeros([batch_size], dtype=dtypes.int32),
- _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
- array_ops.zeros(
- [batch_size, head.logits_dimension], dtype=dtypes.float32))
-
- with ops.control_dependencies([ensemble_reload]):
- (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
- last_layer_nodes_range) = local_tree_ensemble.get_states()
- summary.scalar('ensemble/num_trees', num_trees)
- summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
- summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
-
- partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
- tree_ensemble_handle=local_tree_ensemble.resource_handle,
- cached_tree_ids=cached_tree_ids,
- cached_node_ids=cached_node_ids,
- bucketized_features=input_feature_list,
- logits_dimension=head.logits_dimension)
-
logits = cached_logits + partial_logits
# Create training graph.
@@ -846,12 +855,11 @@ def _bt_model_fn(
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
- if mode == model_fn.ModeKeys.TRAIN:
- # Add an early stop hook.
- estimator_spec = estimator_spec._replace(
- training_hooks=estimator_spec.training_hooks +
- (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
- tree_hparams.n_trees, tree_hparams.max_depth),))
+ # Add an early stop hook.
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks +
+ (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
+ tree_hparams.n_trees, tree_hparams.max_depth),))
return estimator_spec
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index b8cd55c806..eab608813b 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -86,14 +86,15 @@ class Estimator(object):
subdirectory thereof. If `model_dir` is not set, a temporary directory is
used.
- The `config` argument can be passed `RunConfig` object containing information
- about the execution environment. It is passed on to the `model_fn`, if the
- `model_fn` has a parameter named "config" (and input functions in the same
- manner). If the `config` parameter is not passed, it is instantiated by the
- `Estimator`. Not passing config means that defaults useful for local execution
- are used. `Estimator` makes config available to the model (for instance, to
- allow specialization based on the number of workers available), and also uses
- some of its fields to control internals, especially regarding checkpointing.
+ The `config` argument can be passed `tf.estimator.RunConfig` object containing
+ information about the execution environment. It is passed on to the
+ `model_fn`, if the `model_fn` has a parameter named "config" (and input
+ functions in the same manner). If the `config` parameter is not passed, it is
+ instantiated by the `Estimator`. Not passing config means that defaults useful
+ for local execution are used. `Estimator` makes config available to the model
+ (for instance, to allow specialization based on the number of workers
+ available), and also uses some of its fields to control internals, especially
+ regarding checkpointing.
The `params` argument contains hyperparameters. It is passed to the
`model_fn`, if the `model_fn` has a parameter named "params", and to the input
@@ -138,15 +139,16 @@ class Estimator(object):
* `features`: This is the first item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
- single `Tensor` or `dict` of same.
+ single `tf.Tensor` or `dict` of same.
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
- single `Tensor` or `dict` of same (for multi-head models). If
- mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
- the `model_fn`'s signature does not accept `mode`, the
- `model_fn` must still be able to handle `labels=None`.
+ single `tf.Tensor` or `dict` of same (for multi-head models).
+ If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ be passed. If the `model_fn`'s signature does not accept
+ `mode`, the `model_fn` must still be able to handle
+ `labels=None`.
* `mode`: Optional. Specifies if this training, evaluation or
- prediction. See `ModeKeys`.
+ prediction. See `tf.estimator.ModeKeys`.
* `params`: Optional `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
@@ -156,10 +158,10 @@ class Estimator(object):
configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
- `EstimatorSpec`
+ `tf.estimator.EstimatorSpec`
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator to
+ also be used to load checkpoints from the directory into an estimator to
continue training a previously saved model. If `PathLike` object, the
path will be resolved. If `None`, the model_dir in `config` will be used
if set. If both are set, they must be same. If both are `None`, a
@@ -170,9 +172,10 @@ class Estimator(object):
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
- filepath is provided instead of a `WarmStartSettings`,
- then all variables are warm-started, and it is assumed
- that vocabularies and Tensor names are unchanged.
+ filepath is provided instead of a
+ `tf.estimator.WarmStartSettings`, then all variables are
+ warm-started, and it is assumed that vocabularies
+ and `tf.Tensor` names are unchanged.
Raises:
ValueError: parameters of `model_fn` don't match `params`.
@@ -220,10 +223,10 @@ class Estimator(object):
@property
def model_fn(self):
- """Returns the model_fn which is bound to self.params.
+ """Returns the `model_fn` which is bound to `self.params`.
Returns:
- The model_fn with following signature:
+ The `model_fn` with following signature:
`def model_fn(features, labels, mode, config)`
"""
@@ -243,7 +246,7 @@ class Estimator(object):
Numpy array - value of the tensor.
Raises:
- ValueError: If the Estimator has not produced a checkpoint yet.
+ ValueError: If the `Estimator` has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
with context.graph_mode():
@@ -256,14 +259,14 @@ class Estimator(object):
List of names.
Raises:
- ValueError: If the Estimator has not produced a checkpoint yet.
+ ValueError: If the `Estimator` has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
with context.graph_mode():
return [name for name, _ in training.list_variables(self.model_dir)]
def latest_checkpoint(self):
- """Finds the filename of latest saved checkpoint file in `model_dir`.
+ """Finds the filename of the latest saved checkpoint file in `model_dir`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was
@@ -278,40 +281,36 @@ class Estimator(object):
steps=None,
max_steps=None,
saving_listeners=None):
- """Trains a model given training data input_fn.
+ """Trains a model given training data `input_fn`.
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
- the following:
-
- * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
- tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where `features` is a `Tensor` or a
- dictionary of string feature name to `Tensor` and `labels` is a
- `Tensor` or a dictionary of string label name to `Tensor`. Both
- `features` and `labels` are consumed by `model_fn`. They should
- satisfy the expectation of `model_fn` from inputs.
-
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the training loop.
- steps: Number of steps for which to train model. If `None`, train forever
- or train until input_fn generates the `OutOfRange` error or
- `StopIteration` exception. 'steps' works incrementally. If you call two
- times train(steps=10) then training occurs in total 20 steps. If
- `OutOfRange` or `StopIteration` occurs in the middle, training stops
+ See @{$premade_estimators#create_input_functions} for more information.
+ The function should construct and return one of the following: * A
+ `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
+ `(features, labels)` with same constraints as below. * A tuple
+ `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
+ of string feature name to `Tensor` and `labels` is a `Tensor` or a
+ dictionary of string label name to `Tensor`. Both `features` and
+ `labels` are consumed by `model_fn`. They should satisfy the expectation
+ of `model_fn` from inputs.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the training loop.
+ steps: Number of steps for which to train the model. If `None`, train
+ forever or train until `input_fn` generates the `tf.errors.OutOfRange`
+ error or `StopIteration` exception. `steps` works incrementally. If you
+ call two times `train(steps=10)` then training occurs in total 20 steps.
+ If `OutOfRange` or `StopIteration` occurs in the middle, training stops
before 20 steps. If you don't want to have incremental behavior please
set `max_steps` instead. If set, `max_steps` must be `None`.
max_steps: Number of total steps for which to train model. If `None`,
- train forever or train until input_fn generates the `OutOfRange` error
- or `StopIteration` exception. If set, `steps` must be `None`. If
- `OutOfRange` or `StopIteration` occurs in the middle, training stops
- before `max_steps` steps.
- Two calls to `train(steps=100)` means 200 training
- iterations. On the other hand, two calls to `train(max_steps=100)` means
- that the second call will not do any iteration since first call did
- all 100 steps.
+ train forever or train until `input_fn` generates the
+ `tf.errors.OutOfRange` error or `StopIteration` exception. If set,
+ `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
+ middle, training stops before `max_steps` steps. Two calls to
+ `train(steps=100)` means 200 training iterations. On the other hand, two
+ calls to `train(max_steps=100)` means that the second call will not do
+ any iteration since first call did all 100 steps.
saving_listeners: list of `CheckpointSaverListener` objects. Used for
callbacks that run immediately before or after checkpoint savings.
@@ -320,7 +319,7 @@ class Estimator(object):
Raises:
ValueError: If both `steps` and `max_steps` are not `None`.
- ValueError: If either `steps` or `max_steps` is <= 0.
+ ValueError: If either `steps` or `max_steps <= 0`.
"""
with context.graph_mode():
if (steps is not None) and (max_steps is not None):
@@ -368,7 +367,7 @@ class Estimator(object):
return []
def eval_dir(self, name=None):
- """Shows directory name where evaluation metrics are dumped.
+ """Shows the directory name where evaluation metrics are dumped.
Args:
name: Name of the evaluation if user needs to run multiple evaluations on
@@ -384,36 +383,34 @@ class Estimator(object):
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
- """Evaluates the model given evaluation data input_fn.
+ """Evaluates the model given evaluation data `input_fn`.
For each step, calls `input_fn`, which returns one batch of data.
Evaluates until:
- `steps` batches are processed, or
- - `input_fn` raises an end-of-input exception (`OutOfRangeError` or
+ - `input_fn` raises an end-of-input exception (`tf.errors.OutOfRangeError`
+ or
`StopIteration`).
Args:
- input_fn: A function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
- the following:
-
- * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
- tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where `features` is a `Tensor` or a
- dictionary of string feature name to `Tensor` and `labels` is a
- `Tensor` or a dictionary of string label name to `Tensor`. Both
- `features` and `labels` are consumed by `model_fn`. They should
- satisfy the expectation of `model_fn` from inputs.
-
+ input_fn: A function that constructs the input data for evaluation. See
+ @{$premade_estimators#create_input_functions} for more information. The
+ function should construct and return one of the following: * A
+ `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
+ `(features, labels)` with same constraints as below. * A tuple
+ `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
+ of string feature name to `Tensor` and `labels` is a `Tensor` or a
+ dictionary of string label name to `Tensor`. Both `features` and
+ `labels` are consumed by `model_fn`. They should satisfy the expectation
+ of `model_fn` from inputs.
steps: Number of steps for which to evaluate model. If `None`, evaluates
until `input_fn` raises an end-of-input exception.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the evaluation call.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the evaluation call.
checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
latest checkpoint in `model_dir` is used. If there are no checkpoints
in `model_dir`, evaluation is run with newly initialized `Variables`
- instead of restored from checkpoint.
+ instead of ones restored from checkpoint.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data. Metrics for
different evaluations are saved in separate folders, and appear
@@ -479,33 +476,33 @@ class Estimator(object):
Args:
input_fn: A function that constructs the features. Prediction continues
- until `input_fn` raises an end-of-input exception (`OutOfRangeError` or
- `StopIteration`).
+ until `input_fn` raises an end-of-input exception
+ (`tf.errors.OutOfRangeError` or `StopIteration`).
See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
- * A 'tf.data.Dataset' object: Outputs of `Dataset` object must have
+ * A `tf.data.Dataset` object: Outputs of `Dataset` object must have
same constraints as below.
- * features: A `Tensor` or a dictionary of string feature name to
+ * features: A `tf.Tensor` or a dictionary of string feature name to
`Tensor`. features are consumed by `model_fn`. They should satisfy
the expectation of `model_fn` from inputs.
* A tuple, in which case the first item is extracted as features.
predict_keys: list of `str`, name of the keys to predict. It is used if
- the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
- then rest of the predictions will be filtered from the dictionary. If
- `None`, returns all.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the prediction call.
+ the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+ `predict_keys` is used then rest of the predictions will be filtered
+ from the dictionary. If `None`, returns all.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the prediction call.
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
latest checkpoint in `model_dir` is used. If there are no checkpoints
in `model_dir`, prediction is run with newly initialized `Variables`
- instead of restored from checkpoint.
- yield_single_examples: If False, yield the whole batch as returned by the
- `model_fn` instead of decomposing the batch into individual elements.
- This is useful if `model_fn` returns some tensors whose first dimension
- is not equal to the batch size.
+ instead of ones restored from checkpoint.
+ yield_single_examples: If `False`, yields the whole batch as returned by
+ the `model_fn` instead of decomposing the batch into individual
+ elements. This is useful if `model_fn` returns some tensors whose first
+ dimension is not equal to the batch size.
Yields:
Evaluated values of `predictions` tensors.
@@ -513,10 +510,10 @@ class Estimator(object):
Raises:
ValueError: Could not find a trained model in `model_dir`.
ValueError: If batch length of predictions is not the same and
- `yield_single_examples` is True.
+ `yield_single_examples` is `True`.
ValueError: If there is a conflict between `predict_keys` and
`predictions`. For example if `predict_keys` is not `None` but
- `EstimatorSpec.predictions` is not a `dict`.
+ `tf.estimator.EstimatorSpec.predictions` is not a `dict`.
"""
with context.graph_mode():
hooks = _check_hooks_type(hooks)
@@ -571,14 +568,10 @@ class Estimator(object):
return
allowed_overrides = set([
- '_call_input_fn', '_call_model_fn',
- '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_create_global_step', '_create_and_assert_global_step',
+ '_create_and_assert_global_step',
'_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
'_estimator_api_names_v1', '_estimator_api_constants',
'_estimator_api_constants_v1',
- '_validate_features_in_predict_input',
- '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
@@ -599,30 +592,34 @@ class Estimator(object):
checkpoint_path=None,
strip_default_attrs=False):
# pylint: disable=line-too-long
- """Exports inference graph as a SavedModel into given dir.
+ """Exports inference graph as a `SavedModel` into the given dir.
For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+ @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with
+ Estimators}.
This method builds a new graph by first calling the
- serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
- this `Estimator`'s model_fn to generate the model graph based on those
+ `serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling
+ this `Estimator`'s `model_fn` to generate the model graph based on those
features. It restores the given checkpoint (or, lacking that, the most
recent checkpoint) into this graph in a fresh session. Finally it creates
- a timestamped export directory below the given export_dir_base, and writes
- a `SavedModel` into it containing a single `MetaGraphDef` saved from this
+ a timestamped export directory below the given `export_dir_base`, and writes
+ a `SavedModel` into it containing a single `tf.MetaGraphDef` saved from this
session.
The exported `MetaGraphDef` will provide one `SignatureDef` for each
- element of the export_outputs dict returned from the model_fn, named using
+ element of the `export_outputs` dict returned from the `model_fn`, named
+ using
the same keys. One of these keys is always
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,
+ indicating which
signature will be served when a serving request does not specify one.
For each signature, the outputs are provided by the corresponding
- `ExportOutput`s, and the inputs are always the input receivers provided by
- the serving_input_receiver_fn.
+ `tf.estimator.export.ExportOutput`s, and the inputs are always the input
+ receivers provided by
+ the `serving_input_receiver_fn`.
- Extra assets may be written into the SavedModel via the assets_extra
+ Extra assets may be written into the `SavedModel` via the `assets_extra`
argument. This should be a dict, where each key gives a destination path
(including the filename) relative to the assets.extra directory. The
corresponding value gives the full path of the source file to be copied.
@@ -631,23 +628,27 @@ class Estimator(object):
Args:
export_dir_base: A string containing a directory in which to create
- timestamped subdirectories containing exported SavedModels.
- serving_input_receiver_fn: A function that takes no argument and
- returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
+ timestamped subdirectories containing exported `SavedModel`s.
+ serving_input_receiver_fn: A function that takes no argument and returns a
+ `tf.estimator.export.ServingInputReceiver` or
+ `tf.estimator.export.TensorServingInputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel, or `None` if no extra assets are needed.
- as_text: whether to write the SavedModel proto in text format.
+ within the exported `SavedModel`, or `None` if no extra assets are
+ needed.
+ as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
Raises:
- ValueError: if no serving_input_receiver_fn is provided, no export_outputs
+ ValueError: if no `serving_input_receiver_fn` is provided, no
+ `export_outputs`
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
@@ -668,35 +669,37 @@ class Estimator(object):
strip_default_attrs=False,
mode=model_fn_lib.ModeKeys.PREDICT):
# pylint: disable=line-too-long
- """Exports a single train/eval/predict graph as a SavedModel.
+ """Exports a single train/eval/predict graph as a `SavedModel`.
- This method is a wrapper for _export_all_saved_models, and wraps a raw
- input_receiver_fn in a dictionary to pass in to that function.
- See _export_all_saved_models for full docs.
+ This method is a wrapper for `_export_all_saved_models`, and wraps a raw
+ `input_receiver_fn` in a dictionary to pass in to that function.
+ See `_export_all_saved_models` for full docs.
- See tf.contrib.estimator.export_saved_model_for_mode for the currently
+ See `tf.contrib.estimator.export_saved_model_for_mode` for the currently
exposed version of this function.
Args:
export_dir_base: A string containing a directory in which to create
- timestamped subdirectories containing exported SavedModels.
- input_receiver_fn: a function that takes no argument and
- returns the appropriate subclass of `InputReceiver`.
+ timestamped subdirectories containing exported `SavedModel`s.
+ input_receiver_fn: a function that takes no argument and returns the
+ appropriate subclass of `InputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel, or `None` if no extra assets are needed.
- as_text: whether to write the SavedModel proto in text format.
+ within the exported `SavedModel`, or `None` if no extra assets are
+ needed.
+ as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
- mode: tf.estimator.ModeKeys value indicating with mode will be exported.
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ mode: `tf.estimator.ModeKeys` value indicating with mode will be exported.
Returns:
The string path to the exported directory.
Raises:
- ValueError: if input_receiver_fn is None, no export_outputs
+ ValueError: if `input_receiver_fn` is `None`, no `export_outputs`
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
@@ -720,40 +723,46 @@ class Estimator(object):
checkpoint_path=None,
strip_default_attrs=False):
# pylint: disable=line-too-long
- """Exports a SavedModel containing MetaGraphDefs for each requested mode.
+ """Exports a `SavedModel` containing `tf.MetaGraphDefs` for each requested mode.
- See tf.contrib.estimator.export_all_saved_models for the currently
+ See `tf.contrib.estimator.export_all_saved_models` for the currently
exposed version of this function.
- For each mode passed in via the input_receiver_fn_map,
- this method builds a new graph by calling the input_receiver_fn to obtain
+ For each mode passed in via the `input_receiver_fn_map`,
+ this method builds a new graph by calling the `input_receiver_fn` to obtain
feature and label `Tensor`s. Next, this method calls the `Estimator`'s
- model_fn in the passed mode to generate the model graph based on
+ `model_fn` in the passed mode to generate the model graph based on
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
- Only one of the modes is used for saving variables to the SavedModel
- (order of preference: TRAIN, EVAL, then PREDICT), such that up to three
- MetaGraphDefs are saved with a single set of variables in a single
- SavedModel directory.
-
- For the variables and MetaGraphDefs, a timestamped export directory below
- export_dir_base, and writes a `SavedModel` into it containing
- the `MetaGraphDef` for the given mode and its associated signatures.
+ Only one of the modes is used for saving variables to the `SavedModel`
+ (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
+ @{tf.estimator.ModeKeys#EVAL$EVAL}, then
+ @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ `tf.MetaGraphDefs` are saved with a single set of variables in a single
+ `SavedModel` directory.
+
+ For the variables and `tf.MetaGraphDefs`, a timestamped export directory
+ below
+ `export_dir_base`, and writes a `SavedModel` into it containing
+ the `tf.MetaGraphDef` for the given mode and its associated signatures.
For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
- for each element of the export_outputs dict returned from the model_fn,
+ for each element of the `export_outputs` dict returned from the `model_fn`,
named using the same keys. One of these keys is always
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,
+ indicating which
signature will be served when a serving request does not specify one.
For each signature, the outputs are provided by the corresponding
- `ExportOutput`s, and the inputs are always the input receivers provided by
- the serving_input_receiver_fn.
+ `tf.estimator.export.ExportOutput`s, and the inputs are always the input
+ receivers provided by
+ the `serving_input_receiver_fn`.
- For training and evaluation, the train_op is stored in an extra collection,
- and loss, metrics, and predictions are included in a SignatureDef for the
+ For training and evaluation, the `train_op` is stored in an extra
+ collection,
+ and loss, metrics, and predictions are included in a `SignatureDef` for the
mode in question.
- Extra assets may be written into the SavedModel via the assets_extra
+ Extra assets may be written into the `SavedModel` via the `assets_extra`
argument. This should be a dict, where each key gives a destination path
(including the filename) relative to the assets.extra directory. The
corresponding value gives the full path of the source file to be copied.
@@ -762,25 +771,28 @@ class Estimator(object):
Args:
export_dir_base: A string containing a directory in which to create
- timestamped subdirectories containing exported SavedModels.
- input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
- mappings, where the input_receiver_fn is a function that takes no
- argument and returns the appropriate subclass of `InputReceiver`.
+ timestamped subdirectories containing exported `SavedModel`s.
+ input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
+ `input_receiver_fn` mappings, where the `input_receiver_fn` is a
+ function that takes no arguments and returns the appropriate subclass of
+ `InputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel, or `None` if no extra assets are needed.
- as_text: whether to write the SavedModel proto in text format.
+ within the exported `SavedModel`, or `None` if no extra assets are
+ needed.
+ as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
- A dict of tf.estimator.ModeKeys value to string path for each exported
+ A dict of `tf.estimator.ModeKeys` value to string path for each exported
directory.
Raises:
- ValueError: if any input_receiver_fn is None, no export_outputs
+ ValueError: if any `input_receiver_fn` is `None`, no `export_outputs`
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
@@ -853,25 +865,29 @@ class Estimator(object):
export_tags=None,
check_variables=True):
# pylint: disable=line-too-long
- """Loads variables and adds them along with a MetaGraphDef for saving.
+ """Loads variables and adds them along with a `tf.MetaGraphDef` for saving.
Args:
- builder: instance of SavedModelBuilder that will be used for saving.
- input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
- mappings, where the input_receiver_fn is a function that takes no
- argument and returns the appropriate subclass of `InputReceiver`.
+ builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will
+ be used for saving.
+ input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
+ `input_receiver_fn` mappings, where the `input_receiver_fn` is a
+ function that takes no argument and returns the appropriate subclass of
+ `InputReceiver`.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
- save_variables: bool, whether variables should be saved. If False, just
- the MetaGraphDef will be saved. Note that save_variables should only be
- True for the first call to this function, and the SavedModelBuilder will
- raise an error if that is not the case.
- mode: tf.estimator.ModeKeys value indicating which mode will be exported.
- export_tags: The set of tags with which to save `MetaGraphDef`. If None,
- a default set will be selected to matched the passed mode.
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ save_variables: bool, whether variables should be saved. If `False`, just
+ the `tf.MetaGraphDef` will be saved. Note that `save_variables` should
+ only be `True` for the first call to this function, and the
+ `SavedModelBuilder` will raise an error if that is not the case.
+ mode: `tf.estimator.ModeKeys` value indicating which mode will be
+ exported.
+ export_tags: The set of tags with which to save `tf.MetaGraphDef`. If
+ `None`, a default set will be selected to matched the passed mode.
check_variables: bool, whether to check the checkpoint has all variables.
Raises:
@@ -953,21 +969,23 @@ class Estimator(object):
builder.add_meta_graph(**meta_graph_kwargs)
def _get_export_outputs_for_spec(self, estimator_spec):
- """Given an EstimatorSpec, determine what our export outputs should be.
+ """Given an `EstimatorSpec`, determine what our export outputs should be.
- EstimatorSpecs contain export_outputs that are used for serving, but for
+ `EstimatorSpecs` contains `export_outputs` that are used for serving, but
+ for
training and eval graphs, we must wrap the tensors of interest in
- appropriate ExportOutput objects.
+ appropriate `tf.estimator.export.ExportOutput` objects.
Args:
- estimator_spec: EstimatorSpec object that will be exported.
+ estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported.
Returns:
- a dict mapping export_output_name to ExportOutput object.
+ a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput`
+ object.
Raises:
- ValueError: if an appropriate ExportOutput cannot be found for the
- passed EstimatorSpec.mode
+ ValueError: if an appropriate `ExportOutput` cannot be found for the
+ passed `EstimatorSpec.mode`
"""
mode = estimator_spec.mode
if mode == model_fn_lib.ModeKeys.PREDICT:
@@ -1044,13 +1062,13 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
+ be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
Args:
graph: The graph in which to create the global step tensor.
Returns:
- The global step `Tensor`.
+ The global step `tf.Tensor`.
"""
return training.create_global_step(graph)
@@ -1061,7 +1079,7 @@ class Estimator(object):
graph: The graph in which to create the global step tensor.
Returns:
- The global step `Tensor`.
+ The global step `tf.Tensor`.
"""
step = self._create_global_step(graph)
assert step == training.get_global_step()
@@ -1073,21 +1091,21 @@ class Estimator(object):
Args:
input_fn: The input function.
- mode: ModeKeys
+ mode: `tf.estimator.ModeKeys`
Returns:
- The return value of the passed input_fn, which should be one of:
+ The return value of the passed `input_fn`, which should be one of:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
- tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where `features` is a `Tensor` or a
+ tuple `(features, labels)` with same constraints as below.
+ * A tuple `(features, labels)`: Where `features` is a `Tensor` or a
dictionary of string feature name to `Tensor` and `labels` is a
`Tensor` or a dictionary of string label name to `Tensor`. Both
`features` and `labels` are consumed by `model_fn`. They should
satisfy the expectation of `model_fn` from inputs.
Raises:
- ValueError: if input_fn takes invalid arguments.
+ ValueError: if `input_fn` takes invalid arguments.
"""
input_fn_args = function_utils.fn_args(input_fn)
kwargs = {}
@@ -1106,14 +1124,14 @@ class Estimator(object):
Args:
features: features dict.
labels: labels dict.
- mode: ModeKeys
- config: RunConfig
+ mode: `tf.estimator.ModeKeys`
+ config: `tf.estimator.RunConfig`
Returns:
- An `EstimatorSpec` object.
+ An `tf.estimator.EstimatorSpec` object.
Raises:
- ValueError: if model_fn returns invalid objects.
+ ValueError: if `model_fn` returns invalid objects.
"""
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
@@ -1146,14 +1164,14 @@ class Estimator(object):
return self._train_model_default(input_fn, hooks, saving_listeners)
def _train_model_default(self, input_fn, hooks, saving_listeners):
- """Initiate training with input_fn, without DistributionStrategies.
+ """Initiate training with `input_fn`, without `DistributionStrategies`.
Args:
input_fn: A function that provides input data for training as minibatches.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the training loop.
- saving_listeners: list of `CheckpointSaverListener` objects. Used for
- callbacks that run immediately before or after checkpoint savings.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the training loop.
+ saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
+ for callbacks that run immediately before or after checkpoint savings.
Returns:
Loss from training
@@ -1180,14 +1198,14 @@ class Estimator(object):
saving_listeners)
def _train_model_distributed(self, input_fn, hooks, saving_listeners):
- """Initiate training with input_fn, using DistributionStrategies.
+ """Initiate training with `input_fn`, using `DistributionStrategies`.
Args:
input_fn: A function that provides input data for training as minibatches.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the training loop.
- saving_listeners: list of `CheckpointSaverListener` objects. Used for
- callbacks that run immediately before or after checkpoint savings.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the training loop.
+ saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
+ for callbacks that run immediately before or after checkpoint savings.
Returns:
Loss from training
@@ -1535,9 +1553,9 @@ def maybe_overwrite_model_dir_and_session_config(config, model_dir):
"`model_dir` are set both in constructor and `RunConfig`, but with "
"different values. In constructor: '{}', in `RunConfig`: "
"'{}' ".format(model_dir, config.model_dir))
- if model_dir:
- config = run_config.RunConfig.replace(config, model_dir=model_dir)
- if getattr(config, 'model_dir', None) is None:
+ if model_dir:
+ config = run_config.RunConfig.replace(config, model_dir=model_dir)
+ elif getattr(config, 'model_dir', None) is None:
model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s', model_dir)
config = run_config.RunConfig.replace(config, model_dir=model_dir)
@@ -1546,7 +1564,7 @@ def maybe_overwrite_model_dir_and_session_config(config, model_dir):
def create_per_tower_ready_op(scaffold):
- """Create a Scaffold.ready_op inside a tower."""
+ """Create a `tf.train.Scaffold.ready_op` inside a tower."""
if scaffold.ready_op:
return scaffold.ready_op
@@ -1561,7 +1579,7 @@ def create_per_tower_ready_op(scaffold):
def create_per_tower_ready_for_local_init_op(scaffold):
- """Create a Scaffold.ready_for_local_init_op inside a tower."""
+ """Create a `tf.train.Scaffold.ready_for_local_init_op` inside a tower."""
if scaffold.ready_for_local_init_op:
return scaffold.ready_for_local_init_op
@@ -1659,7 +1677,7 @@ def _check_checkpoint_available(model_dir):
def _check_hooks_type(hooks):
- """Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
+ """Returns hooks if all are `SessionRunHook`, raises TypeError otherwise."""
hooks = list(hooks or [])
for h in hooks:
if not isinstance(h, training.SessionRunHook):
@@ -1679,17 +1697,18 @@ def _check_listeners_type(saving_listeners):
def _get_replica_device_setter(config):
- """Creates a replica device setter if required as a default device_fn.
+ """Creates a replica device setter if required as a default `device_fn`.
- `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
- distributed related arguments such as number of ps_replicas based on given
- config.
+ `Estimator` uses `tf.train.ReplicaDeviceSetter` as a default device placer. It
+ sets the
+ distributed related arguments such as number of `ps_replicas` based on given
+ `config`.
Args:
- config: A `RunConfig` instance.
+ config: A `tf.estimator.RunConfig` instance.
Returns:
- A replica device setter, or None.
+ A replica device setter, or `None`.
"""
if config.task_type:
worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
@@ -1708,7 +1727,7 @@ def _get_replica_device_setter(config):
def _verify_model_fn_args(model_fn, params):
- """Verifies model fn arguments."""
+ """Verifies `model_fn` arguments."""
args = set(function_utils.fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
@@ -1853,7 +1872,7 @@ def _write_checkpoint_path_to_summary(output_dir, checkpoint_path,
def _has_dataset_or_queue_runner(maybe_tensor):
- """Returns True if TF dataset or QueueRunner has been used."""
+ """Returns `True` if `Dataset` or `QueueRunner` has been used."""
# Check TF dataset first. Here, we use a simple algorithm to check the top
# level Tensors only, which should be sufficient for most users.
tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)]
@@ -1876,9 +1895,9 @@ class WarmStartSettings(
'var_name_to_vocab_info',
'var_name_to_prev_var_name',
])):
- """Settings for warm-starting in Estimators.
+ """Settings for warm-starting in `tf.estimator.Estimators`.
- Example Use with canned `DNNEstimator`:
+ Example Use with canned `tf.estimator.DNNEstimator`:
```
emb_vocab_file = tf.feature_column.embedding_column(
@@ -1995,23 +2014,19 @@ class WarmStartSettings(
ckpt_to_initialize_from: [Required] A string specifying the directory with
checkpoint file(s) or path to checkpoint from which to warm-start the
model parameters.
- vars_to_warm_start: [Optional] One of the following:
-
- - A regular expression (string) that captures which variables to
- warm-start (see tf.get_collection). This expression will only consider
- variables in the TRAINABLE_VARIABLES collection.
- - A list of Variables to warm-start.
- - A list of strings, each representing a full variable name to warm-start.
- - `None`, in which case only variables specified in
- `var_name_to_vocab_info` will be warm-started.
-
- Defaults to `'.*'`, which warm-starts all variables in the
- TRAINABLE_VARIABLES collection. Note that this excludes variables such as
- accumulators and moving statistics from batch norm.
+ vars_to_warm_start: [Optional] One of the following: - A regular expression
+ (string) that captures which variables to warm-start (see
+ `tf.get_collection`). This expression will only consider variables in the
+ `TRAINABLE_VARIABLES` collection. - A list of Variables to warm-start. - A
+ list of strings, each representing a full variable name to warm-start. -
+ `None`, in which case only variables specified in `var_name_to_vocab_info`
+ will be warm-started. Defaults to `'.*'`, which warm-starts all variables
+ in the `TRAINABLE_VARIABLES` collection. Note that this excludes
+ variables such as accumulators and moving statistics from batch norm.
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
- VocabInfo. The variable names should be "full" variables, not the names
- of the partitions. If not explicitly provided, the variable is assumed to
- have no vocabulary.
+ `tf.estimator.VocabInfo`. The variable names should be "full" variables,
+ not the names of the partitions. If not explicitly provided, the variable
+ is assumed to have no vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
@@ -2036,7 +2051,7 @@ class WarmStartSettings(
def _get_saved_model_ckpt(saved_model_dir):
- """Return path to variables checkpoint in a SavedModel directory."""
+ """Return path to variables checkpoint in a `SavedModel` directory."""
if not gfile.Exists(
os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
compat.as_text('variables.index'))):
@@ -2046,18 +2061,20 @@ def _get_saved_model_ckpt(saved_model_dir):
def _get_default_warm_start_settings(warm_start_from):
- """Returns default WarmStartSettings.
+ """Returns default `tf.estimator.WarmStartSettings`.
Args:
warm_start_from: Either a string representing the filepath of a checkpoint
- or SavedModel to initialize from, or an instance of WarmStartSettings.
+ or `SavedModel` to initialize from, or an instance of
+ `tf.estimator.WarmStartSettings`.
Returns:
- Either None or an instance of WarmStartSettings.
+ Either None or an instance of `WarmStartSettings`.
Raises:
- ValueError: If warm_start_from is not None but is neither a string nor an
- instance of WarmStartSettings.
+ ValueError: If `warm_start_from` is not `None` but is neither a string nor
+ an
+ instance of `WarmStartSettings`.
"""
if warm_start_from is None:
return None
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index e3f22d9010..05d1a04d2f 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -58,6 +58,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
+from tensorflow.python.ops.random_ops import random_uniform
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@@ -158,16 +159,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
def __init__(self):
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
- def _call_input_fn(self, input_fn, mode):
- return input_fn()
-
- def _create_global_step(self, graph):
- pass
-
- def _convert_train_steps_to_hooks(self, steps, max_steps):
- pass
-
- def _convert_eval_steps_to_hooks(self, steps):
+ def _tf_api_names(self):
pass
_Estimator()
@@ -473,6 +465,29 @@ class EstimatorTrainTest(test.TestCase):
est.train(InputFn(), steps=1)
self.assertEqual(1, input_fn_call_count[0])
+ def test_nested_input_fn(self):
+ expected_params = {'batch_size': 10}
+
+ def _input_fn():
+ dataset_features = dataset_ops.Dataset.from_tensor_slices(
+ (random_uniform([4]),
+ random_uniform([4, 100], maxval=100, dtype=dtypes.int32)))
+ dataset_labels = dataset_ops.Dataset.from_tensor_slices(
+ random_uniform([4, 10]))
+ dataset = dataset_ops.Dataset.zip((dataset_features, dataset_labels))
+ dataset = dataset.repeat(-1)
+ iterator = dataset.make_initializable_iterator()
+ return iterator.get_next()
+
+ def _model_fn(features, labels, mode, params, config):
+ del params, config
+ return model_fn_global_step_incrementer(features, labels, mode)
+
+ expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
+ est = estimator.Estimator(
+ model_fn=_model_fn, params=expected_params, config=expected_config)
+ est.train(_input_fn, steps=4)
+
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.TRAIN
expected_params = {'batch_size': 10}
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index a5f07fea3b..e4ce5339d0 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -43,7 +43,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -361,7 +361,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
"""model_fn for keras Estimator."""
# Raise an error when users use DistributionStrategy with native Keras
# optimizers. Currently we only support native TensorFlow optimizers.
- if distribute_lib.has_distribution_strategy() and \
+ if distribution_strategy_context.has_distribution_strategy() and \
not isinstance(keras_model.optimizer,
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
raise ValueError('Only TensorFlow native optimizers are supported with '
@@ -373,7 +373,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
# We need to make sure that the output names of the last layer in the model
# is the same for each of the cloned models. This is required for mirrored
# strategy when we call regroup.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
for name in model.output_names:
name = re.compile(r'_\d$').sub('', name)
model_output_names.append(name)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 98a1802490..5527f52860 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4860,6 +4860,18 @@ class Graph(object):
else:
self._graph_control_dependencies_stack = control_dependencies
+ @property
+ def _distribution_strategy_stack(self):
+ """A stack to maintain distribution strategy context for each thread."""
+ if not hasattr(self._thread_local, "_distribution_strategy_stack"):
+ self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access
+ return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access
+
+ @_distribution_strategy_stack.setter
+ def _distribution_strategy_stack(self, _distribution_strategy_stack):
+ self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access
+ _distribution_strategy_stack)
+
def _mutation_lock(self):
"""Returns a lock to guard code that creates & mutates ops.
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 9be6391b04..c2c97dd684 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -718,7 +718,7 @@ def run_in_graph_and_eager_modes(func=None,
def decorated(self, **kwargs):
try:
- with context.graph_mode():
+ with ops.Graph().as_default():
with self.test_session(use_gpu=use_gpu, config=config):
f(self, **kwargs)
except unittest.case.SkipTest:
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index e04d0e93e2..fa1ec51aa7 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -296,109 +296,15 @@ py_test(
)
py_test(
- name = "densenet_test",
- size = "large",
- srcs = ["applications/densenet_test.py"],
- srcs_version = "PY2AND3",
- tags = ["nomsan"], # times out, http://b/78650237
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "inception_resnet_v2_test",
- size = "medium",
- srcs = ["applications/inception_resnet_v2_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "inception_v3_test",
- size = "medium",
- srcs = ["applications/inception_v3_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "mobilenet_test",
- size = "medium",
- srcs = ["applications/mobilenet_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "nasnet_test",
- size = "large",
- srcs = ["applications/nasnet_test.py"],
- srcs_version = "PY2AND3",
- tags = ["nomsan"], # times out, http://b/78573625
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resnet50_test",
- size = "medium",
- srcs = ["applications/resnet50_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "vgg16_test",
- size = "small",
- srcs = ["applications/vgg16_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "vgg19_test",
- size = "small",
- srcs = ["applications/vgg19_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "xception_test",
- size = "medium",
- srcs = ["applications/xception_test.py"],
+ name = "applications_test",
+ size = "enormous",
+ srcs = ["applications/applications_test.py"],
+ shard_count = 2,
srcs_version = "PY2AND3",
deps = [
":keras",
"//tensorflow/python:client_testlib",
- "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -493,7 +399,7 @@ py_test(
py_test(
name = "local_test",
- size = "medium",
+ size = "large",
srcs = ["layers/local_test.py"],
srcs_version = "PY2AND3",
deps = [
@@ -719,14 +625,15 @@ cuda_py_test(
)
py_test(
- name = "imagenet_utils_test",
+ name = "conv_utils_test",
size = "small",
- srcs = ["applications/imagenet_utils_test.py"],
+ srcs = ["utils/conv_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":keras",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index 51cc51998c..cd9462d6b5 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -39,7 +39,7 @@ from tensorflow.python.keras.applications.densenet import DenseNet201
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras.applications.mobilenet import MobileNet
-from tensorflow.python.keras.applications.mobilenet_v2 import MobileNetV2
+# TODO(fchollet): enable MobileNetV2 in next version.
from tensorflow.python.keras.applications.nasnet import NASNetLarge
from tensorflow.python.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras.applications.resnet50 import ResNet50
diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py
new file mode 100644
index 0000000000..ef3198a937
--- /dev/null
+++ b/tensorflow/python/keras/applications/applications_test.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration tests for Keras applications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.keras import applications
+from tensorflow.python.platform import test
+
+
+MODEL_LIST = [
+ (applications.ResNet50, 2048),
+ (applications.VGG16, 512),
+ (applications.VGG19, 512),
+ (applications.Xception, 2048),
+ (applications.InceptionV3, 2048),
+ (applications.InceptionResNetV2, 1536),
+ (applications.MobileNet, 1024),
+ # TODO(fchollet): enable MobileNetV2 in next version.
+ (applications.DenseNet121, 1024),
+ (applications.DenseNet169, 1664),
+ (applications.DenseNet201, 1920),
+ (applications.NASNetMobile, 1056),
+ (applications.NASNetLarge, 4032),
+]
+
+
+class ApplicationsTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(*MODEL_LIST)
+ def test_classification_model(self, model_fn, _):
+ model = model_fn(classes=1000, weights=None)
+ self.assertEqual(model.output_shape[-1], 1000)
+
+ @parameterized.parameters(*MODEL_LIST)
+ def test_feature_extration_model(self, model_fn, output_dim):
+ model = model_fn(include_top=False, weights=None)
+ self.assertEqual(model.output_shape, (None, None, None, output_dim))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/applications/densenet_test.py b/tensorflow/python/keras/applications/densenet_test.py
deleted file mode 100644
index 8b6aa281ad..0000000000
--- a/tensorflow/python/keras/applications/densenet_test.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for DenseNet application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class DenseNet121Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.DenseNet121(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.DenseNet121(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 1024))
-
- def test_with_pooling(self):
- model = keras.applications.DenseNet121(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 1024))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.DenseNet121(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.DenseNet121(weights='imagenet',
- classes=2000)
-
-
-class DenseNet169Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.DenseNet169(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.DenseNet169(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 1664))
-
- def test_with_pooling(self):
- model = keras.applications.DenseNet169(weights=None,
- include_top=False,
- pooling='max')
- self.assertEqual(model.output_shape, (None, 1664))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.DenseNet169(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.DenseNet169(weights='imagenet',
- classes=2000)
-
-
-class DenseNet201(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.DenseNet201(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.DenseNet201(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 1920))
-
- def test_with_pooling(self):
- model = keras.applications.DenseNet201(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 1920))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.DenseNet201(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.DenseNet201(weights='imagenet',
- classes=2000)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/imagenet_utils_test.py b/tensorflow/python/keras/applications/imagenet_utils_test.py
deleted file mode 100644
index 037e939ac5..0000000000
--- a/tensorflow/python/keras/applications/imagenet_utils_test.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Inception V3 application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python import keras
-from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.platform import test
-
-
-class ImageNetUtilsTest(test.TestCase):
-
- def test_preprocess_input(self):
- # Test batch of images
- x = np.random.uniform(0, 255, (2, 10, 10, 3))
- self.assertEqual(preprocess_input(x).shape, x.shape)
- out1 = preprocess_input(x, 'channels_last')
- out2 = preprocess_input(np.transpose(x, (0, 3, 1, 2)), 'channels_first')
- self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))
-
- # Test single image
- x = np.random.uniform(0, 255, (10, 10, 3))
- self.assertEqual(preprocess_input(x).shape, x.shape)
- out1 = preprocess_input(x, 'channels_last')
- out2 = preprocess_input(np.transpose(x, (2, 0, 1)), 'channels_first')
- self.assertAllClose(out1, out2.transpose(1, 2, 0))
-
- def test_preprocess_input_symbolic(self):
- # Test image batch
- x = np.random.uniform(0, 255, (2, 10, 10, 3))
- inputs = keras.layers.Input(shape=x.shape[1:])
- outputs = keras.layers.Lambda(
- preprocess_input, output_shape=x.shape[1:])(inputs)
- model = keras.models.Model(inputs, outputs)
- assert model.predict(x).shape == x.shape
- # pylint: disable=g-long-lambda
- outputs1 = keras.layers.Lambda(lambda x:
- preprocess_input(x, 'channels_last'),
- output_shape=x.shape[1:])(inputs)
- model1 = keras.models.Model(inputs, outputs1)
- out1 = model1.predict(x)
- x2 = np.transpose(x, (0, 3, 1, 2))
- inputs2 = keras.layers.Input(shape=x2.shape[1:])
- # pylint: disable=g-long-lambda
- outputs2 = keras.layers.Lambda(lambda x:
- preprocess_input(x, 'channels_first'),
- output_shape=x2.shape[1:])(inputs2)
- model2 = keras.models.Model(inputs2, outputs2)
- out2 = model2.predict(x2)
- self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))
-
- # Test single image
- x = np.random.uniform(0, 255, (10, 10, 3))
- inputs = keras.layers.Input(shape=x.shape)
- outputs = keras.layers.Lambda(preprocess_input,
- output_shape=x.shape)(inputs)
- model = keras.models.Model(inputs, outputs)
- assert model.predict(x[np.newaxis])[0].shape == x.shape
- # pylint: disable=g-long-lambda
- outputs1 = keras.layers.Lambda(lambda x:
- preprocess_input(x, 'channels_last'),
- output_shape=x.shape)(inputs)
- model1 = keras.models.Model(inputs, outputs1)
- out1 = model1.predict(x[np.newaxis])[0]
- x2 = np.transpose(x, (2, 0, 1))
- inputs2 = keras.layers.Input(shape=x2.shape)
- outputs2 = keras.layers.Lambda(lambda x:
- preprocess_input(x, 'channels_first'),
- output_shape=x2.shape)(inputs2) # pylint: disable=g-long-lambda
- model2 = keras.models.Model(inputs2, outputs2)
- out2 = model2.predict(x2[np.newaxis])[0]
- self.assertAllClose(out1, out2.transpose(1, 2, 0))
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2_test.py b/tensorflow/python/keras/applications/inception_resnet_v2_test.py
deleted file mode 100644
index 0a12f88505..0000000000
--- a/tensorflow/python/keras/applications/inception_resnet_v2_test.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Inception V3 application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class InceptionResNetV2Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.InceptionResNetV2(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.InceptionResNetV2(weights=None,
- include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 1536))
-
- def test_with_pooling(self):
- model = keras.applications.InceptionResNetV2(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 1536))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.InceptionResNetV2(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.InceptionResNetV2(weights='imagenet',
- classes=2000)
-
- def test_preprocess_input(self):
- x = np.random.uniform(0, 255, (2, 300, 200, 3))
- out1 = keras.applications.inception_resnet_v2.preprocess_input(x)
- self.assertAllClose(np.mean(out1), 0., atol=0.1)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/inception_v3_test.py b/tensorflow/python/keras/applications/inception_v3_test.py
deleted file mode 100644
index a3fcdd5564..0000000000
--- a/tensorflow/python/keras/applications/inception_v3_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Inception V3 application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class InceptionV3Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.InceptionV3(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.InceptionV3(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 2048))
-
- def test_with_pooling(self):
- model = keras.applications.InceptionV3(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 2048))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.InceptionV3(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.InceptionV3(weights='imagenet',
- classes=2000)
-
- def test_preprocess_input(self):
- x = np.random.uniform(0, 255, (2, 300, 200, 3))
- out1 = keras.applications.inception_v3.preprocess_input(x)
- self.assertAllClose(np.mean(out1), 0., atol=0.1)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/mobilenet_test.py b/tensorflow/python/keras/applications/mobilenet_test.py
deleted file mode 100644
index 65e4991ded..0000000000
--- a/tensorflow/python/keras/applications/mobilenet_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MobileNet application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class MobileNetTest(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.MobileNet(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.MobileNet(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 1024))
-
- def test_with_pooling(self):
- model = keras.applications.MobileNet(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 1024))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.MobileNet(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.MobileNet(weights='imagenet',
- classes=2000)
-
- def test_preprocess_input(self):
- x = np.random.uniform(0, 255, (2, 300, 200, 3))
- out1 = keras.applications.mobilenet.preprocess_input(x)
- self.assertAllClose(np.mean(out1), 0., atol=0.1)
-
- def test_mobilenet_variable_input_channels(self):
- input_shape = (None, None, 1)
- model = keras.applications.MobileNet(weights=None,
- include_top=False,
- input_shape=input_shape)
- self.assertEqual(model.output_shape, (None, None, None, 1024))
-
- input_shape = (None, None, 4)
- model = keras.applications.MobileNet(weights=None,
- include_top=False,
- input_shape=input_shape)
- self.assertEqual(model.output_shape, (None, None, None, 1024))
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py
index 74b8b029f8..9194c3ee14 100644
--- a/tensorflow/python/keras/applications/mobilenet_v2.py
+++ b/tensorflow/python/keras/applications/mobilenet_v2.py
@@ -19,14 +19,4 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from keras_applications import mobilenet_v2
-
-from tensorflow.python.util.tf_export import tf_export
-
-MobileNetV2 = mobilenet_v2.MobileNetV2
-decode_predictions = mobilenet_v2.decode_predictions
-preprocess_input = mobilenet_v2.preprocess_input
-
-tf_export('keras.applications.mobilenet_v2.MobileNetV2',
- 'keras.applications.MobileNetV2')(MobileNetV2)
-tf_export('keras.applications.mobilenet_v2.preprocess_input')(preprocess_input)
+# TODO(fchollet): export MobileNetV2 as part of the public API in next version.
diff --git a/tensorflow/python/keras/applications/nasnet_test.py b/tensorflow/python/keras/applications/nasnet_test.py
deleted file mode 100644
index f96c3aa51c..0000000000
--- a/tensorflow/python/keras/applications/nasnet_test.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Nasnet application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class NASNetMobileTest(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.NASNetMobile(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.NASNetMobile(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 1056))
-
- def test_with_pooling(self):
- model = keras.applications.NASNetMobile(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 1056))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.NASNetMobile(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.NASNetMobile(weights='imagenet',
- classes=2000)
-
-
-class NASNetLargeTest(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.NASNetLarge(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.NASNetLarge(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 4032))
-
- def test_with_pooling(self):
- model = keras.applications.NASNetLarge(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 4032))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.NASNetLarge(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.NASNetLarge(weights='imagenet',
- classes=2000)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/resnet50_test.py b/tensorflow/python/keras/applications/resnet50_test.py
deleted file mode 100644
index 22a3f05580..0000000000
--- a/tensorflow/python/keras/applications/resnet50_test.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for ResNet50 application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class ResNet50Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.ResNet50(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.ResNet50(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 2048))
-
- def test_with_pooling(self):
- model = keras.applications.ResNet50(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 2048))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.ResNet50(weights='unknown',
- include_top=False)
-
- with self.assertRaises(ValueError):
- keras.applications.ResNet50(weights='imagenet',
- classes=2000)
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/vgg16_test.py b/tensorflow/python/keras/applications/vgg16_test.py
deleted file mode 100644
index cad65765f3..0000000000
--- a/tensorflow/python/keras/applications/vgg16_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for VGG16 application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class VGG16Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.VGG16(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.VGG16(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 512))
-
- def test_with_pooling(self):
- model = keras.applications.VGG16(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 512))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.VGG16(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.VGG16(weights='imagenet',
- classes=2000)
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/vgg19_test.py b/tensorflow/python/keras/applications/vgg19_test.py
deleted file mode 100644
index 61dccc0c5c..0000000000
--- a/tensorflow/python/keras/applications/vgg19_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for VGG19 application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class VGG19Test(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.VGG19(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.VGG19(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 512))
-
- def test_with_pooling(self):
- model = keras.applications.VGG19(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 512))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.VGG19(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.VGG19(weights='imagenet',
- classes=2000)
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/applications/xception_test.py b/tensorflow/python/keras/applications/xception_test.py
deleted file mode 100644
index 7e2efd0017..0000000000
--- a/tensorflow/python/keras/applications/xception_test.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Xception application."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python import keras
-from tensorflow.python.platform import test
-
-
-class XceptionTest(test.TestCase):
-
- def test_with_top(self):
- model = keras.applications.Xception(weights=None)
- self.assertEqual(model.output_shape, (None, 1000))
-
- def test_no_top(self):
- model = keras.applications.Xception(weights=None, include_top=False)
- self.assertEqual(model.output_shape, (None, None, None, 2048))
-
- def test_with_pooling(self):
- model = keras.applications.Xception(weights=None,
- include_top=False,
- pooling='avg')
- self.assertEqual(model.output_shape, (None, 2048))
-
- def test_weight_loading(self):
- with self.assertRaises(ValueError):
- keras.applications.Xception(weights='unknown',
- include_top=False)
- with self.assertRaises(ValueError):
- keras.applications.Xception(weights='imagenet',
- classes=2000)
-
- def test_preprocess_input(self):
- x = np.random.uniform(0, 255, (2, 300, 200, 3))
- out1 = keras.applications.xception.preprocess_input(x)
- self.assertAllClose(np.mean(out1), 0., atol=0.1)
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index f2feeb85a1..befe82f4ec 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -833,7 +833,7 @@ class TensorBoard(Callback):
Raises:
ValueError: If histogram_freq is set and no validation data is provided.
- @compatbility(eager)
+ @compatibility(eager)
Using `Tensorboard` callback will work while eager execution is enabled,
however outputting histogram summaries of weights and gradients is not
supported, and thus `histogram_freq` will be ignored.
diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py
index 2a05699407..a103b9fbf2 100644
--- a/tensorflow/python/keras/integration_test.py
+++ b/tensorflow/python/keras/integration_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.framework import dtypes
from tensorflow.python.keras import testing_utils
from tensorflow.python.layers import core as tf_core_layers
from tensorflow.python.ops import nn
+from tensorflow.python.ops import rnn_cell
from tensorflow.python.platform import test
@@ -103,6 +105,30 @@ class KerasIntegrationTest(test.TestCase):
verbose=2)
self.assertGreater(history.history['val_acc'][-1], 0.7)
+ def test_temporal_classification_sequential_tf_rnn(self):
+ with self.test_session():
+ np.random.seed(1337)
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=100,
+ test_samples=0,
+ input_shape=(4, 10),
+ num_classes=2)
+ y_train = keras.utils.to_categorical(y_train)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.RNN(rnn_cell.LSTMCell(5), return_sequences=True,
+ input_shape=x_train.shape[1:]))
+ model.add(keras.layers.RNN(rnn_cell.GRUCell(y_train.shape[-1],
+ activation='softmax',
+ dtype=dtypes.float32)))
+ model.compile(loss='categorical_crossentropy',
+ optimizer=keras.optimizers.Adam(lr=0.1),
+ metrics=['accuracy'])
+ history = model.fit(x_train, y_train, epochs=15, batch_size=16,
+ validation_data=(x_train, y_train),
+ verbose=2)
+ self.assertGreater(history.history['val_acc'][-1], 0.7)
+
def test_image_classification_sequential(self):
with self.test_session():
np.random.seed(1337)
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index 0ebafe07cc..33d09a1660 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -85,6 +85,28 @@ class LocallyConnected1D(Layer):
the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
+ implementation: implementation mode, either `1` or `2`.
+ `1` loops over input spatial locations to perform the forward pass.
+ It is memory-efficient but performs a lot of (small) ops.
+
+ `2` stores layer weights in a dense but sparsely-populated 2D matrix
+ and implements the forward pass as a single matrix-multiply. It uses
+ a lot of RAM but performs few (large) ops.
+
+ Depending on the inputs, layer parameters, hardware, and
+ `tf.executing_eagerly()` one implementation can be dramatically faster
+ (e.g. 50X) than another.
+
+ It is recommended to benchmark both in the setting of interest to pick
+ the most efficient one (in terms of speed and memory usage).
+
+ Following scenarios could benefit from setting `implementation=2`:
+ - eager execution;
+ - inference;
+ - running on CPU;
+ - large amount of RAM available;
+ - small models (few filters, small kernel);
+ - using `padding=same` (only possible with `implementation=2`).
Input shape:
3D tensor with shape: `(batch_size, steps, input_dim)`
@@ -109,15 +131,17 @@ class LocallyConnected1D(Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
+ implementation=1,
**kwargs):
super(LocallyConnected1D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
self.padding = conv_utils.normalize_padding(padding)
- if self.padding != 'valid':
+ if self.padding != 'valid' and implementation == 1:
raise ValueError('Invalid border mode for LocallyConnected1D '
- '(only "valid" is supported): ' + padding)
+ '(only "valid" is supported if implementation is 1): '
+ + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
@@ -128,6 +152,7 @@ class LocallyConnected1D(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
+ self.implementation = implementation
self.input_spec = InputSpec(ndim=3)
@tf_utils.shape_type_conversion
@@ -142,14 +167,45 @@ class LocallyConnected1D(Layer):
'Found shape:', input_shape)
self.output_length = conv_utils.conv_output_length(
input_length, self.kernel_size[0], self.padding, self.strides[0])
- self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
- self.filters)
- self.kernel = self.add_weight(
- shape=self.kernel_shape,
- initializer=self.kernel_initializer,
- name='kernel',
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
+
+ if self.implementation == 1:
+ self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
+ self.filters)
+
+ self.kernel = self.add_weight(
+ shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ elif self.implementation == 2:
+ if self.data_format == 'channels_first':
+ self.kernel_shape = (input_dim, input_length,
+ self.filters, self.output_length)
+ else:
+ self.kernel_shape = (input_length, input_dim,
+ self.output_length, self.filters)
+
+ self.kernel = self.add_weight(shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ self.kernel_mask = get_locallyconnected_mask(
+ input_shape=(input_length,),
+ kernel_shape=self.kernel_size,
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dtype=self.kernel.dtype
+ )
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
+
if self.use_bias:
self.bias = self.add_weight(
shape=(self.output_length, self.filters),
@@ -182,8 +238,17 @@ class LocallyConnected1D(Layer):
return (input_shape[0], length, self.filters)
def call(self, inputs):
- output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
- (self.output_length,), self.data_format)
+ if self.implementation == 1:
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_length,), self.data_format)
+
+ elif self.implementation == 2:
+ output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
+ self.compute_output_shape(inputs.shape))
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -220,7 +285,9 @@ class LocallyConnected1D(Layer):
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'bias_constraint':
- constraints.serialize(self.bias_constraint)
+ constraints.serialize(self.bias_constraint),
+ 'implementation':
+ self.implementation
}
base_config = super(LocallyConnected1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -284,9 +351,31 @@ class LocallyConnected2D(Layer):
the `kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation")..
+ the output of the layer (its "activation").
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
+ implementation: implementation mode, either `1` or `2`.
+ `1` loops over input spatial locations to perform the forward pass.
+ It is memory-efficient but performs a lot of (small) ops.
+
+ `2` stores layer weights in a dense but sparsely-populated 2D matrix
+ and implements the forward pass as a single matrix-multiply. It uses
+ a lot of RAM but performs few (large) ops.
+
+ Depending on the inputs, layer parameters, hardware, and
+ `tf.executing_eagerly()` one implementation can be dramatically faster
+ (e.g. 50X) than another.
+
+ It is recommended to benchmark both in the setting of interest to pick
+ the most efficient one (in terms of speed and memory usage).
+
+ Following scenarios could benefit from setting `implementation=2`:
+ - eager execution;
+ - inference;
+ - running on CPU;
+ - large amount of RAM available;
+ - small models (few filters, small kernel);
+ - using `padding=same` (only possible with `implementation=2`).
Input shape:
4D tensor with shape:
@@ -317,15 +406,17 @@ class LocallyConnected2D(Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
+ implementation=1,
**kwargs):
super(LocallyConnected2D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
- if self.padding != 'valid':
+ if self.padding != 'valid' and implementation == 1:
raise ValueError('Invalid border mode for LocallyConnected2D '
- '(only "valid" is supported): ' + padding)
+ '(only "valid" is supported if implementation is 1): '
+ + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
@@ -336,6 +427,7 @@ class LocallyConnected2D(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
+ self.implementation = implementation
self.input_spec = InputSpec(ndim=4)
@tf_utils.shape_type_conversion
@@ -357,15 +449,47 @@ class LocallyConnected2D(Layer):
self.padding, self.strides[1])
self.output_row = output_row
self.output_col = output_col
- self.kernel_shape = (
- output_row * output_col,
- self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters)
- self.kernel = self.add_weight(
- shape=self.kernel_shape,
- initializer=self.kernel_initializer,
- name='kernel',
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
+
+ if self.implementation == 1:
+ self.kernel_shape = (
+ output_row * output_col,
+ self.kernel_size[0] * self.kernel_size[1] * input_filter,
+ self.filters)
+
+ self.kernel = self.add_weight(
+ shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ elif self.implementation == 2:
+ if self.data_format == 'channels_first':
+ self.kernel_shape = (input_filter, input_row, input_col,
+ self.filters, self.output_row, self.output_col)
+ else:
+ self.kernel_shape = (input_row, input_col, input_filter,
+ self.output_row, self.output_col, self.filters)
+
+ self.kernel = self.add_weight(shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ self.kernel_mask = get_locallyconnected_mask(
+ input_shape=(input_row, input_col),
+ kernel_shape=self.kernel_size,
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dtype=self.kernel.dtype
+ )
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
+
if self.use_bias:
self.bias = self.add_weight(
shape=(output_row, output_col, self.filters),
@@ -401,8 +525,18 @@ class LocallyConnected2D(Layer):
return (input_shape[0], rows, cols, self.filters)
def call(self, inputs):
- output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
- (self.output_row, self.output_col), self.data_format)
+ if self.implementation == 1:
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_row, self.output_col),
+ self.data_format)
+
+ elif self.implementation == 2:
+ output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
+ self.compute_output_shape(inputs.shape))
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -439,7 +573,157 @@ class LocallyConnected2D(Layer):
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'bias_constraint':
- constraints.serialize(self.bias_constraint)
+ constraints.serialize(self.bias_constraint),
+ 'implementation':
+ self.implementation
}
base_config = super(LocallyConnected2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+
+def get_locallyconnected_mask(input_shape,
+ kernel_shape,
+ strides,
+ padding,
+ data_format,
+ dtype):
+ """Return a mask representing connectivity of a locally-connected operation.
+
+ This method returns a masking tensor of 0s and 1s (of type `dtype`) that,
+ when element-wise multiplied with a fully-connected weight tensor, masks out
+ the weights between disconnected input-output pairs and thus implements local
+ connectivity through a sparse fully-connected weight tensor.
+
+ Assume an unshared convolution with given parameters is applied to an input
+ having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)`
+ to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined
+ by layer parameters such as `strides`).
+
+ This method returns a mask which can be broadcast-multiplied (element-wise)
+ with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer between
+ (N+1)-D activations (N spatial + 1 channel dimensions for input and output)
+ to make it perform an unshared convolution with given `kernel_shape`,
+ `strides`, `padding` and `data_format`.
+
+ Arguments:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+ data_format: a string, `"channels_first"` or `"channels_last"`.
+ dtype: type of the layer operation, e.g. `tf.float64`.
+
+ Returns:
+ a `dtype`-tensor of shape
+ `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)`
+ if `data_format == `"channels_first"`, or
+ `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)`
+ if `data_format == "channels_last"`.
+
+ Raises:
+ ValueError: if `data_format` is neither `"channels_first"` nor
+ `"channels_last"`.
+ """
+ mask = conv_utils.conv_kernel_mask(
+ input_shape=input_shape,
+ kernel_shape=kernel_shape,
+ strides=strides,
+ padding=padding
+ )
+
+ ndims = int(mask.ndim / 2)
+ mask = K.variable(mask, dtype)
+
+ if data_format == 'channels_first':
+ mask = K.expand_dims(mask, 0)
+ mask = K.expand_dims(mask, - ndims - 1)
+
+ elif data_format == 'channels_last':
+ mask = K.expand_dims(mask, ndims)
+ mask = K.expand_dims(mask, -1)
+
+ else:
+ raise ValueError('Unrecognized data_format: ' + str(data_format))
+
+ return mask
+
+
+def local_conv_matmul(inputs, kernel, kernel_mask, output_shape):
+ """Apply N-D convolution with un-shared weights using a single matmul call.
+
+ This method outputs `inputs . (kernel * kernel_mask)`
+ (with `.` standing for matrix-multiply and `*` for element-wise multiply)
+ and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and
+ hence perform the same operation as a convolution with un-shared
+ (the remaining entries in `kernel`) weights. It also does the necessary
+ reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D.
+
+ Arguments:
+ inputs: (N+2)-D tensor with shape
+ `(batch_size, channels_in, d_in1, ..., d_inN)`
+ or
+ `(batch_size, d_in1, ..., d_inN, channels_in)`.
+ kernel: the unshared weights for N-D convolution,
+ an (N+2)-D tensor of shape:
+ `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)`
+ or
+ `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`,
+ with the ordering of channels and spatial dimensions matching
+ that of the input.
+ Each entry is the weight between a particular input and
+ output location, similarly to a fully-connected weight matrix.
+ kernel_mask: a float 0/1 mask tensor of shape:
+ `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)`
+ or
+ `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`,
+ with the ordering of singleton and spatial dimensions
+ matching that of the input.
+ Mask represents the connectivity pattern of the layer and is
+ precomputed elsewhere based on layer parameters: stride,
+ padding, and the receptive field shape.
+ output_shape: a tuple of (N+2) elements representing the output shape:
+ `(batch_size, channels_out, d_out1, ..., d_outN)`
+ or
+ `(batch_size, d_out1, ..., d_outN, channels_out)`,
+ with the ordering of channels and spatial dimensions matching that of
+ the input.
+
+ Returns:
+ Output (N+2)-D tensor with shape `output_shape`.
+ """
+ inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
+
+ kernel = kernel_mask * kernel
+ kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)
+
+ output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
+ output = K.reshape(output_flat,
+ [K.shape(output_flat)[0],] + output_shape.as_list()[1:])
+ return output
+
+
+def make_2d(tensor, split_dim):
+ """Reshapes an N-dimensional tensor into a 2D tensor.
+
+ Dimensions before (excluding) and after (including) `split_dim` are grouped
+ together.
+
+ Arguments:
+ tensor: a tensor of shape `(d0, ..., d(N-1))`.
+ split_dim: an integer from 1 to N-1, index of the dimension to group
+ dimensions before (excluding) and after (including).
+
+ Returns:
+ Tensor of shape
+ `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
+ """
+ shape = K.array_ops.shape(tensor)
+ in_dims = shape[:split_dim]
+ out_dims = shape[split_dim:]
+
+ in_size = K.math_ops.reduce_prod(in_dims)
+ out_size = K.math_ops.reduce_prod(out_dims)
+
+ return K.array_ops.reshape(tensor, (in_size, out_size))
diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py
index 9639e0251f..4781bcae07 100644
--- a/tensorflow/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/layers/local_test.py
@@ -24,6 +24,7 @@ from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training.rmsprop import RMSPropOptimizer
class LocallyConnectedLayersTest(test.TestCase):
@@ -36,21 +37,30 @@ class LocallyConnectedLayersTest(test.TestCase):
filter_length = 3
filters = 4
- for padding in ['valid']:
+ for padding in ['valid', 'same']:
for strides in [1]:
if padding == 'same' and strides != 1:
continue
for data_format in ['channels_first', 'channels_last']:
- testing_utils.layer_test(
- keras.layers.LocallyConnected1D,
- kwargs={
- 'filters': filters,
- 'kernel_size': filter_length,
- 'padding': padding,
- 'strides': strides,
- 'data_format': data_format
- },
- input_shape=(num_samples, num_steps, input_dim))
+ for implementation in [1, 2]:
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'padding': padding,
+ 'strides': strides,
+ 'data_format': data_format,
+ 'implementation': implementation
+ }
+
+ if padding == 'same' and implementation == 1:
+ self.assertRaises(ValueError,
+ keras.layers.LocallyConnected1D,
+ **kwargs)
+ else:
+ testing_utils.layer_test(
+ keras.layers.LocallyConnected1D,
+ kwargs=kwargs,
+ input_shape=(num_samples, num_steps, input_dim))
def test_locallyconnected_1d_regularization(self):
num_samples = 2
@@ -59,38 +69,47 @@ class LocallyConnectedLayersTest(test.TestCase):
filter_length = 3
filters = 4
for data_format in ['channels_first', 'channels_last']:
- kwargs = {
- 'filters': filters,
- 'kernel_size': filter_length,
- 'kernel_regularizer': 'l2',
- 'bias_regularizer': 'l2',
- 'activity_regularizer': 'l2',
- 'data_format': data_format
- }
-
- with self.test_session():
- layer = keras.layers.LocallyConnected1D(**kwargs)
- layer.build((num_samples, num_steps, input_dim))
- self.assertEqual(len(layer.losses), 2)
- layer(
- keras.backend.variable(np.ones((num_samples,
- num_steps,
- input_dim))))
- self.assertEqual(len(layer.losses), 3)
-
- k_constraint = keras.constraints.max_norm(0.01)
- b_constraint = keras.constraints.max_norm(0.01)
- kwargs = {
- 'filters': filters,
- 'kernel_size': filter_length,
- 'kernel_constraint': k_constraint,
- 'bias_constraint': b_constraint,
- }
- with self.test_session():
- layer = keras.layers.LocallyConnected1D(**kwargs)
- layer.build((num_samples, num_steps, input_dim))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
+ for padding in ['valid', 'same']:
+ for implementation in [1, 2]:
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'kernel_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'data_format': data_format,
+ 'implementation': implementation,
+ 'padding': padding
+ }
+
+ if padding == 'same' and implementation == 1:
+ self.assertRaises(ValueError,
+ keras.layers.LocallyConnected1D,
+ **kwargs)
+ else:
+ with self.test_session():
+ layer = keras.layers.LocallyConnected1D(**kwargs)
+ layer.build((num_samples, num_steps, input_dim))
+ self.assertEqual(len(layer.losses), 2)
+ layer(
+ keras.backend.variable(np.ones((num_samples,
+ num_steps,
+ input_dim))))
+ self.assertEqual(len(layer.losses), 3)
+
+ k_constraint = keras.constraints.max_norm(0.01)
+ b_constraint = keras.constraints.max_norm(0.01)
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'kernel_constraint': k_constraint,
+ 'bias_constraint': b_constraint,
+ }
+ with self.test_session():
+ layer = keras.layers.LocallyConnected1D(**kwargs)
+ layer.build((num_samples, num_steps, input_dim))
+ self.assertEqual(layer.kernel.constraint, k_constraint)
+ self.assertEqual(layer.bias.constraint, b_constraint)
@tf_test_util.run_in_graph_and_eager_modes
def test_locallyconnected_2d(self):
@@ -100,23 +119,32 @@ class LocallyConnectedLayersTest(test.TestCase):
num_row = 6
num_col = 10
- for padding in ['valid']:
+ for padding in ['valid', 'same']:
for strides in [(1, 1), (2, 2)]:
- if padding == 'same' and strides != (1, 1):
- continue
+ for implementation in [1, 2]:
+ if padding == 'same' and strides != (1, 1):
+ continue
- testing_utils.layer_test(
- keras.layers.LocallyConnected2D,
- kwargs={
- 'filters': filters,
- 'kernel_size': 3,
- 'padding': padding,
- 'kernel_regularizer': 'l2',
- 'bias_regularizer': 'l2',
- 'strides': strides,
- 'data_format': 'channels_last'
- },
- input_shape=(num_samples, num_row, num_col, stack_size))
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': 3,
+ 'padding': padding,
+ 'kernel_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'strides': strides,
+ 'data_format': 'channels_last',
+ 'implementation': implementation
+ }
+
+ if padding == 'same' and implementation == 1:
+ self.assertRaises(ValueError,
+ keras.layers.LocallyConnected2D,
+ **kwargs)
+ else:
+ testing_utils.layer_test(
+ keras.layers.LocallyConnected2D,
+ kwargs=kwargs,
+ input_shape=(num_samples, num_row, num_col, stack_size))
@tf_test_util.run_in_graph_and_eager_modes
def test_locallyconnected_2d_channels_first(self):
@@ -126,14 +154,25 @@ class LocallyConnectedLayersTest(test.TestCase):
num_row = 6
num_col = 10
- testing_utils.layer_test(
- keras.layers.LocallyConnected2D,
- kwargs={
+ for implementation in [1, 2]:
+ for padding in ['valid', 'same']:
+ kwargs = {
'filters': filters,
'kernel_size': 3,
- 'data_format': 'channels_first'
- },
- input_shape=(num_samples, num_row, num_col, stack_size))
+ 'data_format': 'channels_first',
+ 'implementation': implementation,
+ 'padding': padding
+ }
+
+ if padding == 'same' and implementation == 1:
+ self.assertRaises(ValueError,
+ keras.layers.LocallyConnected2D,
+ **kwargs)
+ else:
+ testing_utils.layer_test(
+ keras.layers.LocallyConnected2D,
+ kwargs=kwargs,
+ input_shape=(num_samples, num_row, num_col, stack_size))
def test_locallyconnected_2d_regularization(self):
num_samples = 8
@@ -141,35 +180,271 @@ class LocallyConnectedLayersTest(test.TestCase):
stack_size = 4
num_row = 6
num_col = 10
- kwargs = {
- 'filters': filters,
- 'kernel_size': 3,
- 'kernel_regularizer': 'l2',
- 'bias_regularizer': 'l2',
- 'activity_regularizer': 'l2',
- }
- with self.test_session():
- layer = keras.layers.LocallyConnected2D(**kwargs)
- layer.build((num_samples, num_row, num_col, stack_size))
- self.assertEqual(len(layer.losses), 2)
- layer(
- keras.backend.variable(
- np.ones((num_samples, num_row, num_col, stack_size))))
- self.assertEqual(len(layer.losses), 3)
-
- k_constraint = keras.constraints.max_norm(0.01)
- b_constraint = keras.constraints.max_norm(0.01)
- kwargs = {
- 'filters': filters,
- 'kernel_size': 3,
- 'kernel_constraint': k_constraint,
- 'bias_constraint': b_constraint,
- }
- with self.test_session():
- layer = keras.layers.LocallyConnected2D(**kwargs)
- layer.build((num_samples, num_row, num_col, stack_size))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
+ for implementation in [1, 2]:
+ for padding in ['valid', 'same']:
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': 3,
+ 'kernel_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'implementation': implementation,
+ 'padding': padding
+ }
+
+ if padding == 'same' and implementation == 1:
+ self.assertRaises(ValueError,
+ keras.layers.LocallyConnected2D,
+ **kwargs)
+ else:
+ with self.test_session():
+ layer = keras.layers.LocallyConnected2D(**kwargs)
+ layer.build((num_samples, num_row, num_col, stack_size))
+ self.assertEqual(len(layer.losses), 2)
+ layer(
+ keras.backend.variable(
+ np.ones((num_samples, num_row, num_col, stack_size))))
+ self.assertEqual(len(layer.losses), 3)
+
+ k_constraint = keras.constraints.max_norm(0.01)
+ b_constraint = keras.constraints.max_norm(0.01)
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': 3,
+ 'kernel_constraint': k_constraint,
+ 'bias_constraint': b_constraint,
+ }
+ with self.test_session():
+ layer = keras.layers.LocallyConnected2D(**kwargs)
+ layer.build((num_samples, num_row, num_col, stack_size))
+ self.assertEqual(layer.kernel.constraint, k_constraint)
+ self.assertEqual(layer.bias.constraint, b_constraint)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_locallyconnected_implementation(self):
+ n_train = 4
+ n_classes = 3
+ n_epochs = 2
+
+ np.random.seed(1)
+ targets = np.random.randint(0, n_classes, (n_train,))
+
+ for width in [1, 17]:
+ for height in [16]:
+ for filters in [2]:
+ for data_format in ['channels_first', 'channels_last']:
+ inputs = get_inputs(data_format, filters, height, n_train, width)
+
+ for kernel_x in [(3,)]:
+ for kernel_y in [()] if width == 1 else [(2,)]:
+ for stride_x in [(1,)]:
+ for stride_y in [()] if width == 1 else [(3,)]:
+ for layers in [2]:
+ kwargs = {
+ 'layers': layers,
+ 'filters': filters,
+ 'kernel_size': kernel_x + kernel_y,
+ 'strides': stride_x + stride_y,
+ 'data_format': data_format,
+ 'n_classes': n_classes,
+ 'input_shape': inputs.shape
+ }
+
+ model_1 = get_model(implementation=1, **kwargs)
+ model_2 = get_model(implementation=2, **kwargs)
+
+ copy_model_weights(model_2, model_1)
+
+ # Compare outputs at initialization.
+ out_1 = model_1.call(inputs)
+ out_2 = model_2.call(inputs)
+ self.assertAllCloseAccordingToType(out_1, out_2,
+ rtol=1e-5, atol=1e-5)
+
+ # Train.
+ model_1.fit(x=inputs,
+ y=targets,
+ epochs=n_epochs,
+ batch_size=n_train)
+
+ model_2.fit(x=inputs,
+ y=targets,
+ epochs=n_epochs,
+ batch_size=n_train)
+
+ # Compare outputs after a few training steps.
+ out_1 = model_1.call(inputs)
+ out_2 = model_2.call(inputs)
+ self.assertAllCloseAccordingToType(out_1, out_2,
+ rtol=1e-5, atol=1e-5)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_make_2d(self):
+ input_shapes = [
+ (0,),
+ (0, 0),
+ (1,),
+ (2,),
+ (3,),
+ (1, 0),
+ (0, 3),
+ (1, 1),
+ (1, 2),
+ (3, 1),
+ (2, 2),
+ (3, 3),
+ (1, 0, 1),
+ (5, 2, 3),
+ (3, 5, 6, 7, 0),
+ (3, 2, 2, 4, 4),
+ (1, 2, 3, 4, 7, 2),
+ ]
+ np.random.seed(1)
+
+ for input_shape in input_shapes:
+ inputs = np.random.normal(0, 1, input_shape)
+ inputs_tf = keras.backend.variable(inputs)
+
+ split_dim = np.random.randint(0, inputs.ndim + 1)
+ shape_2d = (int(np.prod(inputs.shape[:split_dim])),
+ int(np.prod(inputs.shape[split_dim:])))
+ inputs_2d = np.reshape(inputs, shape_2d)
+
+ inputs_2d_tf = keras.layers.local.make_2d(inputs_tf, split_dim)
+ inputs_2d_tf = keras.backend.get_value(inputs_2d_tf)
+
+ self.assertAllCloseAccordingToType(inputs_2d, inputs_2d_tf)
+
+
+def get_inputs(data_format, filters, height, n_train, width):
+ if data_format == 'channels_first':
+ if width == 1:
+ input_shape = (filters, height)
+ else:
+ input_shape = (filters, height, width)
+
+ elif data_format == 'channels_last':
+ if width == 1:
+ input_shape = (height, filters)
+ else:
+ input_shape = (height, width, filters)
+
+ else:
+ raise NotImplementedError(data_format)
+
+ inputs = np.random.normal(0, 1,
+ (n_train,) + input_shape).astype(np.float32)
+ return inputs
+
+
+def xent(y_true, y_pred):
+ y_true = keras.backend.cast(
+ keras.backend.reshape(y_true, (-1,)),
+ keras.backend.dtypes_module.int32)
+
+ return keras.backend.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=y_true,
+ logits=y_pred)
+
+
+def get_model(implementation,
+ filters,
+ kernel_size,
+ strides,
+ layers,
+ n_classes,
+ data_format,
+ input_shape):
+ model = keras.Sequential()
+
+ if len(kernel_size) == 1:
+ lc_layer = keras.layers.LocallyConnected1D
+ elif len(kernel_size) == 2:
+ lc_layer = keras.layers.LocallyConnected2D
+ else:
+ raise NotImplementedError(kernel_size)
+
+ for _ in range(layers):
+ model.add(lc_layer(
+ padding='valid',
+ kernel_initializer=keras.initializers.random_normal(),
+ bias_initializer=keras.initializers.random_normal(),
+ filters=filters,
+ strides=strides,
+ kernel_size=kernel_size,
+ activation=keras.activations.relu,
+ data_format=data_format,
+ implementation=implementation))
+
+ model.add(keras.layers.Flatten())
+ model.add(keras.layers.Dense(n_classes))
+ model.compile(
+ optimizer=RMSPropOptimizer(0.01),
+ metrics=[keras.metrics.categorical_accuracy],
+ loss=xent
+ )
+ model.build(input_shape)
+ return model
+
+
+def copy_lc_weights(lc_layer_2_from, lc_layer_1_to):
+ lc_2_kernel, lc_2_bias = lc_layer_2_from.weights
+ lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask
+
+ data_format = lc_layer_2_from.data_format
+
+ if data_format == 'channels_first':
+ if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D):
+ permutation = (3, 0, 1, 2)
+ elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D):
+ permutation = (4, 5, 0, 1, 2, 3)
+ else:
+ raise NotImplementedError(lc_layer_2_from)
+
+ elif data_format == 'channels_last':
+ if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D):
+ permutation = (2, 0, 1, 3)
+ elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D):
+ permutation = (3, 4, 0, 1, 2, 5)
+ else:
+ raise NotImplementedError(lc_layer_2_from)
+
+ else:
+ raise NotImplementedError(data_format)
+
+ lc_2_kernel_masked = keras.backend.permute_dimensions(
+ lc_2_kernel_masked, permutation)
+
+ lc_2_kernel_mask = keras.backend.math_ops.not_equal(
+ lc_2_kernel_masked, 0)
+ lc_2_kernel_flat = keras.backend.array_ops.boolean_mask(
+ lc_2_kernel_masked, lc_2_kernel_mask)
+ lc_2_kernel_reshaped = keras.backend.reshape(lc_2_kernel_flat,
+ lc_layer_1_to.kernel.shape)
+
+ lc_2_kernel_reshaped = keras.backend.get_value(lc_2_kernel_reshaped)
+ lc_2_bias = keras.backend.get_value(lc_2_bias)
+
+ lc_layer_1_to.set_weights([lc_2_kernel_reshaped, lc_2_bias])
+
+
+def copy_model_weights(model_2_from, model_1_to):
+ for l in range(len(model_2_from.layers)):
+ layer_2_from = model_2_from.layers[l]
+ layer_1_to = model_1_to.layers[l]
+
+ if isinstance(layer_2_from, (keras.layers.LocallyConnected2D,
+ keras.layers.LocallyConnected1D)):
+ copy_lc_weights(layer_2_from, layer_1_to)
+
+ elif isinstance(layer_2_from, keras.layers.Dense):
+ weights_2, bias_2 = layer_2_from.weights
+ weights_2 = keras.backend.get_value(weights_2)
+ bias_2 = keras.backend.get_value(bias_2)
+ layer_1_to.set_weights([weights_2, bias_2])
+
+ else:
+ continue
if __name__ == '__main__':
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index a7835bc0a2..cd26e04c39 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -36,7 +36,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -345,16 +345,16 @@ class BatchNormalization(Layer):
aggregation=variable_scope.VariableAggregation.MEAN)
return var
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_mean):
+ with distribution_strategy_context.get_distribution_strategy(
+ ).colocate_vars_with(self.moving_mean):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
# We initialize renorm_stddev to 0, and maintain the (0-initialized)
# renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight.
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_variance):
+ with distribution_strategy_context.get_distribution_strategy(
+ ).colocate_vars_with(self.moving_variance):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
())
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 66c68e2085..12c82a53f6 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -670,6 +670,8 @@ class RNN(Layer):
if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
+ # TF RNN cells expect single tensor as state instead of list wrapped tensor.
+ is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
@@ -677,11 +679,21 @@ class RNN(Layer):
def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
- return self.cell.call(inputs, states, constants=constants, **kwargs)
+
+ states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
+ output, new_states = self.cell.call(
+ inputs, states, constants=constants, **kwargs)
+ if not nest.is_sequence(new_states):
+ new_states = [new_states]
+ return output, new_states
else:
def step(inputs, states):
- return self.cell.call(inputs, states, **kwargs)
+ states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
+ output, new_states = self.cell.call(inputs, states, **kwargs)
+ if not nest.is_sequence(new_states):
+ new_states = [new_states]
+ return output, new_states
last_output, outputs, states = K.rnn(
step,
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 2dde9ee41f..9b87170ebe 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -55,7 +55,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import weights_broadcast_ops
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
@@ -111,7 +111,7 @@ def result_wrapper(result_fn):
def decorated(metric_obj, *args):
"""Decorated function with merge_call."""
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None: # if in cross tower context already
result_t = result_fn(*args)
else:
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index 4f97442e82..f339a7e047 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -28,7 +28,7 @@ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -705,7 +705,7 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
return self.optimizer.compute_gradients(loss, params)
def get_updates(self, loss, params):
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
self.updates = []
if not params:
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 5419e7ae05..3a176c3316 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
@@ -199,3 +200,168 @@ def convert_kernel(kernel):
no_flip = (slice(None, None), slice(None, None))
slices[-2:] = no_flip
return np.copy(kernel[slices])
+
+
+def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
+ """Compute a mask representing the connectivity of a convolution operation.
+
+ Assume a convolution with given parameters is applied to an input having N
+ spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
+ output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
+ of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
+ indicating pairs of input and output locations that are connected by a weight.
+
+ Example:
+ ```python
+ >>> input_shape = (4,)
+ >>> kernel_shape = (2,)
+ >>> strides = (1,)
+ >>> padding = "valid"
+ >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
+ array([[ True, False, False],
+ [ True, True, False],
+ [False, True, True],
+ [False, False, True]], dtype=bool)
+ ```
+ where rows and columns correspond to inputs and outputs respectively.
+
+
+ Args:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+
+ Returns:
+ A boolean 2N-D `np.ndarray` of shape
+ `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
+ is the spatial shape of the output. `True` entries in the mask represent
+ pairs of input-output locations that are connected by a weight.
+
+ Raises:
+ ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
+ same number of dimensions.
+ NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
+ """
+ if padding not in {'same', 'valid'}:
+ raise NotImplementedError('Padding type %s not supported. '
+ 'Only "valid" and "same" '
+ 'are implemented.' % padding)
+
+ in_dims = len(input_shape)
+ if isinstance(kernel_shape, int):
+ kernel_shape = (kernel_shape,) * in_dims
+ if isinstance(strides, int):
+ strides = (strides,) * in_dims
+
+ kernel_dims = len(kernel_shape)
+ stride_dims = len(strides)
+ if kernel_dims != in_dims or stride_dims != in_dims:
+ raise ValueError('Number of strides, input and kernel dimensions must all '
+ 'match. Received: %d, %d, %d.' %
+ (stride_dims, in_dims, kernel_dims))
+
+ output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
+
+ mask_shape = input_shape + output_shape
+ mask = np.zeros(mask_shape, np.bool)
+
+ output_axes_ticks = [range(dim) for dim in output_shape]
+ for output_position in itertools.product(*output_axes_ticks):
+ input_axes_ticks = conv_connected_inputs(input_shape,
+ kernel_shape,
+ output_position,
+ strides,
+ padding)
+ for input_position in itertools.product(*input_axes_ticks):
+ mask[input_position + output_position] = True
+
+ return mask
+
+
+def conv_connected_inputs(input_shape,
+ kernel_shape,
+ output_position,
+ strides,
+ padding):
+ """Return locations of the input connected to an output position.
+
+ Assume a convolution with given parameters is applied to an input having N
+ spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
+ returns N ranges specifying the input region that was convolved with the
+ kernel to produce the output at position
+ `output_position = (p_out1, ..., p_outN)`.
+
+ Example:
+ ```python
+ >>> input_shape = (4, 4)
+ >>> kernel_shape = (2, 1)
+ >>> output_position = (1, 1)
+ >>> strides = (1, 1)
+ >>> padding = "valid"
+ >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
+ >>> strides, padding)
+ [xrange(1, 3), xrange(1, 2)]
+ ```
+ Args:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ output_position: tuple of size N: `(p_out1, ..., p_outN)`,
+ a single position in the output of the convolution.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+
+ Returns:
+ N ranges `[[p_in_left1, ..., p_in_right1], ...,
+ [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
+ input connected to output_position.
+ """
+ ranges = []
+
+ ndims = len(input_shape)
+ for d in range(ndims):
+ left_shift = int(kernel_shape[d] / 2)
+ right_shift = kernel_shape[d] - left_shift
+
+ center = output_position[d] * strides[d]
+
+ if padding == 'valid':
+ center += left_shift
+
+ start = max(0, center - left_shift)
+ end = min(input_shape[d], center + right_shift)
+
+ ranges.append(range(start, end))
+
+ return ranges
+
+
+def conv_output_shape(input_shape, kernel_shape, strides, padding):
+ """Return the output shape of an N-D convolution.
+
+ Forces dimensions where input is empty (size 0) to remain empty.
+
+ Args:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+
+ Returns:
+ tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
+ """
+ dims = range(len(kernel_shape))
+ output_shape = [conv_output_length(input_shape[d],
+ kernel_shape[d],
+ padding,
+ strides[d])
+ for d in dims]
+ output_shape = tuple([0 if input_shape[d] == 0 else output_shape[d]
+ for d in dims])
+ return output_shape
diff --git a/tensorflow/python/keras/utils/conv_utils_test.py b/tensorflow/python/keras/utils/conv_utils_test.py
new file mode 100644
index 0000000000..eb2a360bfd
--- /dev/null
+++ b/tensorflow/python/keras/utils/conv_utils_test.py
@@ -0,0 +1,232 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for conv_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.keras.utils import conv_utils
+from tensorflow.python.platform import test
+
+
+def _get_const_output_shape(input_shape, dim):
+ return tuple([min(d, dim) for d in input_shape])
+
+
+input_shapes = [
+ (0,),
+ (0, 0),
+ (1,),
+ (2,),
+ (3,),
+ (1, 0),
+ (0, 3),
+ (1, 1),
+ (1, 2),
+ (3, 1),
+ (2, 2),
+ (3, 3),
+ (1, 0, 1),
+ (5, 2, 3),
+ (3, 5, 6, 7, 0),
+ (3, 2, 2, 4, 4),
+ (1, 2, 3, 4, 7, 2),
+]
+
+
+@parameterized.parameters(input_shapes)
+class TestConvUtils(test.TestCase, parameterized.TestCase):
+
+ def test_conv_kernel_mask_fc(self, *input_shape):
+ padding = 'valid'
+ kernel_shape = input_shape
+ ndims = len(input_shape)
+ strides = (1,) * ndims
+ output_shape = _get_const_output_shape(input_shape, dim=1)
+ mask = np.ones(input_shape + output_shape, np.bool)
+ self.assertAllEqual(
+ mask,
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ padding
+ )
+ )
+
+ def test_conv_kernel_mask_diag(self, *input_shape):
+ ndims = len(input_shape)
+ kernel_shape = (1,) * ndims
+ strides = (1,) * ndims
+
+ for padding in ['valid', 'same']:
+ mask = np.identity(int(np.prod(input_shape)), np.bool)
+ mask = np.reshape(mask, input_shape * 2)
+ self.assertAllEqual(
+ mask,
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ padding
+ )
+ )
+
+ def test_conv_kernel_mask_full_stride(self, *input_shape):
+ padding = 'valid'
+ ndims = len(input_shape)
+ kernel_shape = (1,) * ndims
+ strides = tuple([max(d, 1) for d in input_shape])
+ output_shape = _get_const_output_shape(input_shape, dim=1)
+
+ mask = np.zeros(input_shape + output_shape, np.bool)
+ if all(d > 0 for d in mask.shape):
+ mask[(0,) * len(output_shape)] = True
+
+ self.assertAllEqual(
+ mask,
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ padding
+ )
+ )
+
+ def test_conv_kernel_mask_almost_full_stride(self, *input_shape):
+ padding = 'valid'
+ ndims = len(input_shape)
+ kernel_shape = (1,) * ndims
+ strides = tuple([max(d - 1, 1) for d in input_shape])
+ output_shape = _get_const_output_shape(input_shape, dim=2)
+
+ mask = np.zeros(input_shape + output_shape, np.bool)
+ if all(d > 0 for d in mask.shape):
+ for in_position in itertools.product(*[[0, d - 1] for d in input_shape]):
+ out_position = tuple([min(p, 1) for p in in_position])
+ mask[in_position + out_position] = True
+
+ self.assertAllEqual(
+ mask,
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ padding
+ )
+ )
+
+ def test_conv_kernel_mask_rect_kernel(self, *input_shape):
+ padding = 'valid'
+ ndims = len(input_shape)
+ strides = (1,) * ndims
+
+ for d in range(ndims):
+ kernel_shape = [1] * ndims
+ kernel_shape[d] = input_shape[d]
+
+ output_shape = list(input_shape)
+ output_shape[d] = min(1, input_shape[d])
+
+ mask = np.identity(int(np.prod(input_shape)), np.bool)
+ mask = np.reshape(mask, input_shape * 2)
+
+ for p in itertools.product(*[range(input_shape[dim])
+ for dim in range(ndims)]):
+ p = list(p)
+ p[d] = slice(None)
+ mask[p * 2] = True
+
+ mask = np.take(mask, range(0, min(1, input_shape[d])), ndims + d)
+
+ self.assertAllEqual(
+ mask,
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ padding
+ )
+ )
+
+ def test_conv_kernel_mask_wrong_padding(self, *input_shape):
+ ndims = len(input_shape)
+ kernel_shape = (1,) * ndims
+ strides = (1,) * ndims
+
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ 'valid'
+ )
+
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ 'same'
+ )
+
+ self.assertRaises(NotImplementedError,
+ conv_utils.conv_kernel_mask,
+ input_shape, kernel_shape, strides, 'full')
+
+ def test_conv_kernel_mask_wrong_dims(self, *input_shape):
+ kernel_shape = 1
+ strides = 1
+
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ 'valid'
+ )
+
+ ndims = len(input_shape)
+
+ kernel_shape = (2,) * (ndims + 1)
+ self.assertRaises(ValueError,
+ conv_utils.conv_kernel_mask,
+ input_shape, kernel_shape, strides, 'same')
+
+ strides = (1,) * ndims
+ self.assertRaises(ValueError,
+ conv_utils.conv_kernel_mask,
+ input_shape, kernel_shape, strides, 'valid')
+
+ kernel_shape = (1,) * ndims
+ strides = (2,) * (ndims - 1)
+ self.assertRaises(ValueError,
+ conv_utils.conv_kernel_mask,
+ input_shape, kernel_shape, strides, 'valid')
+
+ strides = (2,) * ndims
+ conv_utils.conv_kernel_mask(
+ input_shape,
+ kernel_shape,
+ strides,
+ 'valid'
+ )
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 2451dc7257..9fe52f3d28 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -73,6 +73,17 @@ tf_py_test(
)
tf_py_test(
+ name = "batch_gather_op_test",
+ srcs = ["batch_gather_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ ],
+)
+
+tf_py_test(
name = "bcast_ops_test",
size = "small",
srcs = ["bcast_ops_test.py"],
@@ -2181,7 +2192,6 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:parsing_ops",
],
- tags = ["no_windows"],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 40567571e6..81442d12e9 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -245,6 +245,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
array_ops.boolean_mask(tensor, mask).eval()
+@test_util.run_all_in_graph_and_eager_modes
class OperatorShapeTest(test_util.TensorFlowTestCase):
def testExpandScalar(self):
@@ -262,7 +263,8 @@ class OperatorShapeTest(test_util.TensorFlowTestCase):
matrix_squeezed = array_ops.squeeze(matrix, [0])
self.assertEqual(matrix_squeezed.get_shape(), (3))
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(
+ Exception, "Can not squeeze dim.1., expected a dimension of 1, got 3"):
matrix_squeezed = array_ops.squeeze(matrix, [1])
def testSqueezeScalarDim(self):
@@ -270,6 +272,11 @@ class OperatorShapeTest(test_util.TensorFlowTestCase):
matrix_squeezed = array_ops.squeeze(matrix, 0)
self.assertEqual(matrix_squeezed.get_shape(), (3))
+ def testExpandDimsWithNonScalarDim(self):
+ with self.assertRaisesRegexp(Exception,
+ "must be a tensor with a single value"):
+ array_ops.expand_dims(1, axis=[0, 1])
+
class ReverseV2Test(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
new file mode 100644
index 0000000000..8e7ae89f9d
--- /dev/null
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -0,0 +1,116 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.tf.gather."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+_TEST_TYPES = (dtypes.int64, dtypes.float32,
+ dtypes.complex64, dtypes.complex128)
+
+
+class GatherTest(test.TestCase):
+
+ def _buildParams(self, data, dtype):
+ data = data.astype(dtype.as_numpy_dtype)
+ # For complex types, add an index-dependent imaginary component so we can
+ # tell we got the right value.
+ if dtype.is_complex:
+ return data + 10j * data
+ return data
+
+ def testSimpleGather(self):
+ data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13])
+ indices = [3, 4]
+ with self.test_session(use_gpu=True):
+ for dtype in _TEST_TYPES:
+ params_np = self._buildParams(data, dtype)
+ params = constant_op.constant(params_np)
+ indices_tf = constant_op.constant(indices)
+ gather_t = array_ops.batch_gather(params, indices_tf)
+ expected_result = np.array([3, 7])
+ np_val = self._buildParams(expected_result, dtype)
+ gather_val = gather_t.eval()
+ self.assertAllEqual(np_val, gather_val)
+ self.assertEqual(np_val.shape, gather_t.get_shape())
+
+ def test2DArray(self):
+ data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]])
+ indices = [[3], [4]]
+ with self.test_session(use_gpu=True):
+ for dtype in _TEST_TYPES:
+ params_np = self._buildParams(data, dtype)
+ params = constant_op.constant(params_np)
+ indices_tf = constant_op.constant(indices)
+ gather_t = array_ops.batch_gather(params, indices_tf)
+ expected_result = np.array([[3], [15]])
+ np_val = self._buildParams(expected_result, dtype)
+ gather_val = gather_t.eval()
+ self.assertAllEqual(np_val, gather_val)
+ self.assertEqual(np_val.shape, gather_t.get_shape())
+
+ def testHigherRank(self):
+ data = np.array([[[0, 1, 2], [3, 7, 5]], [[8, 9, 10], [11, 15, 13]]])
+ indices = [[[2, 0], [1, 2]], [[2, 0], [0, 1]]]
+ with self.test_session(use_gpu=True):
+ for dtype in _TEST_TYPES:
+ params_np = self._buildParams(data, dtype)
+ params = constant_op.constant(params_np)
+ indices_tf = constant_op.constant(indices)
+ gather_t = array_ops.batch_gather(params, indices_tf)
+ gather_val = gather_t.eval()
+ expected_result = np.array([[[2, 0], [7, 5]], [[10, 8], [11, 15]]])
+ np_val = self._buildParams(expected_result, dtype)
+ self.assertAllEqual(np_val, gather_val)
+ self.assertEqual(np_val.shape, gather_t.get_shape())
+
+ def testString(self):
+ params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
+ with self.test_session():
+ indices_tf = constant_op.constant([1])
+ self.assertAllEqual([[b"qwer", b"uiop"]],
+ array_ops.batch_gather(params, indices_tf).eval())
+
+ def testUnknownIndices(self):
+ params = constant_op.constant([[0, 1, 2]])
+ indices = array_ops.placeholder(dtypes.int32, shape=[None, None])
+ gather_t = array_ops.batch_gather(params, indices)
+ self.assertEqual([1, None], gather_t.get_shape().as_list())
+
+ def testBadIndicesCPU(self):
+ with self.test_session(use_gpu=False):
+ params = [[0, 1, 2], [3, 4, 5]]
+ with self.assertRaisesOpError(r"indices\[0\] = 7 is not in \[0, 2\)"):
+ array_ops.batch_gather(params, [7]).eval()
+
+ def testEmptySlices(self):
+ with self.test_session(use_gpu=True):
+ for dtype in _TEST_TYPES:
+ for itype in np.int32, np.int64:
+ params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
+ indices = np.array([3, 4], dtype=itype)
+ gather = array_ops.batch_gather(params, indices)
+ self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0)))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index fb52d10475..400d38b936 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -369,6 +370,21 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans_0, tf_ans_1)
self.assertAllClose(np_ans_1, tf_ans_2)
+ def testClipByGlobalNormInf(self):
+ with self.test_session(use_gpu=True):
+ x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0],
+ shape=[2, 3])
+ x1 = constant_op.constant([1.0, -2.0])
+ clip_norm = 6.0
+
+ ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ norm.eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ ans[0].eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ ans[1].eval()
+
def testClipByAverageNormClipped(self):
# Norm clipping when average clip_norm < 0.83333333
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index ae6875340e..93f5323c41 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -448,7 +448,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
}
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
- "Tried to explicitly squeeze dimension 2"):
+ "Can not squeeze dim\[2\]"):
dynamic_labels.eval(feed_dict=feed_dict)
self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@@ -475,7 +475,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
label_values, dynamic_labels.eval(feed_dict=feed_dict))
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
- "Tried to explicitly squeeze dimension 2"):
+ "Can not squeeze dim\[2\]"):
dynamic_predictions.eval(feed_dict=feed_dict)
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index f5c6255c34..ba9359d923 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -25,12 +25,15 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
class PartitionerCreatorsTest(test.TestCase):
@@ -543,32 +546,6 @@ class PartitionedVariablesTestCase(test.TestCase):
partitioned_variables.create_partitioned_variables(
[10, 43], [1, 50], rnd.initialized_value())
- def testControlDepsNone(self):
- with self.test_session() as session:
- c = constant_op.constant(1.0)
- with ops.control_dependencies([c]):
- # d get the control dependency.
- d = constant_op.constant(2.0)
- # Partitioned variables do not.
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- ops_before_read = session.graph.get_operations()
- var_x.as_tensor() # Caches the ops for subsequent reads.
- reading_ops = [
- op for op in session.graph.get_operations()
- if op not in ops_before_read
- ]
-
- self.assertEqual([c.op], d.op.control_inputs)
- # Tests that no control dependencies are added to reading a partitioned
- # variable which is similar to reading a variable.
- for op in reading_ops:
- self.assertEqual([], op.control_inputs)
-
def testConcat(self):
with self.test_session() as session:
var_x = variable_scope.get_variable(
@@ -594,6 +571,57 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
+ def testVariableCreationInALoop(self):
+ """Tests the variable created inside a loop can be used outside the loop."""
+ with self.test_session():
+ with variable_scope.variable_scope("ascope") as scope:
+ def Body(i, _):
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(
+ 4))
+ return (i + 1, var_x.as_tensor())
+
+ cond = lambda i, _: i < 2
+ _, x = control_flow_ops.while_loop(
+ cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
+ variables.global_variables_initializer().run()
+ self.assertAllClose([1.0, 1.0], x.eval())
+
+ scope.reuse_variables()
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval())
+
+ def testReadInWhileLoop(self):
+ """Tests the value is current (not cached) when read within a loop."""
+ with self.test_session():
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ def Body(i, _):
+ # Use a SGD step to update the variable's value.
+ loss = math_ops.reduce_sum(var_x)
+ optimizer = gradient_descent.GradientDescentOptimizer(1.0)
+ minimize = optimizer.minimize(loss * 0.7)
+ with ops.control_dependencies([minimize]):
+ return (i + 1, var_x.as_tensor())
+
+ cond = lambda i, _: i < 2
+ _, x = control_flow_ops.while_loop(
+ cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
+ variables.global_variables_initializer().run()
+ self.assertAllClose([-0.4, -0.4], x.eval())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index c739cd2c0d..b1ef46f2a1 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -835,6 +835,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_add(v, [1], [3])
self.assertAllEqual([1.0, 5.0], v.numpy())
+ def testScatterSubStateOps(self):
+ with context.eager_mode():
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub")
+ state_ops.scatter_sub(v, [1], [3])
+ self.assertAllEqual([1.0, -1.0], v.numpy())
+
def testScatterNdAddStateOps(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable(
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index e32d7c4e67..c72ada11da 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -301,14 +301,12 @@ class RNNTest(test.TestCase):
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
def testRNNCellSerialization(self):
- for cell in [
+ for cell in [
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
- # TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras.
- # rnn_cell_impl.BasicRNNCell(
- # 32, activation="relu", dtype=dtypes.float32),
- # rnn_cell_impl.GRUCell(
- # 32, kernel_initializer="ones", dtype=dtypes.float32)
+ rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
+ rnn_cell_impl.GRUCell(
+ 32, kernel_initializer="ones", dtype=dtypes.float32)
]:
with self.test_session():
x = keras.Input((None, 5))
@@ -326,11 +324,13 @@ class RNNTest(test.TestCase):
# not visible as a Keras layer, and also has a name conflict with
# keras.LSTMCell and GRUCell.
layer = keras.layers.RNN.from_config(
- config, custom_objects={
- # "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
- # "GRUCell": rnn_cell_impl.GRUCell,
+ config,
+ custom_objects={
+ "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
+ "GRUCell": rnn_cell_impl.GRUCell,
"LSTMCell": rnn_cell_impl.LSTMCell,
- "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell})
+ "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
+ })
y = layer(x)
model = keras.models.Model(x, y)
model.set_weights(weights)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index a917f51087..4b096cb73d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2662,6 +2662,76 @@ def gather(params, indices, validate_indices=None, name=None, axis=0):
gather.__doc__ = gen_array_ops.gather_v2.__doc__
+@tf_export("batch_gather")
+def batch_gather(params, indices, name=None):
+ """Gather slices from `params` according to `indices` with leading batch dims.
+
+ This operation assumes that the leading dimensions of `indices` are dense,
+ and the gathers on the axis corresponding to the last dimension of `indices`.
+ More concretely it computes:
+
+ result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]]
+
+ Therefore `params` should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
+ `indices` should be a Tensor of shape [A1, ..., AN-1, C] and `result` will be
+ a Tensor of size `[A1, ..., AN-1, C, B1, ..., BM]`.
+
+ In the case in which indices is a 1D tensor, this operation is equivalent to
+ `tf.gather`.
+
+ See also `tf.gather` and `tf.gather_nd`.
+
+ Args:
+ params: A Tensor. The tensor from which to gather values.
+ indices: A Tensor. Must be one of the following types: int32, int64. Index
+ tensor. Must be in range `[0, params.shape[axis]`, where `axis` is the
+ last dimension of `indices` itself.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor. Has the same type as `params`.
+
+ Raises:
+ ValueError: if `indices` has an unknown shape.
+ """
+
+ with ops.name_scope(name):
+ indices = ops.convert_to_tensor(indices, name="indices")
+ params = ops.convert_to_tensor(params, name="params")
+ indices_shape = shape(indices)
+ params_shape = shape(params)
+ ndims = indices.shape.ndims
+ if ndims is None:
+ raise ValueError("batch_gather does not allow indices with unknown "
+ "shape.")
+ batch_indices = indices
+ accum_dim_value = 1
+ for dim in range(ndims-1, 0, -1):
+ dim_value = params_shape[dim-1]
+ accum_dim_value *= params_shape[dim]
+ dim_indices = gen_math_ops._range(0, dim_value, 1)
+ dim_indices *= accum_dim_value
+ dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim),
+ axis=0)
+ batch_indices += reshape(dim_indices, dim_shape)
+
+ flat_indices = reshape(batch_indices, [-1])
+ outer_shape = params_shape[ndims:]
+ flat_inner_shape = gen_math_ops.prod(
+ params_shape[:ndims], [0], False)
+
+ flat_params = reshape(
+ params, concat([[flat_inner_shape], outer_shape], axis=0))
+ flat_result = gather(flat_params, flat_indices)
+ result = reshape(flat_result, concat([indices_shape, outer_shape], axis=0))
+ final_shape = indices.get_shape()[:ndims-1].merge_with(
+ params.get_shape()[:ndims -1])
+ final_shape = final_shape.concatenate(indices.get_shape()[ndims-1])
+ final_shape = final_shape.concatenate(params.get_shape()[ndims:])
+ result.set_shape(final_shape)
+ return result
+
+
# Define quantize_v2 here in order to make name the second-to-last attribute,
# because round_mode was added later.
@tf_export("quantize_v2")
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index e2580e8a2e..78b395a6c1 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import numerics
from tensorflow.python.util.tf_export import tf_export
@@ -57,7 +58,7 @@ def clip_by_value(t, clip_value_min, clip_value_max,
A clipped `Tensor`.
Raises:
- ValueError: if the clip tensors would trigger array broadcasting
+ ValueError: If the clip tensors would trigger array broadcasting
that would make the returned tensor larger than the input.
"""
with ops.name_scope(name, "clip_by_value",
@@ -246,6 +247,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
Raises:
TypeError: If `t_list` is not a sequence.
+ InvalidArgumentError: If global norm is not finite.
"""
if (not isinstance(t_list, collections.Sequence)
or isinstance(t_list, six.string_types)):
@@ -253,6 +255,8 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
t_list = list(t_list)
if use_norm is None:
use_norm = global_norm(t_list, name)
+ use_norm = numerics.verify_tensor_all_finite(use_norm,
+ "Found Inf or NaN global norm.")
with ops.name_scope(name, "clip_by_global_norm",
t_list + [clip_norm]) as name:
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 0e4193e23b..2c61bb232a 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -3658,6 +3658,41 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
+class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
+
+ def testSelectFromThreeClusters(self):
+ boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ max_output_size_np = 5
+ iou_threshold_np = 0.5
+ boxes = constant_op.constant(boxes_np)
+ scores = constant_op.constant(scores_np)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices_padded, num_valid_padded = \
+ image_ops.non_max_suppression_padded(
+ boxes,
+ scores,
+ max_output_size,
+ iou_threshold,
+ pad_to_max_output_size=True)
+ selected_indices, num_valid = image_ops.non_max_suppression_padded(
+ boxes,
+ scores,
+ max_output_size,
+ iou_threshold,
+ pad_to_max_output_size=False)
+ # The output shape of the padded operation must be fully defined.
+ self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
+ self.assertEqual(selected_indices.shape.is_fully_defined(), False)
+ with self.test_session():
+ self.assertAllClose(selected_indices_padded.eval(), [3, 0, 5, 0, 0])
+ self.assertEqual(num_valid_padded.eval(), 3)
+ self.assertAllClose(selected_indices.eval(), [3, 0, 5])
+ self.assertEqual(num_valid.eval(), 3)
+
+
class VerifyCompatibleImageShapesTest(test_util.TensorFlowTestCase):
"""Tests utility function used by ssim() and psnr()."""
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 51fb4cbac8..806539747e 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -193,7 +193,7 @@ def compute_weighted_loss(
gradient, you need to apply `tf.stop_gradient` to `weights` before
passing them to `compute_weighted_loss`.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -266,7 +266,7 @@ def absolute_difference(
`labels` or if the shape of `weights` is invalid or if `labels`
or `predictions` is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -317,7 +317,7 @@ def cosine_distance(
ValueError: If `predictions` shape doesn't match `labels` shape, or
`axis`, `labels`, `predictions` or `weights` is `None`.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -369,7 +369,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
ValueError: If the shapes of `logits` and `labels` don't match or
if `labels` or `logits` is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -437,7 +437,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
if the shape of `weights` is invalid. Also if `labels` or
`predictions` is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -503,7 +503,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
if the shape of `weights` is invalid. Also if `labels` or `predictions`
is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -571,7 +571,7 @@ def mean_pairwise_squared_error(
if the shape of `weights` is invalid. Also if `labels` or `predictions`
is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -654,7 +654,7 @@ def mean_squared_error(
if the shape of `weights` is invalid. Also if `labels` or `predictions`
is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -711,7 +711,7 @@ def sigmoid_cross_entropy(
`multi_class_labels` or if the shape of `weights` is invalid, or if
`weights` is None. Also if `multi_class_labels` or `logits` is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -777,7 +777,7 @@ def softmax_cross_entropy(
or if the shape of `weights` is invalid or if `weights` is None. Also if
`onehot_labels` or `logits` is None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
@@ -894,7 +894,7 @@ def sparse_softmax_cross_entropy(
ValueError: If the shapes of `logits`, `labels`, and `weights` are
incompatible, or if any of them are None.
- @compatbility(eager)
+ @compatibility(eager)
The `loss_collection` argument is ignored when executing eagerly. Consider
holding on to the return value or collecting losses via a `tf.keras.Model`.
@end_compatibility
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 81499bee56..c9da1a0bba 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2130,7 +2130,8 @@ def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
Args:
- inputs: A list of `Tensor` objects, each with same shape and type.
+ inputs: A list of `Tensor` or `IndexedSlices` objects, each with same shape
+ and type.
name: A name for the operation (optional).
Returns:
@@ -2141,17 +2142,21 @@ def add_n(inputs, name=None):
cannot be inferred.
"""
if not inputs or not isinstance(inputs, (list, tuple)):
- raise ValueError("inputs must be a list of at least one Tensor with the "
- "same dtype and shape")
+ raise ValueError("inputs must be a list of at least one"
+ "Tensor/IndexedSlices with the same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
- if not all(isinstance(x, ops.Tensor) for x in inputs):
- raise ValueError("inputs must be a list of at least one Tensor with the "
- "same dtype and shape")
+ if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
+ raise ValueError("inputs must be a list of at least one"
+ "Tensor/IndexedSlices with the same dtype and shape")
if len(inputs) == 1:
+ if isinstance(inputs[0], ops.IndexedSlices):
+ values = inputs[0].values
+ else:
+ values = inputs[0]
if name:
- return array_ops.identity(inputs[0], name=name)
- return inputs[0]
+ return array_ops.identity(values, name=name)
+ return values
return gen_math_ops.add_n(inputs, name=name)
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 3aedeb6acd..9461a01515 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -57,7 +57,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
Furthermore, the final answer should be computed once instead of
in every replica/tower. Both of these are accomplished by
running the computation of the final result value inside
- `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
+ `tf.contrib.distribution_strategy_context.get_tower_context(
+ ).merge_call(fn)`.
Inside the `merge_call()`, ops are only added to the graph once
and access to a tower-local variable in a computation returns
the sum across all replicas/towers.
@@ -373,7 +374,7 @@ def mean(values,
ops.add_to_collections(metrics_collections, mean_t)
return mean_t
- mean_t = distribute_lib.get_tower_context().merge_call(
+ mean_t = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -618,7 +619,7 @@ def _aggregate_variable(v, collections):
ops.add_to_collections(collections, value)
return value
- return distribute_lib.get_tower_context().merge_call(f, v)
+ return distribution_strategy_context.get_tower_context().merge_call(f, v)
@tf_export('metrics.auc')
@@ -813,7 +814,7 @@ def auc(labels,
ops.add_to_collections(metrics_collections, auc_value)
return auc_value
- auc_value = distribute_lib.get_tower_context().merge_call(
+ auc_value = distribution_strategy_context.get_tower_context().merge_call(
aggregate_auc, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
@@ -1053,8 +1054,8 @@ def mean_per_class_accuracy(labels,
ops.add_to_collections(metrics_collections, mean_accuracy_v)
return mean_accuracy_v
- mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
- aggregate_mean_accuracy, count, total)
+ mean_accuracy_v = distribution_strategy_context.get_tower_context(
+ ).merge_call(aggregate_mean_accuracy, count, total)
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
@@ -1160,7 +1161,7 @@ def mean_iou(labels,
ops.add_to_collections(metrics_collections, mean_iou_v)
return mean_iou_v
- mean_iou_v = distribute_lib.get_tower_context().merge_call(
+ mean_iou_v = distribution_strategy_context.get_tower_context().merge_call(
mean_iou_across_towers, total_cm)
if updates_collections:
@@ -1376,7 +1377,7 @@ def mean_tensor(values,
ops.add_to_collections(metrics_collections, mean_t)
return mean_t
- mean_t = distribute_lib.get_tower_context().merge_call(
+ mean_t = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -2008,7 +2009,7 @@ def precision(labels,
ops.add_to_collections(metrics_collections, p)
return p
- p = distribute_lib.get_tower_context().merge_call(
+ p = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, true_p, false_p)
update_op = compute_precision(true_positives_update_op,
@@ -2092,7 +2093,7 @@ def precision_at_thresholds(labels,
ops.add_to_collections(metrics_collections, prec)
return prec
- prec = distribute_lib.get_tower_context().merge_call(
+ prec = distribution_strategy_context.get_tower_context().merge_call(
precision_across_towers, values)
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
@@ -2188,7 +2189,7 @@ def recall(labels,
ops.add_to_collections(metrics_collections, rec)
return rec
- rec = distribute_lib.get_tower_context().merge_call(
+ rec = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, true_p, false_n)
update_op = compute_recall(true_positives_update_op,
@@ -2627,7 +2628,7 @@ def recall_at_top_k(labels,
ops.add_to_collections(metrics_collections, metric)
return metric
- metric = distribute_lib.get_tower_context().merge_call(
+ metric = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, tp, fn)
update = math_ops.div(
@@ -2708,7 +2709,7 @@ def recall_at_thresholds(labels,
ops.add_to_collections(metrics_collections, rec)
return rec
- rec = distribute_lib.get_tower_context().merge_call(
+ rec = distribution_strategy_context.get_tower_context().merge_call(
recall_across_towers, values)
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
@@ -2783,7 +2784,7 @@ def root_mean_squared_error(labels,
ops.add_to_collections(metrics_collections, rmse)
return rmse
- rmse = distribute_lib.get_tower_context().merge_call(
+ rmse = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
@@ -2886,7 +2887,7 @@ def sensitivity_at_specificity(labels,
ops.add_to_collections(metrics_collections, sensitivity)
return sensitivity
- sensitivity = distribute_lib.get_tower_context().merge_call(
+ sensitivity = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, values)
update_op = compute_sensitivity_at_specificity(
@@ -3162,8 +3163,8 @@ def _streaming_sparse_average_precision_at_top_k(labels,
ops.add_to_collections(metrics_collections, mean_average_precision)
return mean_average_precision
- mean_average_precision = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, total_var, max_var)
+ mean_average_precision = distribution_strategy_context.get_tower_context(
+ ).merge_call(aggregate_across_towers, total_var, max_var)
update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
@@ -3448,7 +3449,7 @@ def precision_at_top_k(labels,
ops.add_to_collections(metrics_collections, metric)
return metric
- metric = distribute_lib.get_tower_context().merge_call(
+ metric = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, tp, fp)
update = math_ops.div(
@@ -3687,7 +3688,7 @@ def specificity_at_sensitivity(labels,
ops.add_to_collections(metrics_collections, specificity)
return specificity
- specificity = distribute_lib.get_tower_context().merge_call(
+ specificity = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, values)
update_op = compute_specificity_at_sensitivity(
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index f481726d54..85a6a2233c 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer):
for each `s` in `self.batch_size`.
"""
+ def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
+ super(RNNCell, self).__init__(
+ trainable=trainable, name=name, dtype=dtype, **kwargs)
+ # Attribute that indicates whether the cell is a TF RNN cell, due the slight
+ # difference between TF and Keras RNN cell.
+ self._is_tf_rnn_cell = True
+
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
@@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell):
def get_config(self):
config = {
"num_units": self._num_units,
- "initializer": initializers.serialize(self._initializer),
"kernel_initializer": initializers.serialize(self._kernel_initializer),
+ "bias_initializer": initializers.serialize(self._bias_initializer),
"activation": activations.serialize(self._activation),
"reuse": self._reuse,
}
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 35fc1226ec..d556d11a1b 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -470,3 +470,57 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
+
+
+@tf_export("scatter_sub")
+def scatter_sub(ref, indices, updates, use_locking=False, name=None):
+ r"""Subtracts sparse updates to a variable reference.
+
+ ```python
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+ ```
+
+ This operation outputs `ref` after the update is done.
+ This makes it easier to chain operations that need to use the reset value.
+
+ Duplicate entries are handled correctly: if multiple `indices` reference
+ the same location, their (negated) contributions add.
+
+ Requires `updates.shape = indices.shape + ref.shape[1:]` or
+ `updates.shape = []`.
+
+ <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%"
+ src="https://www.tensorflow.org/images/ScatterSub.png" alt>
+ </div>
+
+ Args:
+ ref: A mutable `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
+ `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
+ `uint32`, `uint64`. Should be from a `Variable` node.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into the first dimension of `ref`.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to subtract from `ref`.
+ use_locking: An optional `bool`. Defaults to `False`.
+ If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ A mutable `Tensor`. Has the same type as `ref`.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_sub(ref, indices, updates,
+ use_locking=use_locking, name=name)
+ return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py
index a793f634bd..b382c3b7ce 100644
--- a/tensorflow/python/ops/summary_op_util.py
+++ b/tensorflow/python/ops/summary_op_util.py
@@ -23,7 +23,7 @@ import re
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging
-from tensorflow.python.training import distribute
+from tensorflow.python.training import distribution_strategy_context
def collect(val, collections, default_collections):
@@ -49,7 +49,7 @@ def skip_summary():
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last tower,
# compute sum or mean across towers).
- tower_context = distribute.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
return tower_context and tower_context.tower_id > 0
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 464c1167d9..402ab2dd9d 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1917,15 +1917,10 @@ class PartitionedVariable(object):
def as_tensor(self):
"""Returns the overall concatenated value as a `Tensor`.
- The returned tensor will not inherit the control dependencies from the scope
- where the value is used, which is similar to getting the value of
- `Variable`.
-
Returns:
`Tensor` containing the concatenated value.
"""
- with ops.control_dependencies(None):
- return self._concat()
+ return self._concat()
@staticmethod
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index f3a6d47500..980320cc66 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -268,7 +268,7 @@ def merge(inputs, collections=None, name=None):
@compatibility(eager)
Not compatible with eager execution. To write TensorBoard
summaries under eager execution, use `tf.contrib.summary` instead.
- @end_compatbility
+ @end_compatibility
"""
# pylint: enable=line-too-long
if _context.executing_eagerly():
@@ -304,7 +304,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None, name=None):
@compatibility(eager)
Not compatible with eager execution. To write TensorBoard
summaries under eager execution, use `tf.contrib.summary` instead.
- @end_compatbility
+ @end_compatibility
"""
if _context.executing_eagerly():
raise RuntimeError(
@@ -336,7 +336,7 @@ def get_summary_description(node_def):
@compatibility(eager)
Not compatible with eager execution. To write TensorBoard
summaries under eager execution, use `tf.contrib.summary` instead.
- @end_compatbility
+ @end_compatibility
"""
if node_def.op != 'TensorSummary':
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index 861a3e920d..16b8626476 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -352,7 +352,7 @@ class FileWriter(SummaryToEventTransformer):
@compatibility(eager)
`FileWriter` is not compatible with eager execution. To write TensorBoard
summaries under eager execution, use `tf.contrib.summary` instead.
- @end_compatbility
+ @end_compatibility
"""
if context.executing_eagerly():
raise RuntimeError(
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 64f0469482..7001e566ce 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -25,7 +25,6 @@ TENSORFLOW_API_INIT_FILES = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
- "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index bc2f3516d1..73d11199d9 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -25,7 +25,6 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
- "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 130fe70beb..acf070075e 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -59,6 +59,21 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
+def _has_variables(sess):
+ """Determines if the graph has any variables.
+
+ Args:
+ sess: TensorFlow Session.
+
+ Returns:
+ Bool.
+ """
+ for op in sess.graph.get_operations():
+ if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
+ return False
+ return True
+
+
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
input_checkpoint,
@@ -152,6 +167,11 @@ def freeze_graph_with_def_protos(input_graph_def,
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
return -1
+ # Models that have been frozen previously do not contain Variables.
+ elif _has_variables(sess):
+ print("No variables were found in this model. It is likely the model "
+ "was frozen previously. You cannot freeze a graph twice.")
+ return 0
else:
raise e
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index aaddc015ed..85f2904318 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -19,16 +19,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import os.path
import re
+import time
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -51,7 +58,9 @@ def _GetCheckpointFilename(save_dir, latest_filename):
@tf_export("train.generate_checkpoint_state_proto")
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
- all_model_checkpoint_paths=None):
+ all_model_checkpoint_paths=None,
+ all_model_checkpoint_timestamps=None,
+ last_preserved_timestamp=None):
"""Generates a checkpoint state proto.
Args:
@@ -61,11 +70,20 @@ def generate_checkpoint_state_proto(save_dir,
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
-
+ all_model_checkpoint_timestamps: A list of floats, indicating the number of
+ seconds since the Epoch when each checkpoint was generated.
+ last_preserved_timestamp: A float, indicating the number of seconds since
+ the Epoch when the last preserved checkpoint was written, e.g. due to a
+ `keep_checkpoint_every_n_hours` parameter (see
+ `tf.contrib.checkpoint.CheckpointManager` for an implementation).
Returns:
CheckpointState proto with model_checkpoint_path and
all_model_checkpoint_paths updated to either absolute paths or
relative paths to the current save_dir.
+
+ Raises:
+ ValueError: If `all_model_checkpoint_timestamps` was provided but its length
+ does not match `all_model_checkpoint_paths`.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
@@ -76,6 +94,14 @@ def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
+ if (all_model_checkpoint_timestamps
+ and (len(all_model_checkpoint_timestamps)
+ != len(all_model_checkpoint_paths))):
+ raise ValueError(
+ ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
+ "paths %s and timestamps %s)")
+ % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
+
# Relative paths need to be rewritten to be relative to the "save_dir"
# if model_checkpoint_path already contains "save_dir".
if not os.path.isabs(save_dir):
@@ -88,7 +114,9 @@ def generate_checkpoint_state_proto(save_dir,
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
+ all_model_checkpoint_paths=all_model_checkpoint_paths,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
return coord_checkpoint_proto
@@ -97,7 +125,9 @@ def generate_checkpoint_state_proto(save_dir,
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
- latest_filename=None):
+ latest_filename=None,
+ all_model_checkpoint_timestamps=None,
+ last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
@@ -112,7 +142,13 @@ def update_checkpoint_state(save_dir,
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
-
+ all_model_checkpoint_timestamps: Optional list of timestamps (floats,
+ seconds since the Epoch) indicating when the checkpoints in
+ `all_model_checkpoint_paths` were created.
+ last_preserved_timestamp: A float, indicating the number of seconds since
+ the Epoch when the last preserved checkpoint was written, e.g. due to a
+ `keep_checkpoint_every_n_hours` parameter (see
+ `tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
@@ -122,14 +158,18 @@ def update_checkpoint_state(save_dir,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
latest_filename=latest_filename,
- save_relative_paths=False)
+ save_relative_paths=False,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
def update_checkpoint_state_internal(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None,
- save_relative_paths=False):
+ save_relative_paths=False,
+ all_model_checkpoint_timestamps=None,
+ last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
@@ -146,6 +186,13 @@ def update_checkpoint_state_internal(save_dir,
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
+ all_model_checkpoint_timestamps: Optional list of timestamps (floats,
+ seconds since the Epoch) indicating when the checkpoints in
+ `all_model_checkpoint_paths` were created.
+ last_preserved_timestamp: A float, indicating the number of seconds since
+ the Epoch when the last preserved checkpoint was written, e.g. due to a
+ `keep_checkpoint_every_n_hours` parameter (see
+ `tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
@@ -168,12 +215,16 @@ def update_checkpoint_state_internal(save_dir,
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
- all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
+ all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
else:
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
+ all_model_checkpoint_paths=all_model_checkpoint_paths,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
@@ -404,3 +455,227 @@ def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
suffixed_filename = ".".join([basename, meta_graph_suffix])
return suffixed_filename
+
+
+# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
+class CheckpointManager(object):
+ """Deletes old checkpoints.
+
+ Example usage:
+ ```python
+ import tensorflow as tf
+ checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
+ manager = tf.contrib.checkpoint.CheckpointManager(
+ checkpoint, directory="/tmp/model", max_to_keep=5)
+ status = checkpoint.restore(manager.latest_checkpoint)
+ while True:
+ # train
+ manager.save()
+ ```
+
+ `CheckpointManager` preserves its own state across instantiations (see the
+ `__init__` documentation for details). Only one should be active in a
+ particular directory at a time.
+ """
+
+ def __init__(self, checkpoint, directory,
+ max_to_keep, keep_checkpoint_every_n_hours=None):
+ """Configure a `CheckpointManager` for use in `directory`.
+
+ If a `CheckpointManager` was previously used in `directory`, its
+ state will be restored. This includes the list of managed checkpoints and
+ the timestamp bookkeeping necessary to support
+ `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
+ will be the same as the previous `CheckpointManager`, including cleaning up
+ existing checkpoints if appropriate.
+
+ Checkpoints are only considered for deletion just after a new checkpoint has
+ been added. At that point, `max_to_keep` checkpoints will remain in an
+ "active set". Once a checkpoint is preserved by
+ `keep_checkpoint_every_n_hours` it will not be deleted by this
+ `CheckpointManager` or any future `CheckpointManager` instantiated in
+ `directory` (regardless of the new setting of
+ `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
+ active set may be deleted by this `CheckpointManager` or a future
+ `CheckpointManager` instantiated in `directory` (subject to its
+ `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
+
+ Args:
+ checkpoint: The `tf.train.Checkpoint` instance to save and manage
+ checkpoints for.
+ directory: The path to a directory in which to write checkpoints. A
+ special file named "checkpoint" is also written to this directory (in a
+ human-readable text format) which contains the state of the
+ `CheckpointManager`.
+ max_to_keep: An integer, the number of checkpoints to keep. Unless
+ preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
+ deleted from the active set, oldest first, until only `max_to_keep`
+ checkpoints remain.
+ keep_checkpoint_every_n_hours: Upon removal from the active set, a
+ checkpoint will be preserved if it has been at least
+ `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
+ default setting of `None` does not preserve any checkpoints in this way.
+
+ Raises:
+ ValueError: If `max_to_keep` is not a positive integer.
+ """
+ self._checkpoint = checkpoint
+ self._save_counter_assign = None
+ if not max_to_keep or max_to_keep < 0:
+ raise ValueError(
+ "Expected a positive integer for `max_to_max_to_keep`, got %d."
+ % (max_to_keep,))
+ self._max_to_keep = max_to_keep
+ self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
+ self._directory = directory
+ self._checkpoint_prefix = os.path.join(directory, "ckpt")
+ recovered_state = get_checkpoint_state(directory)
+ current_clock = time.time()
+ self._maybe_delete = collections.OrderedDict()
+ if recovered_state is None:
+ self._latest_checkpoint = None
+ self._last_preserved_timestamp = current_clock
+ else:
+ self._latest_checkpoint = recovered_state.model_checkpoint_path
+ self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
+ if current_clock < self._last_preserved_timestamp:
+ # Time seems to have reversed itself. In addition to this warning, we'll
+ # min() saved checkpoint timestamps with the current time to ensure that
+ # old checkpoints don't get deleted accidentally.
+ logging.warning(
+ ("time.time() returned a value %f seconds behind the last "
+ "preserved checkpoint timestamp.")
+ % (self._last_preserved_timestamp - current_clock,))
+ self._last_preserved_timestamp = current_clock
+ all_timestamps = recovered_state.all_model_checkpoint_timestamps
+ all_paths = recovered_state.all_model_checkpoint_paths
+ del recovered_state # Uses modified values from now on
+ if not all_timestamps:
+ all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
+
+ for filename, timestamp in zip(all_paths, all_timestamps):
+ timestamp = min(timestamp, current_clock)
+ if timestamp > self._last_preserved_timestamp:
+ self._maybe_delete[filename] = timestamp
+
+ @property
+ def latest_checkpoint(self):
+ """The prefix of the most recent checkpoint in `directory`.
+
+ Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
+ the constructor argument to `CheckpointManager`.
+
+ Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
+
+ Returns:
+ The checkpoint prefix. If there are no checkpoints, returns `None`.
+ """
+ return self._latest_checkpoint
+
+ @property
+ def checkpoints(self):
+ """A list of managed checkpoints.
+
+ Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
+ show up in this list (to avoid ever-growing filename lists).
+
+ Returns:
+ A list of filenames, sorted from oldest to newest.
+ """
+ return list(self._maybe_delete.keys())
+
+ def _sweep(self):
+ """Deletes or preserves managed checkpoints."""
+ while len(self._maybe_delete) > self._max_to_keep:
+ filename, timestamp = self._maybe_delete.popitem(last=False)
+ # Even if we're keeping this checkpoint due to
+ # keep_checkpoint_every_n_hours, we won't reference it to avoid
+ # infinitely-growing CheckpointState protos.
+ if (self._keep_checkpoint_every_n_hours
+ and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
+ >= self._last_preserved_timestamp)):
+ self._last_preserved_timestamp = timestamp
+ continue
+ remove_checkpoint(filename)
+
+ def _record_state(self):
+ """Saves the `CheckpointManager`'s state in `directory`."""
+ filenames, timestamps = zip(*self._maybe_delete.items())
+ update_checkpoint_state_internal(
+ self._directory,
+ model_checkpoint_path=self.latest_checkpoint,
+ all_model_checkpoint_paths=filenames,
+ all_model_checkpoint_timestamps=timestamps,
+ last_preserved_timestamp=self._last_preserved_timestamp,
+ save_relative_paths=True)
+
+ @property
+ def _prefix(self):
+ """A common prefix for all checkpoints saved with this manager.
+
+ For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
+ `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
+ numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
+ checkpoint has several associated files
+ (e.g. `"/tmp/tf-model/ckpt-2.index"`).
+
+ Returns:
+ A string prefix.
+ """
+ return self._checkpoint_prefix
+
+ def save(self, session=None, checkpoint_number=None):
+ """Creates a new checkpoint and manages it.
+
+ Args:
+ session: The session to evaluate variables in. Ignored when executing
+ eagerly. If not provided when graph building, the default session is
+ used.
+ checkpoint_number: An optional integer, or an integer-dtype `Variable` or
+ `Tensor`, used to number the checkpoint. If `None` (default),
+ checkpoints are numbered using `checkpoint.save_counter`. Even if
+ `checkpoint_number` is provided, `save_counter` is still incremented. A
+ user-provided `checkpoint_number` is not incremented even if it is a
+ `Variable`.
+
+ Returns:
+ The path to the new checkpoint. It is also recorded in the `checkpoints`
+ and `latest_checkpoint` properies.
+ """
+ # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
+ # slightly with a custom numbering option.
+ if context.executing_eagerly():
+ save_counter = self._checkpoint.save_counter
+ save_counter.assign_add(1)
+ else:
+ if session is None:
+ session = ops.get_default_session()
+
+ def _initializing_creator(next_creator, **kwargs):
+ """Initialize the save counter if it has been newly created."""
+ v = next_creator(**kwargs)
+ session.run(v.initializer)
+ return v
+
+ with variable_scope.variable_creator_scope(_initializing_creator):
+ save_counter = self._checkpoint.save_counter
+ if self._save_counter_assign is None:
+ self._save_counter_assign = save_counter.assign_add(1, read_value=False)
+ session.run(self._save_counter_assign)
+ if checkpoint_number is None:
+ checkpoint_number = save_counter
+ if not isinstance(checkpoint_number, compat.integral_types):
+ checkpoint_number = training_util.global_step(
+ sess=session, global_step_tensor=checkpoint_number)
+ prefix = "%s-%d" % (self._prefix, checkpoint_number)
+ save_path = self._checkpoint.write(prefix)
+ timestamp = time.time()
+ # If this is an overwritten checkpoint we were previously tracking, delete
+ # and reinsert it to make sure it goes to the end of the queue.
+ if save_path in self._maybe_delete:
+ del self._maybe_delete[save_path]
+ self._maybe_delete[save_path] = timestamp
+ self._latest_checkpoint = save_path
+ self._sweep()
+ self._record_state()
+ return save_path
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 4b31d0c613..1e2827d0a4 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,14 +26,18 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.training.checkpointable import util
class LatestCheckpointWithRelativePaths(test.TestCase):
@@ -312,5 +316,202 @@ class SaverUtilsTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
+class CheckpointManagerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeletion(self):
+ checkpoint = util.Checkpoint()
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, self.get_temp_dir(), max_to_keep=3)
+ first_path = manager.save()
+ second_path = manager.save()
+ third_path = manager.save()
+ fourth_path = manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+
+ @test_util.run_in_graph_and_eager_modes
+ @test.mock.patch.object(checkpoint_management, "time")
+ def testSaveRestoreState(self, mock_time):
+ directory = self.get_temp_dir()
+ mock_time.time.return_value = 3.
+ checkpoint = util.Checkpoint()
+ first_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ first_time = 10000.
+ first_name = os.path.join(directory, "ckpt-1")
+ mock_time.time.return_value = first_time
+ first_manager.save()
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
+ self.assertEqual(3., state.last_preserved_timestamp)
+ second_time = first_time + 3610.
+ second_name = os.path.join(directory, "ckpt-2")
+ mock_time.time.return_value = second_time
+ first_manager.save()
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual([first_time, second_time],
+ state.all_model_checkpoint_timestamps)
+ self.assertEqual(3., state.last_preserved_timestamp)
+ self.assertEqual([first_name, second_name], first_manager.checkpoints)
+ self.assertEqual(second_name, first_manager.latest_checkpoint)
+ del first_manager
+
+ second_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory,
+ max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
+ self.assertEqual([first_name, second_name], second_manager.checkpoints)
+ self.assertEqual(second_name, second_manager.latest_checkpoint)
+ third_name = os.path.join(directory, "ckpt-3")
+ third_time = second_time + 3600. * 0.2
+ mock_time.time.return_value = third_time
+ second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
+ self.assertEqual([second_name, third_name],
+ second_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(first_time, state.last_preserved_timestamp)
+ fourth_time = third_time + 3600. * 0.5
+ mock_time.time.return_value = fourth_time
+ fourth_name = os.path.join(directory, "ckpt-4")
+ second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
+ self.assertEqual([third_name, fourth_name],
+ second_manager.checkpoints)
+ fifth_time = fourth_time + 3600. * 0.5
+ mock_time.time.return_value = fifth_time
+ fifth_name = os.path.join(directory, "ckpt-5")
+ second_manager.save()
+ self.assertEqual([fourth_name, fifth_name],
+ second_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(first_time, state.last_preserved_timestamp)
+ del second_manager
+ third_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory,
+ max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
+ self.assertEqual(fifth_name, third_manager.latest_checkpoint)
+ mock_time.time.return_value += 10.
+ third_manager.save()
+ sixth_name = os.path.join(directory, "ckpt-6")
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(fourth_time, state.last_preserved_timestamp)
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
+ self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
+ self.assertEqual([fifth_name, sixth_name],
+ third_manager.checkpoints)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testContinueFromUnmanaged(self):
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "unusual_prefix")
+ checkpoint = util.Checkpoint()
+ first_path = checkpoint.save(prefix)
+ second_path = checkpoint.save(prefix)
+ del checkpoint
+ checkpoint = util.Checkpoint()
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
+ self.assertEqual(2, self.evaluate(checkpoint.save_counter))
+ third_path = manager.save()
+ self.assertEqual([third_path], manager.checkpoints)
+ fourth_path = manager.save()
+ self.assertEqual([third_path, fourth_path],
+ manager.checkpoints)
+ fifth_path = manager.save()
+ self.assertEqual([fourth_path, fifth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
+
+ @test_util.run_in_graph_and_eager_modes
+ @test.mock.patch.object(checkpoint_management, "time")
+ def testClockReset(self, mock_time):
+ directory = self.get_temp_dir()
+ mock_time.time.return_value = 10000.
+ checkpoint = util.Checkpoint()
+ first_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
+ first_path = first_manager.save()
+ mock_time.time.return_value += 3600.
+ second_path = first_manager.save()
+ mock_time.time.return_value += 3600.
+ third_path = first_manager.save()
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertEqual([third_path], first_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(13600., state.last_preserved_timestamp)
+ # Set the clock back in time
+ mock_time.time.return_value = 5000.
+ del first_manager
+ with test.mock.patch.object(logging, "warning") as mock_log:
+ second_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=1)
+ self.assertRegexpMatches(
+ str(mock_log.call_args),
+ "behind the last preserved checkpoint timestamp")
+ # We should err on the side of keeping checkpoints around when we're not
+ # sure whether they were preserved or not due to clock funkiness.
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ # We know about the existing checkpoints, but they'll never be deleted and
+ # so won't go in the CheckpointState proto on save.
+ self.assertEqual(third_path, second_manager.latest_checkpoint)
+ self.assertEqual([], second_manager.checkpoints)
+ mock_time.time.return_value += 10.
+ fourth_path = second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertEqual(fourth_path, second_manager.latest_checkpoint)
+ self.assertEqual([fourth_path], second_manager.checkpoints)
+ mock_time.time.return_value += 10.
+ fifth_path = second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertEqual([fifth_path], second_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(5000., state.last_preserved_timestamp)
+ self.assertEqual([5020.],
+ state.all_model_checkpoint_timestamps)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCustomNumbering(self):
+ directory = self.get_temp_dir()
+ step = variables.Variable(0, dtype=dtypes.int64)
+ checkpoint = util.Checkpoint(step=step)
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ self.evaluate(step.initializer)
+ for i in range(5):
+ path = manager.save(checkpoint_number=step)
+ expected_suffix = "-%d" % (2 * i,)
+ if not path.endswith(expected_suffix):
+ self.fail("%s should have suffix %s" % (path, expected_suffix))
+ self.evaluate(step.assign_add(2))
+ self.assertEqual(5, self.evaluate(checkpoint.save_counter))
+ # Test regular integers
+ last_path = manager.save(checkpoint_number=32)
+ self.assertIn("-32", last_path)
+ self.assertEqual(last_path, manager.latest_checkpoint)
+ self.assertEqual(
+ last_path, checkpoint_management.latest_checkpoint(directory))
+ state = checkpoint_management.get_checkpoint_state(directory)
+ # Only the most recent two checkpoints are saved
+ self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpoint_state.proto b/tensorflow/python/training/checkpoint_state.proto
index 9172a5c331..704f7fdc88 100644
--- a/tensorflow/python/training/checkpoint_state.proto
+++ b/tensorflow/python/training/checkpoint_state.proto
@@ -4,8 +4,6 @@ package tensorflow;
option cc_enable_arenas = true;
// Protocol buffer representing the checkpoint state.
-//
-// TODO(touts): Add other attributes as needed.
message CheckpointState {
// Path to the most-recent model checkpoint.
string model_checkpoint_path = 1;
@@ -15,4 +13,10 @@ message CheckpointState {
// Note that the value of model_checkpoint_path should be the last item in
// this list.
repeated string all_model_checkpoint_paths = 2;
+ // Unix timestamps corresponding to all_model_checkpoint_paths, indicating
+ // when each checkpoint was created.
+ repeated double all_model_checkpoint_timestamps = 3;
+ // Unix timestamp indicating the creation time for the last preserved
+ // checkpoint.
+ double last_preserved_timestamp = 4;
}
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 9b72b09f08..e6118177fd 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -180,10 +180,10 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
_init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
else:
- distribute_lib.get_tower_context().merge_call(
+ distribution_strategy_context.get_tower_context().merge_call(
_init_from_checkpoint, ckpt_dir_or_file, assignment_map)
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 8a289b31b5..6e9b8ff905 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -118,10 +118,7 @@ py_test(
name = "util_test",
srcs = ["util_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_windows", # TODO: needs investigation on Windows
- "notsan", # b/74395663
- ],
+ tags = ["notsan"], # b/74395663
deps = [
":base",
":tracking",
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 66837ee52f..390434c0a2 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -79,10 +79,6 @@ class CheckpointInitialValue(ops.Tensor):
self.wrapped_value.set_shape(shape)
self._checkpoint_position = checkpoint_position
- @property
- def __class__(self):
- return (self.wrapped_value.__class__, CheckpointInitialValue)
-
def __getattr__(self, attr):
try:
return getattr(self.wrapped_value, attr)
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 6c03bf0d51..e42f989469 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -35,8 +35,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saveable_object as saveable_object_lib
@@ -227,10 +227,11 @@ def _default_getter(name, shape, dtype, initializer=None,
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
- return resource_variable_ops.ResourceVariable(
+ return variables.Variable(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
+ use_resource=True,
**kwargs
)
@@ -1528,8 +1529,6 @@ class Checkpoint(tracking.Checkpointable):
self._maybe_create_save_counter()
return self._save_counter
- # TODO(allenl): Update save's docstring with a pointer to
- # tf.contrib.checkpoint.CheckpointManager once that's in.
def save(self, file_prefix, session=None):
"""Saves a training checkpoint and provides basic checkpoint management.
@@ -1541,7 +1540,8 @@ class Checkpoint(tracking.Checkpointable):
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
- provided by other utilities which also wrap `write`.
+ provided by other utilities which also wrap `write`
+ (`tf.contrib.checkpoint.CheckpointManager` for example).
Args:
file_prefix: A prefix to use for the checkpoint filenames
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 98a42b1c20..a0a87b6b79 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -522,7 +522,6 @@ class CheckpointingTests(test.TestCase):
# Does create garbage when executing eagerly due to ops.Graph() creation.
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
- checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default(), self.test_session(
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
@@ -531,9 +530,9 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = checkpoint_management.latest_checkpoint(
- checkpoint_directory)
- status = root.restore(save_path=checkpoint_path)
+ manager = checkpoint_management.CheckpointManager(
+ root, checkpoint_directory, max_to_keep=1)
+ status = root.restore(save_path=manager.latest_checkpoint)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
optimizer.minimize,
@@ -544,7 +543,7 @@ class CheckpointingTests(test.TestCase):
status.initialize_or_restore()
for _ in range(num_training_steps):
train_fn()
- root.save(file_prefix=checkpoint_prefix)
+ manager.save()
self.assertEqual((training_continuation + 1) * num_training_steps,
self.evaluate(root.global_step))
self.assertEqual(training_continuation + 1,
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 581db45e80..28c60ad809 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import threading
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
+from tensorflow.python.eager import context as eager_context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -31,71 +31,11 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import nest
# ------------------------------------------------------------------------------
-# Internal API for setting the current thread mode as being either in a
-# tower or cross-tower context for a particular distribution strategy.
-
-
-class _ThreadMode(object):
-
- def __init__(self, dist, cross, tower):
- self.distribution_strategy = dist
- self.cross_tower_context = cross
- self.tower_context = tower
-
-
-class _CrossTowerThreadMode(_ThreadMode):
-
- def __init__(self, distribution_strategy):
- _ThreadMode.__init__(
- self, distribution_strategy, distribution_strategy, None)
-
-
-class _InTowerThreadMode(_ThreadMode):
-
- def __init__(self, tower_ctx):
- _ThreadMode.__init__(
- self, tower_ctx.distribution_strategy, None, tower_ctx)
-
-
-_per_thread_mode = threading.local()
-
-
-def _push_per_thread_mode(context):
- if not hasattr(_per_thread_mode, "stack"):
- _per_thread_mode.stack = []
- _per_thread_mode.stack.append(context)
-
-
-def _pop_per_thread_mode():
- _per_thread_mode.stack.pop(-1)
-
-
-class _DefaultTowerThreadMode(_ThreadMode):
- """Type of default value returned by `_get_per_thread_mode()`.
-
- Used when the thread-local stack is empty.
- """
-
- def __init__(self):
- # _default_distribution_strategy and _default_tower_context are
- # defined at the bottom of this file.
- _ThreadMode.__init__(
- self, _default_distribution_strategy, None, _default_tower_context)
-
-
-def _get_per_thread_mode():
- try:
- return _per_thread_mode.stack[-1]
- except (AttributeError, IndexError):
- # _default_tower_mode is defined at the bottom of this file.
- return _default_tower_mode
-
-
-# ------------------------------------------------------------------------------
# Context tracking whether in a distribution.update() or .update_non_slot()
# call.
@@ -128,96 +68,6 @@ class UpdateContext(object):
# ------------------------------------------------------------------------------
-# Public API for accessing the current thread mode
-
-
-def get_tower_context():
- """Returns the current TowerContext or None if in a cross-tower context.
-
- Note that execution:
- 1. starts in the default (single-tower) tower context (this function
- will return the default TowerContext object);
- 2. switches to cross-tower context (in which case this will return
- None) when entering a `with DistributionStrategy.scope():` block;
- 3. switches to a (non-default) tower context inside
- `call_for_each_tower(fn, ...)`;
- 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context (and again
- this function will return None).
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-tower context for the default `DistributionStrategy`. You may
- also switch from the cross-tower context of 4 to a tower context by
- calling `call_for_each_tower()`, jumping back to step 3.
-
- Most `DistributionStrategy` methods may only be executed in
- a cross-tower context, in a tower context you should use the
- `TowerContext` API instead.
-
- Returns:
- The current `TowerContext` object when in a tower context scope, else None.
-
- Exactly one of `get_tower_context()` and `get_cross_tower_context()`
- will return None in a particular block.
- """
- return _get_per_thread_mode().tower_context
-
-
-def get_cross_tower_context():
- """Returns the current DistributionStrategy if in a cross-tower context.
-
- Note that execution:
- 1. starts in the default (single-tower) tower context;
- 2. switches to cross-tower context when entering a
- `with DistributionStrategy.scope():` block;
- 3. switches to a (non-default) tower context inside
- `call_for_each_tower(fn, ...)`;
- 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context.
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-tower context for the default `DistributionStrategy`. You may
- also switch from the cross-tower context of 4 to a tower context by
- calling `call_for_each_tower()`, jumping back to step 3.
-
- Most `DistributionStrategy` methods may only be executed in
- a cross-tower context.
-
- Returns:
- Returns the current `DistributionStrategy` object in a cross-tower
- context, or None.
-
- Exactly one of `get_tower_context()` and `get_cross_tower_context()`
- will return None in a particular block.
- """
- return _get_per_thread_mode().cross_tower_context
-
-
-def get_distribution_strategy():
- """Returns the current `DistributionStrategy` object.
-
- Prefer to use `get_tower_context()` or `get_cross_tower_context()`
- instead when possible.
-
- Returns:
- A `DistributionStrategy` object. Inside a
- `with distribution_strategy.scope()` block, it returns
- `distribution_strategy`, otherwise it returns the default
- (single-tower) `DistributionStrategy` object.
- """
- return _get_per_thread_mode().distribution_strategy
-
-
-def has_distribution_strategy():
- """Return if there is a current non-default `DistributionStrategy`.
-
- Returns:
- True if inside a `with distribution_strategy.scope():`.
- """
- return get_distribution_strategy() is not _default_distribution_strategy
-
-
-# ------------------------------------------------------------------------------
# Public utility functions.
@@ -239,7 +89,8 @@ def _require_cross_tower_context(distribution_strategy):
if context.cross_tower_context is distribution_strategy: return
# We have an error to report, figure out the right message.
if context.distribution_strategy is not distribution_strategy:
- if context.distribution_strategy is _default_distribution_strategy:
+ if (context.distribution_strategy is
+ distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
@@ -272,7 +123,8 @@ def _require_distribution_strategy_scope(distribution_strategy):
context = _get_per_thread_mode()
if context.distribution_strategy is distribution_strategy: return
# We have an error to report, figure out the right message.
- if context.distribution_strategy is _default_distribution_strategy:
+ if (context.distribution_strategy is
+ distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
@@ -295,7 +147,8 @@ class _CurrentDistributionContext(object):
var_creator_scope,
var_scope=None,
default_device=None):
- self._context = _CrossTowerThreadMode(distribution_strategy)
+ self._context = distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access
+ distribution_strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
if default_device:
@@ -588,7 +441,7 @@ class DistributionStrategy(object):
Returns:
A context manager.
"""
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
_require_cross_tower_context(self)
return _SameScopeAgainContext(self)
@@ -740,7 +593,7 @@ class DistributionStrategy(object):
In eager mode, returns `None`.
In graph mode, a list of ops to execute. Empty list if nothing to be done.
"""
- if context.executing_eagerly():
+ if eager_context.executing_eagerly():
return
else:
return []
@@ -757,7 +610,7 @@ class DistributionStrategy(object):
In eager mode, returns `None`.
In graph mode, a list of ops to execute. Empty list if nothing to be done.
"""
- if context.executing_eagerly():
+ if eager_context.executing_eagerly():
return
else:
return []
@@ -1077,9 +930,37 @@ class DistributionStrategy(object):
def _worker_device_index(self):
raise NotImplementedError("must be implemented in descendants")
- def configure(self, session_config=None):
- """Find the best configuration given a tensorflow session config."""
- del session_config
+ @property
+ def between_graph(self):
+ """Whether the strategy uses between-graph replication or not.
+
+ This is expected to return a constant value that will not be changed
+ throughout its life cycle.
+ """
+ raise NotImplementedError("must be implemented in descendants")
+
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the strategy class."""
+ del session_config, cluster_spec, task_type, task_id
+
+ @property
+ def should_init(self):
+ """Whether initialization is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def should_checkpoint(self):
+ """Whether checkpointing is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def should_save_summary(self):
+ """Whether saving summaries is needed."""
+ raise NotImplementedError("must be implemented in descendants")
# A note about the difference between the context managers
@@ -1106,7 +987,8 @@ class TowerContext(object):
def __init__(self, distribution_strategy, tower_id):
self._distribution_strategy = distribution_strategy
- self._thread_context = _InTowerThreadMode(self)
+ self._thread_context = distribution_strategy_context._InTowerThreadMode( # pylint: disable=protected-access
+ self)
self._tower_id = tower_id
def __enter__(self):
@@ -1149,7 +1031,8 @@ class TowerContext(object):
def _merge_call(self, merge_fn, *args, **kwargs):
"""Default implementation for single tower."""
_push_per_thread_mode( # thread-local, so not needed with multiple threads
- _CrossTowerThreadMode(self._distribution_strategy))
+ distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access
+ self._distribution_strategy))
try:
return merge_fn(self._distribution_strategy, *args, **kwargs)
finally:
@@ -1196,7 +1079,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def scope(self):
"""Context manager setting a variable creator and `self` as current."""
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
raise RuntimeError("Must not nest DistributionStrategy scopes.")
def creator(next_creator, *args, **kwargs):
@@ -1277,6 +1160,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
raise RuntimeError("worker_device_index() method unsupported by "
"_DefaultDistributionStrategy.")
+
# ------------------------------------------------------------------------------
# Common operations
@@ -1292,20 +1176,11 @@ def increment_var(v, amount=1):
def merge_fn(dist, vm):
return dist.group(dist.update(vm, update))
- tower_context = get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
# ------------------------------------------------------------------------------
-# Singletons
-
-_default_distribution_strategy = _DefaultDistributionStrategy()
-_default_tower_context = TowerContext(
- _default_distribution_strategy, tower_id=0)
-_default_tower_mode = _DefaultTowerThreadMode()
-
-
-# ------------------------------------------------------------------------------
# We haven't yet implemented deserialization for DistributedVariables.
# So here we catch any attempts to deserialize variables
# when using distribution strategies.
@@ -1314,7 +1189,7 @@ _original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None):
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using"
"distributed strategies.")
@@ -1323,3 +1198,10 @@ def _from_proto_fn(v, import_scope=None):
resource_variable_ops._from_proto_fn = _from_proto_fn
# pylint: enable=protected-access
+
+
+#-------------------------------------------------------------------------------
+# Shorthand for some methods from distribution_strategy_context.
+_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
+_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
+_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 694145ede7..f03bd39100 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import distribute
+from tensorflow.python.training import distribution_strategy_context
class _TestTowerContext(distribute.TowerContext):
@@ -49,12 +50,12 @@ class _TestStrategy(distribute.DistributionStrategy):
def _assert_in_default_state(t):
- t.assertIs(distribute._default_tower_context,
- distribute.get_tower_context())
- t.assertIs(None, distribute.get_cross_tower_context())
- t.assertIs(distribute._default_distribution_strategy,
- distribute.get_distribution_strategy())
- t.assertFalse(distribute.has_distribution_strategy())
+ t.assertIs(distribution_strategy_context._get_default_tower_context(),
+ distribution_strategy_context.get_tower_context())
+ t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
+ t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
+ distribution_strategy_context.get_distribution_strategy())
+ t.assertFalse(distribution_strategy_context.has_distribution_strategy())
class TestStrategyTest(test.TestCase):
@@ -64,11 +65,13 @@ class TestStrategyTest(test.TestCase):
dist = _TestStrategy()
def run_fn():
- tower_context = distribute.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
self.assertTrue(tower_context is not None)
- self.assertIs(None, distribute.get_cross_tower_context())
- self.assertTrue(distribute.has_distribution_strategy())
- self.assertIs(dist, distribute.get_distribution_strategy())
+ self.assertIs(None,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertTrue(distribution_strategy_context.has_distribution_strategy())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
@@ -86,10 +89,12 @@ class TestStrategyTest(test.TestCase):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
- self.assertIs(None, distribute.get_tower_context())
- self.assertIs(dist, distribute.get_cross_tower_context())
- self.assertTrue(distribute.has_distribution_strategy())
- self.assertIs(dist, distribute.get_distribution_strategy())
+ self.assertIs(None, distribution_strategy_context.get_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertTrue(distribution_strategy_context.has_distribution_strategy())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
@@ -120,15 +125,21 @@ class DefaultDistributionStrategyTest(test.TestCase):
_assert_in_default_state(self)
def merge_fn(dist, s):
- self.assertIs(distribute._default_distribution_strategy, dist)
- self.assertIs(None, distribute.get_tower_context())
- self.assertIs(dist, distribute.get_cross_tower_context())
- self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertFalse(distribute.has_distribution_strategy())
+ self.assertIs(
+ distribution_strategy_context._get_default_distribution_strategy(),
+ dist)
+ self.assertIs(None, distribution_strategy_context.get_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
+ self.assertFalse(
+ distribution_strategy_context.has_distribution_strategy())
return "foo_" + s
- tower_ctx = distribute.get_tower_context()
- self.assertIs(distribute._default_tower_context, tower_ctx)
+ tower_ctx = distribution_strategy_context.get_tower_context()
+ self.assertIs(distribution_strategy_context._get_default_tower_context(),
+ tower_ctx)
self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
_assert_in_default_state(self)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
new file mode 100644
index 0000000000..998b5c35ce
--- /dev/null
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -0,0 +1,203 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to get distribution strategy related contexts."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+
+# There is a circular dependency between this and `distribute` module. So we
+# load it lazily to workaround this.
+distribute_lib = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training.distribute")
+
+# ------------------------------------------------------------------------------
+# Internal API for setting the current thread mode as being either in a
+# tower or cross-tower context for a particular distribution strategy.
+
+
+class _ThreadMode(object):
+
+ def __init__(self, dist, cross, tower):
+ self.distribution_strategy = dist
+ self.cross_tower_context = cross
+ self.tower_context = tower
+
+
+class _CrossTowerThreadMode(_ThreadMode):
+
+ def __init__(self, distribution_strategy):
+ _ThreadMode.__init__(
+ self, distribution_strategy, distribution_strategy, None)
+
+
+class _InTowerThreadMode(_ThreadMode):
+
+ def __init__(self, tower_ctx):
+ _ThreadMode.__init__(
+ self, tower_ctx.distribution_strategy, None, tower_ctx)
+
+
+def _push_per_thread_mode(context):
+ ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
+
+
+def _pop_per_thread_mode():
+ ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
+
+
+class _DefaultTowerThreadMode(_ThreadMode):
+ """Type of default value returned by `_get_per_thread_mode()`.
+
+ Used when the thread-local stack is empty.
+ """
+
+ def __init__(self):
+ _ThreadMode.__init__(self, _get_default_distribution_strategy(), None,
+ _get_default_tower_context())
+
+
+def _get_per_thread_mode():
+ try:
+ return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
+ except (AttributeError, IndexError):
+ return _get_default_tower_mode()
+
+
+# ------------------------------------------------------------------------------
+# Public API for accessing the current thread mode
+
+
+def get_tower_context():
+ """Returns the current TowerContext or None if in a cross-tower context.
+
+ Note that execution:
+ 1. starts in the default (single-tower) tower context (this function
+ will return the default TowerContext object);
+ 2. switches to cross-tower context (in which case this will return
+ None) when entering a `with DistributionStrategy.scope():` block;
+ 3. switches to a (non-default) tower context inside
+ `call_for_each_tower(fn, ...)`;
+ 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
+ inside `merge_fn` you are back in the cross-tower context (and again
+ this function will return None).
+
+ Note that you can also go directly from step 1 to 4 to switch to a
+ cross-tower context for the default `DistributionStrategy`. You may
+ also switch from the cross-tower context of 4 to a tower context by
+ calling `call_for_each_tower()`, jumping back to step 3.
+
+ Most `DistributionStrategy` methods may only be executed in
+ a cross-tower context, in a tower context you should use the
+ `TowerContext` API instead.
+
+ Returns:
+ The current `TowerContext` object when in a tower context scope, else None.
+
+ Exactly one of `get_tower_context()` and `get_cross_tower_context()`
+ will return None in a particular block.
+ """
+ return _get_per_thread_mode().tower_context
+
+
+def get_cross_tower_context():
+ """Returns the current DistributionStrategy if in a cross-tower context.
+
+ Note that execution:
+ 1. starts in the default (single-tower) tower context;
+ 2. switches to cross-tower context when entering a
+ `with DistributionStrategy.scope():` block;
+ 3. switches to a (non-default) tower context inside
+ `call_for_each_tower(fn, ...)`;
+ 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
+ inside `merge_fn` you are back in the cross-tower context.
+
+ Note that you can also go directly from step 1 to 4 to switch to a
+ cross-tower context for the default `DistributionStrategy`. You may
+ also switch from the cross-tower context of 4 to a tower context by
+ calling `call_for_each_tower()`, jumping back to step 3.
+
+ Most `DistributionStrategy` methods may only be executed in
+ a cross-tower context.
+
+ Returns:
+ Returns the current `DistributionStrategy` object in a cross-tower
+ context, or None.
+
+ Exactly one of `get_tower_context()` and `get_cross_tower_context()`
+ will return None in a particular block.
+ """
+ return _get_per_thread_mode().cross_tower_context
+
+
+def get_distribution_strategy():
+ """Returns the current `DistributionStrategy` object.
+
+ Prefer to use `get_tower_context()` or `get_cross_tower_context()`
+ instead when possible.
+
+ Returns:
+ A `DistributionStrategy` object. Inside a
+ `with distribution_strategy.scope()` block, it returns
+ `distribution_strategy`, otherwise it returns the default
+ (single-tower) `DistributionStrategy` object.
+ """
+ return _get_per_thread_mode().distribution_strategy
+
+
+def has_distribution_strategy():
+ """Return if there is a current non-default `DistributionStrategy`.
+
+ Returns:
+ True if inside a `with distribution_strategy.scope():`.
+ """
+ return get_distribution_strategy() is not _get_default_distribution_strategy()
+
+
+# ------------------------------------------------------------------------------
+# Defaults that are used when no distribution strategy is explicitly created.
+# We create them lazily in a function so that we can workaround the circular
+# dependency on distribute_lib. See lazy loader at the top of this file.
+
+_defaults = {
+ "distribution_strategy": None,
+ "tower_context": None,
+ "tower_mode": None
+}
+
+
+def _get_default_distribution_strategy():
+ if _defaults["distribution_strategy"] is None:
+ _defaults["distribution_strategy"] = (
+ distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access
+ return _defaults["distribution_strategy"]
+
+
+def _get_default_tower_context():
+ if _defaults["tower_context"] is None:
+ _defaults["tower_context"] = distribute_lib.TowerContext(
+ _get_default_distribution_strategy(), tower_id=0)
+ return _defaults["tower_context"]
+
+
+def _get_default_tower_mode():
+ if _defaults["tower_mode"] is None:
+ _defaults["tower_mode"] = _DefaultTowerThreadMode()
+ return _defaults["tower_mode"]
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 6d95b144d5..1b6bce2865 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -464,7 +465,8 @@ class Optimizer(
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
if (distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN):
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -482,7 +484,8 @@ class Optimizer(
# Scale loss if using a "mean" loss reduction and multiple towers.
if (distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN):
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -548,15 +551,15 @@ class Optimizer(
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
# Handle DistributionStrategy case.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-tower context.")
# TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
# always calling _distributed_apply(), using the default distribution
# as needed.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, grads_and_vars, global_step, name)
# No DistributionStrategy case.
@@ -799,7 +802,8 @@ class Optimizer(
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_checkpointable()
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 258a6f045d..d76b22acd8 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
@@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
else:
@@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index e742f8e8d5..d4d97087ba 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -30,6 +30,7 @@ cc_library(
hdrs = STREAM_EXECUTOR_HEADERS,
linkopts = select({
"//tensorflow:freebsd": [],
+ "//tensorflow:windows": [],
"//conditions:default": ["-ldl"],
}),
visibility = ["//visibility:public"],
@@ -79,6 +80,7 @@ cc_library(
}),
linkopts = select({
"//tensorflow:freebsd": [],
+ "//tensorflow:windows": [],
"//conditions:default": ["-ldl"],
}),
visibility = ["//visibility:public"],
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index 858396ef96..7ba1f18101 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
uint64 size) override;
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return false; }
+ bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override;
bool SynchronousMemSet(DeviceMemoryBase *location, int value,
diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc
index 5a7d3b3dd4..bfbfb56cd7 100644
--- a/tensorflow/stream_executor/host/host_stream.cc
+++ b/tensorflow/stream_executor/host/host_stream.cc
@@ -28,18 +28,28 @@ HostStream::HostStream()
HostStream::~HostStream() {}
bool HostStream::EnqueueTask(std::function<void()> task) {
+ struct NotifiedTask {
+ HostStream* stream;
+ std::function<void()> task;
+
+ void operator()() {
+ task();
+ // Destroy the task before unblocking its waiters, as BlockHostUntilDone()
+ // should guarantee that all tasks are destroyed.
+ task = std::function<void()>();
+ {
+ mutex_lock lock(stream->mu_);
+ --stream->pending_tasks_;
+ }
+ stream->completion_condition_.notify_all();
+ }
+ };
+
{
mutex_lock lock(mu_);
++pending_tasks_;
}
- host_executor_->Schedule([this, task]() {
- task();
- {
- mutex_lock lock(mu_);
- --pending_tasks_;
- }
- completion_condition_.notify_all();
- });
+ host_executor_->Schedule(NotifiedTask{this, std::move(task)});
return true;
}
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 4f89aebe90..2f19147dbb 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -4,12 +4,12 @@
# Uses the ":optmode" config_setting to pick the options.
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
- "tf_cuda_tests_tags",
- "tf_sycl_tests_tags",
+ "if_dynamic_kernels",
+ "if_static",
"tf_additional_grpc_deps_py",
"tf_additional_xla_deps_py",
- "if_static",
- "if_dynamic_kernels",
+ "tf_cuda_tests_tags",
+ "tf_sycl_tests_tags",
)
load(
"@local_config_tensorrt//:build_defs.bzl",
@@ -17,17 +17,19 @@ load(
)
load(
"@local_config_cuda//cuda:build_defs.bzl",
- "if_cuda",
"cuda_default_copts",
+ "if_cuda",
)
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
- "if_mkl_lnx_x64"
+ "if_mkl_lnx_x64",
+ "if_mkl_ml",
+ "mkl_deps",
)
load(
"//third_party/mkl_dnn:build_defs.bzl",
- "if_mkl_open_source_only"
+ "if_mkl_open_source_only",
)
load(
@@ -42,155 +44,154 @@ def register_extension_info(**kwargs):
# i.e. "common_runtime/direct_session_test.cc" becomes
# "common_runtime_direct_session_test"
def src_to_test_name(src):
- return src.replace("/", "_").split(".")[0]
+ return src.replace("/", "_").split(".")[0]
def full_path(relative_paths):
- return [native.package_name() + "/" + relative for relative in relative_paths]
+ return [native.package_name() + "/" + relative for relative in relative_paths]
def _add_tfcore_prefix(src):
- if src.startswith("//"):
- return src
- return "//tensorflow/core:" + src
+ if src.startswith("//"):
+ return src
+ return "//tensorflow/core:" + src
# List of proto files for android builds
def tf_android_core_proto_sources(core_proto_sources_relative):
- return [
- _add_tfcore_prefix(p) for p in core_proto_sources_relative
- ]
+ return [
+ _add_tfcore_prefix(p)
+ for p in core_proto_sources_relative
+ ]
# Returns the list of pb.h and proto.h headers that are generated for
# tf_android_core_proto_sources().
def tf_android_core_proto_headers(core_proto_sources_relative):
- return ([
- _add_tfcore_prefix(p).replace(":", "/").replace(".proto", ".pb.h")
- for p in core_proto_sources_relative
- ] + [
- _add_tfcore_prefix(p).replace(":", "/").replace(".proto", ".proto.h")
- for p in core_proto_sources_relative
- ])
+ return ([
+ _add_tfcore_prefix(p).replace(":", "/").replace(".proto", ".pb.h")
+ for p in core_proto_sources_relative
+ ] + [
+ _add_tfcore_prefix(p).replace(":", "/").replace(".proto", ".proto.h")
+ for p in core_proto_sources_relative
+ ])
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
- return str(Label(dep))
+ return str(Label(dep))
def if_android_x86(a):
- return select({
- clean_dep("//tensorflow:android_x86"): a,
- clean_dep("//tensorflow:android_x86_64"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:android_x86"): a,
+ clean_dep("//tensorflow:android_x86_64"): a,
+ "//conditions:default": [],
+ })
def if_android_arm(a):
- return select({
- clean_dep("//tensorflow:android_arm"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:android_arm"): a,
+ "//conditions:default": [],
+ })
def if_android_arm64(a):
- return select({
- clean_dep("//tensorflow:android_arm64"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:android_arm64"): a,
+ "//conditions:default": [],
+ })
def if_android_mips(a):
- return select({
- clean_dep("//tensorflow:android_mips"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:android_mips"): a,
+ "//conditions:default": [],
+ })
def if_not_android(a):
- return select({
- clean_dep("//tensorflow:android"): [],
- "//conditions:default": a,
- })
+ return select({
+ clean_dep("//tensorflow:android"): [],
+ "//conditions:default": a,
+ })
def if_not_android_mips_and_mips64(a):
- return select({
- clean_dep("//tensorflow:android_mips"): [],
- clean_dep("//tensorflow:android_mips64"): [],
- "//conditions:default": a,
- })
+ return select({
+ clean_dep("//tensorflow:android_mips"): [],
+ clean_dep("//tensorflow:android_mips64"): [],
+ "//conditions:default": a,
+ })
def if_android(a):
- return select({
- clean_dep("//tensorflow:android"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:android"): a,
+ "//conditions:default": [],
+ })
def if_ios(a):
- return select({
- clean_dep("//tensorflow:ios"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:ios"): a,
+ "//conditions:default": [],
+ })
def if_ios_x86_64(a):
- return select({
- clean_dep("//tensorflow:ios_x86_64"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:ios_x86_64"): a,
+ "//conditions:default": [],
+ })
def if_mobile(a):
- return select({
- clean_dep("//tensorflow:android"): a,
- clean_dep("//tensorflow:ios"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:android"): a,
+ clean_dep("//tensorflow:ios"): a,
+ "//conditions:default": [],
+ })
def if_not_mobile(a):
- return select({
- clean_dep("//tensorflow:android"): [],
- clean_dep("//tensorflow:ios"): [],
- "//conditions:default": a,
- })
+ return select({
+ clean_dep("//tensorflow:android"): [],
+ clean_dep("//tensorflow:ios"): [],
+ "//conditions:default": a,
+ })
# Config setting selector used when building for products
# which requires restricted licenses to be avoided.
def if_not_lgpl_restricted(a):
- _ = (a,)
- return select({
- "//conditions:default": [],
- })
+ _ = (a,)
+ return select({
+ "//conditions:default": [],
+ })
def if_not_windows(a):
- return select({
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- "//conditions:default": a,
- })
+ return select({
+ clean_dep("//tensorflow:windows"): [],
+ "//conditions:default": a,
+ })
def if_windows(a):
- return select({
- clean_dep("//tensorflow:windows"): a,
- clean_dep("//tensorflow:windows_msvc"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:windows"): a,
+ "//conditions:default": [],
+ })
def if_not_windows_cuda(a):
- return select({
- clean_dep("//tensorflow:with_cuda_support_windows_override"): [],
- "//conditions:default": a,
- })
+ return select({
+ clean_dep("//tensorflow:with_cuda_support_windows_override"): [],
+ "//conditions:default": a,
+ })
def if_linux_x86_64(a):
- return select({
- clean_dep("//tensorflow:linux_x86_64"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:linux_x86_64"): a,
+ "//conditions:default": [],
+ })
def if_darwin(a):
- return select({
- clean_dep("//tensorflow:darwin"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:darwin"): a,
+ "//conditions:default": [],
+ })
def if_override_eigen_strong_inline(a):
- return select({
- clean_dep("//tensorflow:override_eigen_strong_inline"): a,
- "//conditions:default": [],
- })
+ return select({
+ clean_dep("//tensorflow:override_eigen_strong_inline"): a,
+ "//conditions:default": [],
+ })
-def get_win_copts(is_external=False):
+def get_win_copts(is_external = False):
WINDOWS_COPTS = [
"/DPLATFORM_WINDOWS",
"/DEIGEN_HAS_C99_MATH",
@@ -208,164 +209,170 @@ def get_win_copts(is_external=False):
"/DNOGDI",
]
if is_external:
- return WINDOWS_COPTS + ["/UTF_COMPILE_LIBRARY"]
+ return WINDOWS_COPTS + ["/UTF_COMPILE_LIBRARY"]
else:
- return WINDOWS_COPTS + ["/DTF_COMPILE_LIBRARY"]
+ return WINDOWS_COPTS + ["/DTF_COMPILE_LIBRARY"]
# LINT.IfChange
-def tf_copts(android_optimization_level_override="-O2", is_external=False):
- # For compatibility reasons, android_optimization_level_override
- # is currently only being set for Android.
- # To clear this value, and allow the CROSSTOOL default
- # to be used, pass android_optimization_level_override=None
- android_copts = [
- "-std=c++11",
- "-DTF_LEAN_BINARY",
- "-Wno-narrowing",
- "-fomit-frame-pointer",
- ]
- if android_optimization_level_override:
- android_copts.append(android_optimization_level_override)
- return (
- if_not_windows([
- "-DEIGEN_AVOID_STL_ARRAY",
- "-Iexternal/gemmlowp",
- "-Wno-sign-compare",
- "-fno-exceptions",
- "-ftemplate-depth=900"])
- + if_cuda(["-DGOOGLE_CUDA=1"])
- + if_tensorrt(["-DGOOGLE_TENSORRT=1"])
- + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"])
- + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"])
- + if_ngraph(["-DINTEL_NGRAPH=1"])
- + if_mkl_lnx_x64(["-fopenmp"])
- + if_android_arm(["-mfpu=neon"])
- + if_linux_x86_64(["-msse3"])
- + if_ios_x86_64(["-msse4.1"])
- + select({
+def tf_copts(android_optimization_level_override = "-O2", is_external = False):
+ # For compatibility reasons, android_optimization_level_override
+ # is currently only being set for Android.
+ # To clear this value, and allow the CROSSTOOL default
+ # to be used, pass android_optimization_level_override=None
+ android_copts = [
+ "-std=c++11",
+ "-DTF_LEAN_BINARY",
+ "-Wno-narrowing",
+ "-fomit-frame-pointer",
+ ]
+ if android_optimization_level_override:
+ android_copts.append(android_optimization_level_override)
+ return (
+ if_not_windows([
+ "-DEIGEN_AVOID_STL_ARRAY",
+ "-Iexternal/gemmlowp",
+ "-Wno-sign-compare",
+ "-fno-exceptions",
+ "-ftemplate-depth=900",
+ ]) +
+ if_cuda(["-DGOOGLE_CUDA=1"]) +
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
+ if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
+ if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_ngraph(["-DINTEL_NGRAPH=1"]) +
+ if_mkl_lnx_x64(["-fopenmp"]) +
+ if_android_arm(["-mfpu=neon"]) +
+ if_linux_x86_64(["-msse3"]) +
+ if_ios_x86_64(["-msse4.1"]) +
+ select({
clean_dep("//tensorflow:framework_shared_object"): [],
"//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"],
- })
- + select({
+ }) +
+ select({
clean_dep("//tensorflow:android"): android_copts,
clean_dep("//tensorflow:darwin"): [],
clean_dep("//tensorflow:windows"): get_win_copts(is_external),
- clean_dep("//tensorflow:windows_msvc"): get_win_copts(is_external),
clean_dep("//tensorflow:ios"): ["-std=c++11"],
clean_dep("//tensorflow:no_lgpl_deps"): ["-D__TENSORFLOW_NO_LGPL_DEPS__", "-pthread"],
- "//conditions:default": ["-pthread"]
- }))
-
+ "//conditions:default": ["-pthread"],
+ })
+ )
def tfe_xla_copts():
- return select({
- "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"],
+ "//conditions:default": [],
+ })
def tf_opts_nortti_if_android():
- return if_android([
- "-fno-rtti",
- "-DGOOGLE_PROTOBUF_NO_RTTI",
- "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER",
- ])
+ return if_android([
+ "-fno-rtti",
+ "-DGOOGLE_PROTOBUF_NO_RTTI",
+ "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER",
+ ])
# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
def tf_features_nomodules_if_android():
- return if_android(["-use_header_modules"])
+ return if_android(["-use_header_modules"])
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
-def tf_gen_op_libs(op_lib_names, deps=None, is_external=True):
- # Make library out of each op so it can also be used to generate wrappers
- # for various languages.
- if not deps:
- deps = []
- for n in op_lib_names:
- native.cc_library(
- name=n + "_op_lib",
- copts=tf_copts(is_external=is_external),
- srcs=["ops/" + n + ".cc"],
- deps=deps + [clean_dep("//tensorflow/core:framework")],
- visibility=["//visibility:public"],
- alwayslink=1,
- linkstatic=1,)
+def tf_gen_op_libs(op_lib_names, deps = None, is_external = True):
+ # Make library out of each op so it can also be used to generate wrappers
+ # for various languages.
+ if not deps:
+ deps = []
+ for n in op_lib_names:
+ native.cc_library(
+ name = n + "_op_lib",
+ copts = tf_copts(is_external = is_external),
+ srcs = ["ops/" + n + ".cc"],
+ deps = deps + [clean_dep("//tensorflow/core:framework")],
+ visibility = ["//visibility:public"],
+ alwayslink = 1,
+ linkstatic = 1,
+ )
def _make_search_paths(prefix, levels_to_root):
- return ",".join(
- ["-rpath,%s/%s" % (prefix, "/".join([".."] * search_level))
- for search_level in range(levels_to_root + 1)])
+ return ",".join(
+ [
+ "-rpath,%s/%s" % (prefix, "/".join([".."] * search_level))
+ for search_level in range(levels_to_root + 1)
+ ],
+ )
def _rpath_linkopts(name):
- # Search parent directories up to the TensorFlow root directory for shared
- # object dependencies, even if this op shared object is deeply nested
- # (e.g. tensorflow/contrib/package:python/ops/_op_lib.so). tensorflow/ is then
- # the root and tensorflow/libtensorflow_framework.so should exist when
- # deployed. Other shared object dependencies (e.g. shared between contrib/
- # ops) are picked up as long as they are in either the same or a parent
- # directory in the tensorflow/ tree.
- levels_to_root = native.package_name().count("/") + name.count("/")
- return select({
- clean_dep("//tensorflow:darwin"): [
- "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),),
- ],
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- "//conditions:default": [
- "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),),
- ],
- })
+ # Search parent directories up to the TensorFlow root directory for shared
+ # object dependencies, even if this op shared object is deeply nested
+ # (e.g. tensorflow/contrib/package:python/ops/_op_lib.so). tensorflow/ is then
+ # the root and tensorflow/libtensorflow_framework.so should exist when
+ # deployed. Other shared object dependencies (e.g. shared between contrib/
+ # ops) are picked up as long as they are in either the same or a parent
+ # directory in the tensorflow/ tree.
+ levels_to_root = native.package_name().count("/") + name.count("/")
+ return select({
+ clean_dep("//tensorflow:darwin"): [
+ "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),),
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ "//conditions:default": [
+ "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),),
+ ],
+ })
# Bazel-generated shared objects which must be linked into TensorFlow binaries
# to define symbols from //tensorflow/core:framework and //tensorflow/core:lib.
def tf_binary_additional_srcs():
- return if_static(
- extra_deps=[],
- otherwise=[
- clean_dep("//tensorflow:libtensorflow_framework.so"),
- ])
-
+ return if_static(
+ extra_deps = [],
+ otherwise = [
+ clean_dep("//tensorflow:libtensorflow_framework.so"),
+ ],
+ )
# Helper functions to add kernel dependencies to tf binaries when using dynamic
# kernel linking.
def tf_binary_dynamic_kernel_dsos(kernels):
- return if_dynamic_kernels(
- extra_deps=["libtfkernel_%s.so" % clean_dep(k) for k in kernels],
- otherwise=[])
+ return if_dynamic_kernels(
+ extra_deps = ["libtfkernel_%s.so" % clean_dep(k) for k in kernels],
+ otherwise = [],
+ )
# Helper functions to add kernel dependencies to tf binaries when using static
# kernel linking.
def tf_binary_dynamic_kernel_deps(kernels):
- return if_dynamic_kernels(
- extra_deps=[],
- otherwise=kernels)
+ return if_dynamic_kernels(
+ extra_deps = [],
+ otherwise = kernels,
+ )
def tf_cc_shared_object(
- name,
- srcs=[],
- deps=[],
- data=[],
- linkopts=[],
- framework_so=tf_binary_additional_srcs(),
- kernels=[],
- **kwargs):
- native.cc_binary(
- name=name,
- srcs=srcs + framework_so,
- deps=deps + tf_binary_dynamic_kernel_deps(kernels),
- linkshared = 1,
- data = data + tf_binary_dynamic_kernel_dsos(kernels),
- linkopts=linkopts + _rpath_linkopts(name) + select({
- clean_dep("//tensorflow:darwin"): [
- "-Wl,-install_name,@rpath/" + name.split("/")[-1],
- ],
- clean_dep("//tensorflow:windows"): [],
- "//conditions:default": [
- "-Wl,-soname," + name.split("/")[-1],
- ],
- }),
- **kwargs)
+ name,
+ srcs = [],
+ deps = [],
+ data = [],
+ linkopts = [],
+ framework_so = tf_binary_additional_srcs(),
+ kernels = [],
+ **kwargs):
+ native.cc_binary(
+ name = name,
+ srcs = srcs + framework_so,
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels),
+ linkshared = 1,
+ data = data + tf_binary_dynamic_kernel_dsos(kernels),
+ linkopts = linkopts + _rpath_linkopts(name) + select({
+ clean_dep("//tensorflow:darwin"): [
+ "-Wl,-install_name,@rpath/" + name.split("/")[-1],
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ "//conditions:default": [
+ "-Wl,-soname," + name.split("/")[-1],
+ ],
+ }),
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_cc_shared_object",
@@ -376,26 +383,28 @@ register_extension_info(
# (//third_party/tensorflow:libtensorflow_framework.so) when not building
# statically. Also adds linker options (rpaths) so that the framework shared
# object can be found.
-def tf_cc_binary(name,
- srcs=[],
- deps=[],
- data=[],
- linkopts=[],
- copts=tf_copts(),
- kernels=[],
- **kwargs):
- native.cc_binary(
- name=name,
- copts=copts,
- srcs=srcs + tf_binary_additional_srcs(),
- deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- ],
- ),
- data=data + tf_binary_dynamic_kernel_dsos(kernels),
- linkopts=linkopts + _rpath_linkopts(name),
- **kwargs)
+def tf_cc_binary(
+ name,
+ srcs = [],
+ deps = [],
+ data = [],
+ linkopts = [],
+ copts = tf_copts(),
+ kernels = [],
+ **kwargs):
+ native.cc_binary(
+ name = name,
+ copts = copts,
+ srcs = srcs + tf_binary_additional_srcs(),
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml(
+ [
+ "//third_party/intel_mkl_ml",
+ ],
+ ),
+ data = data + tf_binary_dynamic_kernel_dsos(kernels),
+ linkopts = linkopts + _rpath_linkopts(name),
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_cc_binary",
@@ -405,64 +414,72 @@ register_extension_info(
# A simple wrap around native.cc_binary rule.
# When using this rule, you should realize it doesn't link to any tensorflow
# dependencies by default.
-def tf_native_cc_binary(name,
- copts=tf_copts(),
- **kwargs):
- native.cc_binary(
- name=name,
- copts=copts,
- **kwargs)
+def tf_native_cc_binary(
+ name,
+ copts = tf_copts(),
+ **kwargs):
+ native.cc_binary(
+ name = name,
+ copts = copts,
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_native_cc_binary",
label_regex_for_dep = "{extension_name}.*",
)
-def tf_gen_op_wrapper_cc(name,
- out_ops_file,
- pkg="",
- op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
- deps=None,
- include_internal_ops=0,
- # ApiDefs will be loaded in the order specified in this list.
- api_def_srcs=[]):
- # Construct an op generator binary for these ops.
- tool = out_ops_file + "_gen_cc"
- if deps == None:
- deps = [pkg + ":" + name + "_op_lib"]
- tf_cc_binary(
- name=tool,
- copts=tf_copts(),
- linkopts=if_not_windows(["-lm"]),
- linkstatic=1, # Faster to link this one-time-use binary dynamically
- deps=[op_gen] + deps)
-
- srcs = api_def_srcs[:]
-
- if not api_def_srcs:
- api_def_args_str = ","
- else:
- api_def_args = []
- for api_def_src in api_def_srcs:
- # Add directory of the first ApiDef source to args.
- # We are assuming all ApiDefs in a single api_def_src are in the
- # same directory.
- api_def_args.append(
- " $$(dirname $$(echo $(locations " + api_def_src +
- ") | cut -d\" \" -f1))")
- api_def_args_str = ",".join(api_def_args)
-
- native.genrule(
- name=name + "_genrule",
- outs=[
- out_ops_file + ".h", out_ops_file + ".cc",
- out_ops_file + "_internal.h", out_ops_file + "_internal.cc"
- ],
- srcs=srcs,
- tools=[":" + tool] + tf_binary_additional_srcs(),
- cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
- "$(location :" + out_ops_file + ".cc) " +
- str(include_internal_ops) + " " + api_def_args_str))
+def tf_gen_op_wrapper_cc(
+ name,
+ out_ops_file,
+ pkg = "",
+ op_gen = clean_dep("//tensorflow/cc:cc_op_gen_main"),
+ deps = None,
+ include_internal_ops = 0,
+ # ApiDefs will be loaded in the order specified in this list.
+ api_def_srcs = []):
+ # Construct an op generator binary for these ops.
+ tool = out_ops_file + "_gen_cc"
+ if deps == None:
+ deps = [pkg + ":" + name + "_op_lib"]
+ tf_cc_binary(
+ name = tool,
+ copts = tf_copts(),
+ linkopts = if_not_windows(["-lm"]),
+ linkstatic = 1, # Faster to link this one-time-use binary dynamically
+ deps = [op_gen] + deps,
+ )
+
+ srcs = api_def_srcs[:]
+
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ " $$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))",
+ )
+ api_def_args_str = ",".join(api_def_args)
+
+ native.genrule(
+ name = name + "_genrule",
+ outs = [
+ out_ops_file + ".h",
+ out_ops_file + ".cc",
+ out_ops_file + "_internal.h",
+ out_ops_file + "_internal.cc",
+ ],
+ srcs = srcs,
+ tools = [":" + tool] + tf_binary_additional_srcs(),
+ cmd = ("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
+ "$(location :" + out_ops_file + ".cc) " +
+ str(include_internal_ops) + " " + api_def_args_str),
+ )
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate individual C++ .cc and .h
@@ -491,68 +508,72 @@ def tf_gen_op_wrapper_cc(name,
# "ops/math_ops_internal.h" ],
# deps = [ ... ])
# TODO(joshl): Cleaner approach for hidden ops.
-def tf_gen_op_wrappers_cc(name,
- op_lib_names=[],
- other_srcs=[],
- other_hdrs=[],
- pkg="",
- deps=[
- clean_dep("//tensorflow/cc:ops"),
- clean_dep("//tensorflow/cc:scope"),
- clean_dep("//tensorflow/cc:const_op"),
- ],
- op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
- include_internal_ops=0,
- visibility=None,
- # ApiDefs will be loaded in the order apecified in this list.
- api_def_srcs=[]):
- subsrcs = other_srcs[:]
- subhdrs = other_hdrs[:]
- internalsrcs = []
- internalhdrs = []
- for n in op_lib_names:
- tf_gen_op_wrapper_cc(
- n,
- "ops/" + n,
- pkg=pkg,
- op_gen=op_gen,
- include_internal_ops=include_internal_ops,
- api_def_srcs=api_def_srcs)
- subsrcs += ["ops/" + n + ".cc"]
- subhdrs += ["ops/" + n + ".h"]
- internalsrcs += ["ops/" + n + "_internal.cc"]
- internalhdrs += ["ops/" + n + "_internal.h"]
-
- native.cc_library(
- name=name,
- srcs=subsrcs,
- hdrs=subhdrs,
- deps=deps + if_not_android([
- clean_dep("//tensorflow/core:core_cpu"),
- clean_dep("//tensorflow/core:framework"),
- clean_dep("//tensorflow/core:lib"),
- clean_dep("//tensorflow/core:protos_all_cc"),
- ]) + if_android([
- clean_dep("//tensorflow/core:android_tensorflow_lib"),
- ]),
- copts=tf_copts(),
- alwayslink=1,
- visibility=visibility)
- native.cc_library(
- name=name + "_internal",
- srcs=internalsrcs,
- hdrs=internalhdrs,
- deps=deps + if_not_android([
- clean_dep("//tensorflow/core:core_cpu"),
- clean_dep("//tensorflow/core:framework"),
- clean_dep("//tensorflow/core:lib"),
- clean_dep("//tensorflow/core:protos_all_cc"),
- ]) + if_android([
- clean_dep("//tensorflow/core:android_tensorflow_lib"),
- ]),
- copts=tf_copts(),
- alwayslink=1,
- visibility=[clean_dep("//tensorflow:internal")])
+def tf_gen_op_wrappers_cc(
+ name,
+ op_lib_names = [],
+ other_srcs = [],
+ other_hdrs = [],
+ pkg = "",
+ deps = [
+ clean_dep("//tensorflow/cc:ops"),
+ clean_dep("//tensorflow/cc:scope"),
+ clean_dep("//tensorflow/cc:const_op"),
+ ],
+ op_gen = clean_dep("//tensorflow/cc:cc_op_gen_main"),
+ include_internal_ops = 0,
+ visibility = None,
+ # ApiDefs will be loaded in the order apecified in this list.
+ api_def_srcs = []):
+ subsrcs = other_srcs[:]
+ subhdrs = other_hdrs[:]
+ internalsrcs = []
+ internalhdrs = []
+ for n in op_lib_names:
+ tf_gen_op_wrapper_cc(
+ n,
+ "ops/" + n,
+ pkg = pkg,
+ op_gen = op_gen,
+ include_internal_ops = include_internal_ops,
+ api_def_srcs = api_def_srcs,
+ )
+ subsrcs += ["ops/" + n + ".cc"]
+ subhdrs += ["ops/" + n + ".h"]
+ internalsrcs += ["ops/" + n + "_internal.cc"]
+ internalhdrs += ["ops/" + n + "_internal.h"]
+
+ native.cc_library(
+ name = name,
+ srcs = subsrcs,
+ hdrs = subhdrs,
+ deps = deps + if_not_android([
+ clean_dep("//tensorflow/core:core_cpu"),
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:lib"),
+ clean_dep("//tensorflow/core:protos_all_cc"),
+ ]) + if_android([
+ clean_dep("//tensorflow/core:android_tensorflow_lib"),
+ ]),
+ copts = tf_copts(),
+ alwayslink = 1,
+ visibility = visibility,
+ )
+ native.cc_library(
+ name = name + "_internal",
+ srcs = internalsrcs,
+ hdrs = internalhdrs,
+ deps = deps + if_not_android([
+ clean_dep("//tensorflow/core:core_cpu"),
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:lib"),
+ clean_dep("//tensorflow/core:protos_all_cc"),
+ ]) + if_android([
+ clean_dep("//tensorflow/core:android_tensorflow_lib"),
+ ]),
+ copts = tf_copts(),
+ alwayslink = 1,
+ visibility = [clean_dep("//tensorflow:internal")],
+ )
# Generates a Python library target wrapping the ops registered in "deps".
#
@@ -578,96 +599,102 @@ def tf_gen_op_wrappers_cc(name,
# is invalid to specify both "hidden" and "op_whitelist".
# cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the
# specified ops.
-def tf_gen_op_wrapper_py(name,
- out=None,
- hidden=None,
- visibility=None,
- deps=[],
- require_shape_functions=False,
- hidden_file=None,
- generated_target_name=None,
- op_whitelist=[],
- cc_linkopts=[],
- api_def_srcs=[]):
- if (hidden or hidden_file) and op_whitelist:
- fail('Cannot pass specify both hidden and op_whitelist.')
-
- # Construct a cc_binary containing the specified ops.
- tool_name = "gen_" + name + "_py_wrappers_cc"
- if not deps:
- deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
- tf_cc_binary(
- name=tool_name,
- linkopts=if_not_windows(["-lm"]) + cc_linkopts,
- copts=tf_copts(),
- linkstatic=1, # Faster to link this one-time-use binary dynamically
- deps=([
- clean_dep("//tensorflow/core:framework"),
- clean_dep("//tensorflow/python:python_op_gen_main")
- ] + deps),
- visibility=[clean_dep("//tensorflow:internal")],)
-
- # Invoke the previous cc_binary to generate a python file.
- if not out:
- out = "ops/gen_" + name + ".py"
-
- if hidden:
- op_list_arg = ",".join(hidden)
- op_list_is_whitelist = False
- elif op_whitelist:
- op_list_arg = ",".join(op_whitelist)
- op_list_is_whitelist = True
- else:
- op_list_arg = "''"
- op_list_is_whitelist = False
-
- # Prepare ApiDef directories to pass to the genrule.
- if not api_def_srcs:
- api_def_args_str = ","
- else:
- api_def_args = []
- for api_def_src in api_def_srcs:
- # Add directory of the first ApiDef source to args.
- # We are assuming all ApiDefs in a single api_def_src are in the
- # same directory.
- api_def_args.append(
- "$$(dirname $$(echo $(locations " + api_def_src +
- ") | cut -d\" \" -f1))")
- api_def_args_str = ",".join(api_def_args)
-
- if hidden_file:
- # `hidden_file` is file containing a list of op names to be hidden in the
- # generated module.
- native.genrule(
- name=name + "_pygenrule",
- outs=[out],
- srcs=api_def_srcs + [hidden_file],
- tools=[tool_name] + tf_binary_additional_srcs(),
- cmd=("$(location " + tool_name + ") " + api_def_args_str +
- " @$(location " + hidden_file + ") " +
- ("1" if require_shape_functions else "0") + " > $@"))
- else:
- native.genrule(
- name=name + "_pygenrule",
- outs=[out],
- srcs=api_def_srcs,
- tools=[tool_name] + tf_binary_additional_srcs(),
- cmd=("$(location " + tool_name + ") " + api_def_args_str + " " +
- op_list_arg + " " +
- ("1" if require_shape_functions else "0") + " " +
- ("1" if op_list_is_whitelist else "0") + " > $@"))
-
- # Make a py_library out of the generated python file.
- if not generated_target_name:
- generated_target_name = name
- native.py_library(
- name=generated_target_name,
- srcs=[out],
- srcs_version="PY2AND3",
- visibility=visibility,
- deps=[
- clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
- ],)
+def tf_gen_op_wrapper_py(
+ name,
+ out = None,
+ hidden = None,
+ visibility = None,
+ deps = [],
+ require_shape_functions = False,
+ hidden_file = None,
+ generated_target_name = None,
+ op_whitelist = [],
+ cc_linkopts = [],
+ api_def_srcs = []):
+ if (hidden or hidden_file) and op_whitelist:
+ fail("Cannot pass specify both hidden and op_whitelist.")
+
+ # Construct a cc_binary containing the specified ops.
+ tool_name = "gen_" + name + "_py_wrappers_cc"
+ if not deps:
+ deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
+ tf_cc_binary(
+ name = tool_name,
+ linkopts = if_not_windows(["-lm"]) + cc_linkopts,
+ copts = tf_copts(),
+ linkstatic = 1, # Faster to link this one-time-use binary dynamically
+ deps = ([
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/python:python_op_gen_main"),
+ ] + deps),
+ visibility = [clean_dep("//tensorflow:internal")],
+ )
+
+ # Invoke the previous cc_binary to generate a python file.
+ if not out:
+ out = "ops/gen_" + name + ".py"
+
+ if hidden:
+ op_list_arg = ",".join(hidden)
+ op_list_is_whitelist = False
+ elif op_whitelist:
+ op_list_arg = ",".join(op_whitelist)
+ op_list_is_whitelist = True
+ else:
+ op_list_arg = "''"
+ op_list_is_whitelist = False
+
+ # Prepare ApiDef directories to pass to the genrule.
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ "$$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))",
+ )
+ api_def_args_str = ",".join(api_def_args)
+
+ if hidden_file:
+ # `hidden_file` is file containing a list of op names to be hidden in the
+ # generated module.
+ native.genrule(
+ name = name + "_pygenrule",
+ outs = [out],
+ srcs = api_def_srcs + [hidden_file],
+ tools = [tool_name] + tf_binary_additional_srcs(),
+ cmd = ("$(location " + tool_name + ") " + api_def_args_str +
+ " @$(location " + hidden_file + ") " +
+ ("1" if require_shape_functions else "0") + " > $@"),
+ )
+ else:
+ native.genrule(
+ name = name + "_pygenrule",
+ outs = [out],
+ srcs = api_def_srcs,
+ tools = [tool_name] + tf_binary_additional_srcs(),
+ cmd = ("$(location " + tool_name + ") " + api_def_args_str + " " +
+ op_list_arg + " " +
+ ("1" if require_shape_functions else "0") + " " +
+ ("1" if op_list_is_whitelist else "0") + " > $@"),
+ )
+
+ # Make a py_library out of the generated python file.
+ if not generated_target_name:
+ generated_target_name = name
+ native.py_library(
+ name = generated_target_name,
+ srcs = [out],
+ srcs_version = "PY2AND3",
+ visibility = visibility,
+ deps = [
+ clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
+ ],
+ )
# Define a bazel macro that creates cc_test for tensorflow.
#
@@ -678,53 +705,54 @@ def tf_gen_op_wrapper_py(name,
#
# TODO(opensource): we need to enable this to work around the hidden symbol
# __cudaRegisterFatBinary error. Need more investigations.
-def tf_cc_test(name,
- srcs,
- deps,
- data=[],
- linkstatic=0,
- extra_copts=[],
- suffix="",
- linkopts=[],
- nocopts=None,
- kernels=[],
- **kwargs):
- native.cc_test(
- name="%s%s" % (name, suffix),
- srcs=srcs + tf_binary_additional_srcs(),
- copts=tf_copts() + extra_copts,
- linkopts=select({
- clean_dep("//tensorflow:android"): [
- "-pie",
- ],
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- clean_dep("//tensorflow:darwin"): [
- "-lm",
- ],
- "//conditions:default": [
- "-lpthread",
- "-lm"
- ],
- }) + linkopts + _rpath_linkopts(name),
- deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- ],
- ),
- data=data + tf_binary_dynamic_kernel_dsos(kernels),
- # Nested select() statements seem not to be supported when passed to
- # linkstatic, and we already have a cuda select() passed in to this
- # function.
- linkstatic=linkstatic or select({
- # cc_tests with ".so"s in srcs incorrectly link on Darwin unless
- # linkstatic=1 (https://github.com/bazelbuild/bazel/issues/3450).
- # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out.
- clean_dep("//tensorflow:darwin"): 1,
- "//conditions:default": 0,
- }),
- nocopts=nocopts,
- **kwargs)
+def tf_cc_test(
+ name,
+ srcs,
+ deps,
+ data = [],
+ linkstatic = 0,
+ extra_copts = [],
+ suffix = "",
+ linkopts = [],
+ nocopts = None,
+ kernels = [],
+ **kwargs):
+ native.cc_test(
+ name = "%s%s" % (name, suffix),
+ srcs = srcs + tf_binary_additional_srcs(),
+ copts = tf_copts() + extra_copts,
+ linkopts = select({
+ clean_dep("//tensorflow:android"): [
+ "-pie",
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ clean_dep("//tensorflow:darwin"): [
+ "-lm",
+ ],
+ "//conditions:default": [
+ "-lpthread",
+ "-lm",
+ ],
+ }) + linkopts + _rpath_linkopts(name),
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml(
+ [
+ "//third_party/intel_mkl_ml",
+ ],
+ ),
+ data = data + tf_binary_dynamic_kernel_dsos(kernels),
+ # Nested select() statements seem not to be supported when passed to
+ # linkstatic, and we already have a cuda select() passed in to this
+ # function.
+ linkstatic = linkstatic or select({
+ # cc_tests with ".so"s in srcs incorrectly link on Darwin unless
+ # linkstatic=1 (https://github.com/bazelbuild/bazel/issues/3450).
+ # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out.
+ clean_dep("//tensorflow:darwin"): 1,
+ "//conditions:default": 0,
+ }),
+ nocopts = nocopts,
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_cc_test",
@@ -733,107 +761,115 @@ register_extension_info(
# Part of the testing workflow requires a distinguishable name for the build
# rules that involve a GPU, even if otherwise identical to the base rule.
-def tf_cc_test_gpu(name,
- srcs,
- deps,
- linkstatic=0,
- tags=[],
- data=[],
- size="medium",
- suffix="",
- args=None):
- tf_cc_test(
- name,
- srcs,
- deps,
- linkstatic=linkstatic,
- tags=tags,
- data=data,
- size=size,
- suffix=suffix,
- args=args)
+def tf_cc_test_gpu(
+ name,
+ srcs,
+ deps,
+ linkstatic = 0,
+ tags = [],
+ data = [],
+ size = "medium",
+ suffix = "",
+ args = None):
+ tf_cc_test(
+ name,
+ srcs,
+ deps,
+ linkstatic = linkstatic,
+ tags = tags,
+ data = data,
+ size = size,
+ suffix = suffix,
+ args = args,
+ )
register_extension_info(
extension_name = "tf_cc_test_gpu",
label_regex_for_dep = "{extension_name}",
)
-def tf_cuda_cc_test(name,
- srcs=[],
- deps=[],
- tags=[],
- data=[],
- size="medium",
- extra_copts=[],
- linkstatic=0,
- args=[],
- linkopts=[]):
- tf_cc_test(
- name=name,
- srcs=srcs,
- deps=deps,
- tags=tags + ["manual"],
- data=data,
- size=size,
- extra_copts=extra_copts,
- linkstatic=linkstatic,
- linkopts=linkopts,
- args=args)
- tf_cc_test(
- name=name,
- srcs=srcs,
- suffix="_gpu",
- deps=deps + if_cuda([
- clean_dep("//tensorflow/core:gpu_runtime"),
- ]),
- linkstatic=select({
- # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out.
- clean_dep("//tensorflow:darwin"): 1,
- "@local_config_cuda//cuda:using_nvcc": 1,
- "@local_config_cuda//cuda:using_clang": 1,
- "//conditions:default": 0,
- }),
- tags=tags + tf_cuda_tests_tags(),
- data=data,
- size=size,
- extra_copts=extra_copts,
- linkopts=linkopts,
- args=args)
+def tf_cuda_cc_test(
+ name,
+ srcs = [],
+ deps = [],
+ tags = [],
+ data = [],
+ size = "medium",
+ extra_copts = [],
+ linkstatic = 0,
+ args = [],
+ linkopts = []):
+ tf_cc_test(
+ name = name,
+ srcs = srcs,
+ deps = deps,
+ tags = tags + ["manual"],
+ data = data,
+ size = size,
+ extra_copts = extra_copts,
+ linkstatic = linkstatic,
+ linkopts = linkopts,
+ args = args,
+ )
+ tf_cc_test(
+ name = name,
+ srcs = srcs,
+ suffix = "_gpu",
+ deps = deps + if_cuda([
+ clean_dep("//tensorflow/core:gpu_runtime"),
+ ]),
+ linkstatic = select({
+ # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out.
+ clean_dep("//tensorflow:darwin"): 1,
+ "@local_config_cuda//cuda:using_nvcc": 1,
+ "@local_config_cuda//cuda:using_clang": 1,
+ "//conditions:default": 0,
+ }),
+ tags = tags + tf_cuda_tests_tags(),
+ data = data,
+ size = size,
+ extra_copts = extra_copts,
+ linkopts = linkopts,
+ args = args,
+ )
register_extension_info(
extension_name = "tf_cuda_cc_test",
label_regex_for_dep = "{extension_name}",
)
-def tf_cuda_only_cc_test(name,
- srcs=[],
- deps=[],
- tags=[],
- data=[],
- size="medium",
- linkstatic=0,
- args=[],
- kernels=[],
- linkopts=[]):
- native.cc_test(
- name="%s%s" % (name, "_gpu"),
- srcs=srcs + tf_binary_additional_srcs(),
- size=size,
- args=args,
- copts= _cuda_copts() + tf_copts(),
- data=data + tf_binary_dynamic_kernel_dsos(kernels),
- deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_cuda([
- clean_dep("//tensorflow/core:cuda"),
- clean_dep("//tensorflow/core:gpu_lib")]),
- linkopts=if_not_windows(["-lpthread", "-lm"]) + linkopts + _rpath_linkopts(name),
- linkstatic=linkstatic or select({
- # cc_tests with ".so"s in srcs incorrectly link on Darwin
- # unless linkstatic=1.
- # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out.
- clean_dep("//tensorflow:darwin"): 1,
- "//conditions:default": 0,
- }),
- tags=tags + tf_cuda_tests_tags())
+def tf_cuda_only_cc_test(
+ name,
+ srcs = [],
+ deps = [],
+ tags = [],
+ data = [],
+ size = "medium",
+ linkstatic = 0,
+ args = [],
+ kernels = [],
+ linkopts = []):
+ native.cc_test(
+ name = "%s%s" % (name, "_gpu"),
+ srcs = srcs + tf_binary_additional_srcs(),
+ size = size,
+ args = args,
+ copts = _cuda_copts() + tf_copts(),
+ data = data + tf_binary_dynamic_kernel_dsos(kernels),
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_cuda([
+ clean_dep("//tensorflow/core:cuda"),
+ clean_dep("//tensorflow/core:gpu_lib"),
+ ]),
+ linkopts = if_not_windows(["-lpthread", "-lm"]) + linkopts + _rpath_linkopts(name),
+ linkstatic = linkstatic or select({
+ # cc_tests with ".so"s in srcs incorrectly link on Darwin
+ # unless linkstatic=1.
+ # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out.
+ clean_dep("//tensorflow:darwin"): 1,
+ "//conditions:default": 0,
+ }),
+ tags = tags + tf_cuda_tests_tags(),
+ )
register_extension_info(
extension_name = "tf_cuda_only_cc_test",
@@ -841,109 +877,112 @@ register_extension_info(
)
# Create a cc_test for each of the tensorflow tests listed in "tests"
-def tf_cc_tests(srcs,
- deps,
- name="",
- linkstatic=0,
- tags=[],
- size="medium",
- args=None,
- linkopts=[],
- nocopts=None):
- for src in srcs:
- tf_cc_test(
- name=src_to_test_name(src),
- srcs=[src],
- deps=deps,
- linkstatic=linkstatic,
- tags=tags,
- size=size,
- args=args,
- linkopts=linkopts,
- nocopts=nocopts)
-
-def tf_cc_test_mkl(srcs,
- deps,
- name="",
- data=[],
- linkstatic=0,
- tags=[],
- size="medium",
- kernels=[],
- args=None):
- # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
- disable_header_modules = ["-use_header_modules"]
-
- for src in srcs:
- native.cc_test(
- name=src_to_test_name(src),
- srcs=if_mkl([src]) + tf_binary_additional_srcs(),
- copts=tf_copts(),
- linkopts=select({
- clean_dep("//tensorflow:android"): [
- "-pie",
- ],
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- "//conditions:default": [
- "-lpthread",
- "-lm"
- ],
- }) + _rpath_linkopts(src_to_test_name(src)),
- deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- ],
- ),
- data=data + tf_binary_dynamic_kernel_dsos(kernels),
- linkstatic=linkstatic,
- tags=tags,
- size=size,
- args=args,
- features=disable_header_modules,
- nocopts="-fno-exceptions")
-
-
-def tf_cc_tests_gpu(srcs,
- deps,
- name="",
- linkstatic=0,
- tags=[],
- size="medium",
- args=None):
- tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)
-
-def tf_cuda_cc_tests(srcs,
- deps,
- name="",
- tags=[],
- size="medium",
- linkstatic=0,
- args=None,
- linkopts=[]):
- for src in srcs:
- tf_cuda_cc_test(
- name=src_to_test_name(src),
- srcs=[src],
- deps=deps,
- tags=tags,
- size=size,
- linkstatic=linkstatic,
- args=args,
- linkopts=linkopts)
-
-def tf_java_test(name,
- srcs=[],
- deps=[],
- kernels=[],
- *args,
- **kwargs):
- native.java_test(
- name=name,
- srcs=srcs,
- deps=deps + tf_binary_additional_srcs() + tf_binary_dynamic_kernel_dsos(kernels) + tf_binary_dynamic_kernel_deps(kernels),
- *args,
- **kwargs)
+def tf_cc_tests(
+ srcs,
+ deps,
+ name = "",
+ linkstatic = 0,
+ tags = [],
+ size = "medium",
+ args = None,
+ linkopts = [],
+ nocopts = None):
+ for src in srcs:
+ tf_cc_test(
+ name = src_to_test_name(src),
+ srcs = [src],
+ deps = deps,
+ linkstatic = linkstatic,
+ tags = tags,
+ size = size,
+ args = args,
+ linkopts = linkopts,
+ nocopts = nocopts,
+ )
+
+def tf_cc_test_mkl(
+ srcs,
+ deps,
+ name = "",
+ data = [],
+ linkstatic = 0,
+ tags = [],
+ size = "medium",
+ kernels = [],
+ args = None):
+ # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
+ disable_header_modules = ["-use_header_modules"]
+
+ for src in srcs:
+ native.cc_test(
+ name = src_to_test_name(src),
+ srcs = if_mkl([src]) + tf_binary_additional_srcs(),
+ copts = tf_copts(),
+ linkopts = select({
+ clean_dep("//tensorflow:android"): [
+ "-pie",
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ "//conditions:default": [
+ "-lpthread",
+ "-lm",
+ ],
+ }) + _rpath_linkopts(src_to_test_name(src)),
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels) + mkl_deps(),
+ data = data + tf_binary_dynamic_kernel_dsos(kernels),
+ linkstatic = linkstatic,
+ tags = tags,
+ size = size,
+ args = args,
+ features = disable_header_modules,
+ nocopts = "-fno-exceptions",
+ )
+
+def tf_cc_tests_gpu(
+ srcs,
+ deps,
+ name = "",
+ linkstatic = 0,
+ tags = [],
+ size = "medium",
+ args = None):
+ tf_cc_tests(srcs, deps, linkstatic, tags = tags, size = size, args = args)
+
+def tf_cuda_cc_tests(
+ srcs,
+ deps,
+ name = "",
+ tags = [],
+ size = "medium",
+ linkstatic = 0,
+ args = None,
+ linkopts = []):
+ for src in srcs:
+ tf_cuda_cc_test(
+ name = src_to_test_name(src),
+ srcs = [src],
+ deps = deps,
+ tags = tags,
+ size = size,
+ linkstatic = linkstatic,
+ args = args,
+ linkopts = linkopts,
+ )
+
+def tf_java_test(
+ name,
+ srcs = [],
+ deps = [],
+ kernels = [],
+ *args,
+ **kwargs):
+ native.java_test(
+ name = name,
+ srcs = srcs,
+ deps = deps + tf_binary_additional_srcs() + tf_binary_dynamic_kernel_dsos(kernels) + tf_binary_dynamic_kernel_deps(kernels),
+ *args,
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_java_test",
@@ -951,85 +990,89 @@ register_extension_info(
)
def _cuda_copts():
- """Gets the appropriate set of copts for (maybe) CUDA compilation.
-
- If we're doing CUDA compilation, returns copts for our particular CUDA
- compiler. If we're not doing CUDA compilation, returns an empty list.
-
- """
- return cuda_default_copts() + select({
- "//conditions:default": [],
- "@local_config_cuda//cuda:using_nvcc": ([
- "-nvcc_options=relaxed-constexpr",
- "-nvcc_options=ftz=true",
- ]),
- "@local_config_cuda//cuda:using_clang": ([
- "-fcuda-flush-denormals-to-zero",
- ]),
- })
+ """Gets the appropriate set of copts for (maybe) CUDA compilation.
+
+ If we're doing CUDA compilation, returns copts for our particular CUDA
+ compiler. If we're not doing CUDA compilation, returns an empty list.
+
+ """
+ return cuda_default_copts() + select({
+ "//conditions:default": [],
+ "@local_config_cuda//cuda:using_nvcc": ([
+ "-nvcc_options=relaxed-constexpr",
+ "-nvcc_options=ftz=true",
+ ]),
+ "@local_config_cuda//cuda:using_clang": ([
+ "-fcuda-flush-denormals-to-zero",
+ ]),
+ })
# Build defs for TensorFlow kernels
# When this target is built using --config=cuda, a cc_library is built
# that passes -DGOOGLE_CUDA=1 and '-x cuda', linking in additional
# libraries needed by GPU kernels.
-def tf_gpu_kernel_library(srcs,
- copts=[],
- cuda_copts=[],
- deps=[],
- hdrs=[],
- **kwargs):
- copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
- kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
-
- native.cc_library(
- srcs=srcs,
- hdrs=hdrs,
- copts=copts,
- deps=deps + if_cuda([
- clean_dep("//tensorflow/core:cuda"),
- clean_dep("//tensorflow/core:gpu_lib"),
- ]),
- alwayslink=1,
- **kwargs)
+def tf_gpu_kernel_library(
+ srcs,
+ copts = [],
+ cuda_copts = [],
+ deps = [],
+ hdrs = [],
+ **kwargs):
+ copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
+ kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
+
+ native.cc_library(
+ srcs = srcs,
+ hdrs = hdrs,
+ copts = copts,
+ deps = deps + if_cuda([
+ clean_dep("//tensorflow/core:cuda"),
+ clean_dep("//tensorflow/core:gpu_lib"),
+ ]),
+ alwayslink = 1,
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_gpu_kernel_library",
label_regex_for_dep = "{extension_name}",
)
-def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
- """Generate a cc_library with a conditional set of CUDA dependencies.
-
- When the library is built with --config=cuda:
-
- - Both deps and cuda_deps are used as dependencies.
- - The cuda runtime is added as a dependency (if necessary).
- - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
- - In addition, when the library is also built with TensorRT enabled, it
- additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
-
- Args:
- - cuda_deps: BUILD dependencies which will be linked if and only if:
- '--config=cuda' is passed to the bazel command line.
- - deps: dependencies which will always be linked.
- - copts: copts always passed to the cc_library.
- - kwargs: Any other argument to cc_library.
- """
- if not deps:
- deps = []
- if not cuda_deps:
- cuda_deps = []
-
- kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
- native.cc_library(
- deps=deps + if_cuda(cuda_deps + [
- clean_dep("//tensorflow/core:cuda"),
- "@local_config_cuda//cuda:cuda_headers"
- ]),
- copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
- if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
- **kwargs)
+def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
+ """Generate a cc_library with a conditional set of CUDA dependencies.
+
+ When the library is built with --config=cuda:
+
+ - Both deps and cuda_deps are used as dependencies.
+ - The cuda runtime is added as a dependency (if necessary).
+ - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
+ - In addition, when the library is also built with TensorRT enabled, it
+ additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
+
+ Args:
+ - cuda_deps: BUILD dependencies which will be linked if and only if:
+ '--config=cuda' is passed to the bazel command line.
+ - deps: dependencies which will always be linked.
+ - copts: copts always passed to the cc_library.
+ - kwargs: Any other argument to cc_library.
+ """
+ if not deps:
+ deps = []
+ if not cuda_deps:
+ cuda_deps = []
+
+ kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
+ native.cc_library(
+ deps = deps + if_cuda(cuda_deps + [
+ clean_dep("//tensorflow/core:cuda"),
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+ copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
+ if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
+ **kwargs
+ )
register_extension_info(
extension_name = "tf_cuda_library",
@@ -1047,126 +1090,138 @@ def tf_kernel_library(
copts = None,
is_external = False,
**kwargs):
- """A rule to build a TensorFlow OpKernel.
-
- May either specify srcs/hdrs or prefix. Similar to tf_cuda_library,
- but with alwayslink=1 by default. If prefix is specified:
- * prefix*.cc (except *.cu.cc) is added to srcs
- * prefix*.h (except *.cu.h) is added to hdrs
- * prefix*.cu.cc and prefix*.h (including *.cu.h) are added to gpu_srcs.
- With the exception that test files are excluded.
- For example, with prefix = "cast_op",
- * srcs = ["cast_op.cc"]
- * hdrs = ["cast_op.h"]
- * gpu_srcs = ["cast_op_gpu.cu.cc", "cast_op.h"]
- * "cast_op_test.cc" is excluded
- With prefix = "cwise_op"
- * srcs = ["cwise_op_abs.cc", ..., "cwise_op_tanh.cc"],
- * hdrs = ["cwise_ops.h", "cwise_ops_common.h"],
- * gpu_srcs = ["cwise_op_gpu_abs.cu.cc", ..., "cwise_op_gpu_tanh.cu.cc",
- "cwise_ops.h", "cwise_ops_common.h",
- "cwise_ops_gpu_common.cu.h"]
- * "cwise_ops_test.cc" is excluded
- """
- if not srcs:
- srcs = []
- if not hdrs:
- hdrs = []
- if not deps:
- deps = []
- if not copts:
- copts = []
- textual_hdrs = []
- copts = copts + tf_copts(is_external=is_external)
- if prefix:
- if native.glob([prefix + "*.cu.cc"], exclude=["*test*"]):
- if not gpu_srcs:
- gpu_srcs = []
- gpu_srcs = gpu_srcs + native.glob(
- [prefix + "*.cu.cc", prefix + "*.h"], exclude=[prefix + "*test*"])
- srcs = srcs + native.glob(
- [prefix + "*.cc"], exclude=[prefix + "*test*", prefix + "*.cu.cc"])
- hdrs = hdrs + native.glob(
+ """A rule to build a TensorFlow OpKernel.
+
+ May either specify srcs/hdrs or prefix. Similar to tf_cuda_library,
+ but with alwayslink=1 by default. If prefix is specified:
+ * prefix*.cc (except *.cu.cc) is added to srcs
+ * prefix*.h (except *.cu.h) is added to hdrs
+ * prefix*.cu.cc and prefix*.h (including *.cu.h) are added to gpu_srcs.
+ With the exception that test files are excluded.
+ For example, with prefix = "cast_op",
+ * srcs = ["cast_op.cc"]
+ * hdrs = ["cast_op.h"]
+ * gpu_srcs = ["cast_op_gpu.cu.cc", "cast_op.h"]
+ * "cast_op_test.cc" is excluded
+ With prefix = "cwise_op"
+ * srcs = ["cwise_op_abs.cc", ..., "cwise_op_tanh.cc"],
+ * hdrs = ["cwise_ops.h", "cwise_ops_common.h"],
+ * gpu_srcs = ["cwise_op_gpu_abs.cu.cc", ..., "cwise_op_gpu_tanh.cu.cc",
+ "cwise_ops.h", "cwise_ops_common.h",
+ "cwise_ops_gpu_common.cu.h"]
+ * "cwise_ops_test.cc" is excluded
+ """
+ if not srcs:
+ srcs = []
+ if not hdrs:
+ hdrs = []
+ if not deps:
+ deps = []
+ if not copts:
+ copts = []
+ textual_hdrs = []
+ copts = copts + tf_copts(is_external = is_external)
+ if prefix:
+ if native.glob([prefix + "*.cu.cc"], exclude = ["*test*"]):
+ if not gpu_srcs:
+ gpu_srcs = []
+ gpu_srcs = gpu_srcs + native.glob(
+ [prefix + "*.cu.cc", prefix + "*.h"],
+ exclude = [prefix + "*test*"],
+ )
+ srcs = srcs + native.glob(
+ [prefix + "*.cc"],
+ exclude = [prefix + "*test*", prefix + "*.cu.cc"],
+ )
+ hdrs = hdrs + native.glob(
[prefix + "*.h"],
exclude = [prefix + "*test*", prefix + "*.cu.h", prefix + "*impl.h"],
)
- textual_hdrs = native.glob(
+ textual_hdrs = native.glob(
[prefix + "*impl.h"],
exclude = [prefix + "*test*", prefix + "*.cu.h"],
)
- cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")]
- if gpu_srcs:
- for gpu_src in gpu_srcs:
- if gpu_src.endswith(".cc") and not gpu_src.endswith(".cu.cc"):
- fail("{} not allowed in gpu_srcs. .cc sources must end with .cu.cc".
- format(gpu_src))
- tf_gpu_kernel_library(
- name=name + "_gpu", srcs=gpu_srcs, deps=deps, **kwargs)
- cuda_deps.extend([":" + name + "_gpu"])
- kwargs["tags"] = kwargs.get("tags", []) + [
- "req_dep=%s" % clean_dep("//tensorflow/core:gpu_lib"),
- "req_dep=@local_config_cuda//cuda:cuda_headers",
- ]
- tf_cuda_library(
- name=name,
- srcs=srcs,
- hdrs=hdrs,
- textual_hdrs = textual_hdrs,
- copts=copts,
- cuda_deps=cuda_deps,
- linkstatic=1, # Needed since alwayslink is broken in bazel b/27630669
- alwayslink=alwayslink,
- deps=deps,
- **kwargs)
-
- # TODO(gunan): CUDA dependency not clear here. Fix it.
- tf_cc_shared_object(
- name="libtfkernel_%s.so" % name,
- srcs=srcs + hdrs,
- copts=copts,
- deps=deps,
- tags=["manual", "notap"])
+ cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")]
+ if gpu_srcs:
+ for gpu_src in gpu_srcs:
+ if gpu_src.endswith(".cc") and not gpu_src.endswith(".cu.cc"):
+ fail("{} not allowed in gpu_srcs. .cc sources must end with .cu.cc"
+ .format(gpu_src))
+ tf_gpu_kernel_library(
+ name = name + "_gpu",
+ srcs = gpu_srcs,
+ deps = deps,
+ **kwargs
+ )
+ cuda_deps.extend([":" + name + "_gpu"])
+ kwargs["tags"] = kwargs.get("tags", []) + [
+ "req_dep=%s" % clean_dep("//tensorflow/core:gpu_lib"),
+ "req_dep=@local_config_cuda//cuda:cuda_headers",
+ ]
+ tf_cuda_library(
+ name = name,
+ srcs = srcs,
+ hdrs = hdrs,
+ textual_hdrs = textual_hdrs,
+ copts = copts,
+ cuda_deps = cuda_deps,
+ linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
+ alwayslink = alwayslink,
+ deps = deps,
+ **kwargs
+ )
+ # TODO(gunan): CUDA dependency not clear here. Fix it.
+ tf_cc_shared_object(
+ name = "libtfkernel_%s.so" % name,
+ srcs = srcs + hdrs,
+ copts = copts,
+ deps = deps,
+ tags = ["manual", "notap"],
+ )
register_extension_info(
extension_name = "tf_kernel_library",
label_regex_for_dep = "{extension_name}(_gpu)?",
)
-def tf_mkl_kernel_library(name,
- prefix=None,
- srcs=None,
- hdrs=None,
- deps=None,
- alwayslink=1,
- copts=tf_copts(),
- nocopts="-fno-exceptions"):
- """A rule to build MKL-based TensorFlow kernel libraries."""
-
- if not bool(srcs):
- srcs = []
- if not bool(hdrs):
- hdrs = []
-
- if prefix:
- srcs = srcs + native.glob(
- [prefix + "*.cc"])
- hdrs = hdrs + native.glob(
- [prefix + "*.h"])
-
- # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
- disable_header_modules = ["-use_header_modules"]
-
- native.cc_library(
- name=name,
- srcs=if_mkl(srcs),
- hdrs=hdrs,
- deps=deps,
- alwayslink=alwayslink,
- copts=copts,
- nocopts=nocopts,
- features = disable_header_modules
- )
+def tf_mkl_kernel_library(
+ name,
+ prefix = None,
+ srcs = None,
+ hdrs = None,
+ deps = None,
+ alwayslink = 1,
+ copts = tf_copts(),
+ nocopts = "-fno-exceptions"):
+ """A rule to build MKL-based TensorFlow kernel libraries."""
+
+ if not bool(srcs):
+ srcs = []
+ if not bool(hdrs):
+ hdrs = []
+
+ if prefix:
+ srcs = srcs + native.glob(
+ [prefix + "*.cc"],
+ )
+ hdrs = hdrs + native.glob(
+ [prefix + "*.h"],
+ )
+
+ # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
+ disable_header_modules = ["-use_header_modules"]
+
+ native.cc_library(
+ name = name,
+ srcs = if_mkl(srcs),
+ hdrs = hdrs,
+ deps = deps,
+ alwayslink = alwayslink,
+ copts = copts,
+ nocopts = nocopts,
+ features = disable_header_modules,
+ )
register_extension_info(
extension_name = "tf_mkl_kernel_library",
@@ -1175,35 +1230,42 @@ register_extension_info(
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
- srcs = ctx.files.srcs
- if len(srcs) != 1:
- fail("Exactly one SWIG source file label must be specified.", "srcs")
- module_name = ctx.attr.module_name
- src = ctx.files.srcs[0]
- inputs = depset([src])
- inputs += ctx.files.swig_includes
- for dep in ctx.attr.deps:
- inputs += dep.cc.transitive_headers
- inputs += ctx.files._swiglib
- inputs += ctx.files.toolchain_deps
- swig_include_dirs = depset(_get_repository_roots(ctx, inputs))
- swig_include_dirs += sorted([f.dirname for f in ctx.files._swiglib])
- args = [
- "-c++", "-python", "-module", module_name, "-o", ctx.outputs.cc_out.path,
- "-outdir", ctx.outputs.py_out.dirname
- ]
- args += ["-l" + f.path for f in ctx.files.swig_includes]
- args += ["-I" + i for i in swig_include_dirs]
- args += [src.path]
- outputs = [ctx.outputs.cc_out, ctx.outputs.py_out]
- ctx.action(
- executable=ctx.executable._swig,
- arguments=args,
- inputs=list(inputs),
- outputs=outputs,
- mnemonic="PythonSwig",
- progress_message="SWIGing " + src.path)
- return struct(files=depset(outputs))
+ srcs = ctx.files.srcs
+ if len(srcs) != 1:
+ fail("Exactly one SWIG source file label must be specified.", "srcs")
+ module_name = ctx.attr.module_name
+ src = ctx.files.srcs[0]
+ inputs = depset([src])
+ inputs += ctx.files.swig_includes
+ for dep in ctx.attr.deps:
+ inputs += dep.cc.transitive_headers
+ inputs += ctx.files._swiglib
+ inputs += ctx.files.toolchain_deps
+ swig_include_dirs = depset(_get_repository_roots(ctx, inputs))
+ swig_include_dirs += sorted([f.dirname for f in ctx.files._swiglib])
+ args = [
+ "-c++",
+ "-python",
+ "-module",
+ module_name,
+ "-o",
+ ctx.outputs.cc_out.path,
+ "-outdir",
+ ctx.outputs.py_out.dirname,
+ ]
+ args += ["-l" + f.path for f in ctx.files.swig_includes]
+ args += ["-I" + i for i in swig_include_dirs]
+ args += [src.path]
+ outputs = [ctx.outputs.cc_out, ctx.outputs.py_out]
+ ctx.action(
+ executable = ctx.executable._swig,
+ arguments = args,
+ inputs = list(inputs),
+ outputs = outputs,
+ mnemonic = "PythonSwig",
+ progress_message = "SWIGing " + src.path,
+ )
+ return struct(files = depset(outputs))
_py_wrap_cc = rule(
attrs = {
@@ -1241,40 +1303,40 @@ _py_wrap_cc = rule(
)
def _get_repository_roots(ctx, files):
- """Returns abnormal root directories under which files reside.
-
- When running a ctx.action, source files within the main repository are all
- relative to the current directory; however, files that are generated or exist
- in remote repositories will have their root directory be a subdirectory,
- e.g. bazel-out/local-fastbuild/genfiles/external/jpeg_archive. This function
- returns the set of these devious directories, ranked and sorted by popularity
- in order to hopefully minimize the number of I/O system calls within the
- compiler, because includes have quadratic complexity.
- """
- result = {}
- for f in files:
- root = f.root.path
- if root:
- if root not in result:
- result[root] = 0
- result[root] -= 1
- work = f.owner.workspace_root
- if work:
- if root:
- root += "/"
- root += work
- if root:
- if root not in result:
- result[root] = 0
- result[root] -= 1
- return [k for v, k in sorted([(v, k) for k, v in result.items()])]
+ """Returns abnormal root directories under which files reside.
+
+ When running a ctx.action, source files within the main repository are all
+ relative to the current directory; however, files that are generated or exist
+ in remote repositories will have their root directory be a subdirectory,
+ e.g. bazel-out/local-fastbuild/genfiles/external/jpeg_archive. This function
+ returns the set of these devious directories, ranked and sorted by popularity
+ in order to hopefully minimize the number of I/O system calls within the
+ compiler, because includes have quadratic complexity.
+ """
+ result = {}
+ for f in files:
+ root = f.root.path
+ if root:
+ if root not in result:
+ result[root] = 0
+ result[root] -= 1
+ work = f.owner.workspace_root
+ if work:
+ if root:
+ root += "/"
+ root += work
+ if root:
+ if root not in result:
+ result[root] = 0
+ result[root] -= 1
+ return [k for v, k in sorted([(v, k) for k, v in result.items()])]
# Bazel rule for collecting the header files that a target depends on.
def _transitive_hdrs_impl(ctx):
- outputs = depset()
- for dep in ctx.attr.deps:
- outputs += dep.cc.transitive_headers
- return struct(files=outputs)
+ outputs = depset()
+ for dep in ctx.attr.deps:
+ outputs += dep.cc.transitive_headers
+ return struct(files = outputs)
_transitive_hdrs = rule(
attrs = {
@@ -1286,52 +1348,54 @@ _transitive_hdrs = rule(
implementation = _transitive_hdrs_impl,
)
-def transitive_hdrs(name, deps=[], **kwargs):
- _transitive_hdrs(name=name + "_gather", deps=deps)
- native.filegroup(name=name, srcs=[":" + name + "_gather"])
+def transitive_hdrs(name, deps = [], **kwargs):
+ _transitive_hdrs(name = name + "_gather", deps = deps)
+ native.filegroup(name = name, srcs = [":" + name + "_gather"])
# Create a header only library that includes all the headers exported by
# the libraries in deps.
-def cc_header_only_library(name, deps=[], includes=[], **kwargs):
- _transitive_hdrs(name=name + "_gather", deps=deps)
- native.cc_library(name=name,
- hdrs=[":" + name + "_gather"],
- includes=includes,
- **kwargs)
+def cc_header_only_library(name, deps = [], includes = [], **kwargs):
+ _transitive_hdrs(name = name + "_gather", deps = deps)
+ native.cc_library(
+ name = name,
+ hdrs = [":" + name + "_gather"],
+ includes = includes,
+ **kwargs
+ )
def tf_custom_op_library_additional_deps():
- return [
+ return [
"@protobuf_archive//:protobuf_headers",
- clean_dep("//third_party/eigen3"),
- clean_dep("//tensorflow/core:framework_headers_lib"),
- ] + if_windows(["//tensorflow/python:pywrap_tensorflow_import_lib"])
+ clean_dep("//third_party/eigen3"),
+ clean_dep("//tensorflow/core:framework_headers_lib"),
+ ] + if_windows(["//tensorflow/python:pywrap_tensorflow_import_lib"])
# A list of targets that contains the implemenation of
# tf_custom_op_library_additional_deps. It's used to generate a DEF file for
# exporting symbols from _pywrap_tensorflow.dll on Windows.
def tf_custom_op_library_additional_deps_impl():
- return [
+ return [
"@protobuf_archive//:protobuf",
"@nsync//:nsync_cpp",
- # for //third_party/eigen3
- clean_dep("//third_party/eigen3"),
- # for //tensorflow/core:framework_headers_lib
- clean_dep("//tensorflow/core:framework"),
- clean_dep("//tensorflow/core:reader_base"),
- ]
+ # for //third_party/eigen3
+ clean_dep("//third_party/eigen3"),
+ # for //tensorflow/core:framework_headers_lib
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:reader_base"),
+ ]
# Traverse the dependency graph along the "deps" attribute of the
# target and return a struct with one field called 'tf_collected_deps'.
# tf_collected_deps will be the union of the deps of the current target
# and the tf_collected_deps of the dependencies of this target.
def _collect_deps_aspect_impl(target, ctx):
- alldeps = depset()
- if hasattr(ctx.rule.attr, "deps"):
- for dep in ctx.rule.attr.deps:
- alldeps = alldeps | depset([dep.label])
- if hasattr(dep, "tf_collected_deps"):
- alldeps = alldeps | dep.tf_collected_deps
- return struct(tf_collected_deps=alldeps)
+ alldeps = depset()
+ if hasattr(ctx.rule.attr, "deps"):
+ for dep in ctx.rule.attr.deps:
+ alldeps = alldeps | depset([dep.label])
+ if hasattr(dep, "tf_collected_deps"):
+ alldeps = alldeps | dep.tf_collected_deps
+ return struct(tf_collected_deps = alldeps)
collect_deps_aspect = aspect(
attr_aspects = ["deps"],
@@ -1339,24 +1403,26 @@ collect_deps_aspect = aspect(
)
def _dep_label(dep):
- label = dep.label
- return label.package + ":" + label.name
+ label = dep.label
+ return label.package + ":" + label.name
# This rule checks that the transitive dependencies of targets listed
# in the 'deps' attribute don't depend on the targets listed in
# the 'disallowed_deps' attribute.
def _check_deps_impl(ctx):
- disallowed_deps = ctx.attr.disallowed_deps
- for input_dep in ctx.attr.deps:
- if not hasattr(input_dep, "tf_collected_deps"):
- continue
- for dep in input_dep.tf_collected_deps:
- for disallowed_dep in disallowed_deps:
- if dep == disallowed_dep.label:
- fail(
- _dep_label(input_dep) + " cannot depend on " + _dep_label(
- disallowed_dep))
- return struct()
+ disallowed_deps = ctx.attr.disallowed_deps
+ for input_dep in ctx.attr.deps:
+ if not hasattr(input_dep, "tf_collected_deps"):
+ continue
+ for dep in input_dep.tf_collected_deps:
+ for disallowed_dep in disallowed_deps:
+ if dep == disallowed_dep.label:
+ fail(
+ _dep_label(input_dep) + " cannot depend on " + _dep_label(
+ disallowed_dep,
+ ),
+ )
+ return struct()
check_deps = rule(
_check_deps_impl,
@@ -1375,66 +1441,70 @@ check_deps = rule(
# Helper to build a dynamic library (.so) from the sources containing
# implementations of custom ops and kernels.
-def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]):
- cuda_deps = [
- clean_dep("//tensorflow/core:stream_executor_headers_lib"),
- "@local_config_cuda//cuda:cuda_headers",
- "@local_config_cuda//cuda:cudart_static",
- ]
- deps = deps + tf_custom_op_library_additional_deps()
- if gpu_srcs:
- basename = name.split(".")[0]
- native.cc_library(
- name=basename + "_gpu",
- srcs=gpu_srcs,
- copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
- features = if_cuda(["-use_header_modules"]),
- deps=deps + if_cuda(cuda_deps))
- cuda_deps.extend([":" + basename + "_gpu"])
-
- check_deps(
- name=name + "_check_deps",
- deps=deps + if_cuda(cuda_deps),
- disallowed_deps=[
- clean_dep("//tensorflow/core:framework"),
- clean_dep("//tensorflow/core:lib")
- ])
- tf_cc_shared_object(
- name=name,
- srcs=srcs,
- deps=deps + if_cuda(cuda_deps),
- data=if_static([name + "_check_deps"]),
- copts=tf_copts(is_external=True),
- features = ["windows_export_all_symbols"],
- linkopts=linkopts + select({
- "//conditions:default": [
- "-lm",
- ],
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- clean_dep("//tensorflow:darwin"): [],
- }),)
+def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = []):
+ cuda_deps = [
+ clean_dep("//tensorflow/core:stream_executor_headers_lib"),
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudart_static",
+ ]
+ deps = deps + tf_custom_op_library_additional_deps()
+ if gpu_srcs:
+ basename = name.split(".")[0]
+ native.cc_library(
+ name = basename + "_gpu",
+ srcs = gpu_srcs,
+ copts = _cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
+ features = if_cuda(["-use_header_modules"]),
+ deps = deps + if_cuda(cuda_deps),
+ )
+ cuda_deps.extend([":" + basename + "_gpu"])
+
+ check_deps(
+ name = name + "_check_deps",
+ deps = deps + if_cuda(cuda_deps),
+ disallowed_deps = [
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:lib"),
+ ],
+ )
+ tf_cc_shared_object(
+ name = name,
+ srcs = srcs,
+ deps = deps + if_cuda(cuda_deps),
+ data = if_static([name + "_check_deps"]),
+ copts = tf_copts(is_external = True),
+ features = ["windows_export_all_symbols"],
+ linkopts = linkopts + select({
+ "//conditions:default": [
+ "-lm",
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ clean_dep("//tensorflow:darwin"): [],
+ }),
+ )
register_extension_info(
extension_name = "tf_custom_op_library",
label_regex_for_dep = "{extension_name}",
)
-def tf_custom_op_py_library(name,
- srcs=[],
- dso=[],
- kernels=[],
- srcs_version="PY2AND3",
- visibility=None,
- deps=[]):
- kernels = kernels # unused argument
- native.py_library(
- name=name,
- data=dso,
- srcs=srcs,
- srcs_version=srcs_version,
- visibility=visibility,
- deps=deps,)
+def tf_custom_op_py_library(
+ name,
+ srcs = [],
+ dso = [],
+ kernels = [],
+ srcs_version = "PY2AND3",
+ visibility = None,
+ deps = []):
+ kernels = kernels # unused argument
+ native.py_library(
+ name = name,
+ data = dso,
+ srcs = srcs,
+ srcs_version = srcs_version,
+ visibility = visibility,
+ deps = deps,
+ )
register_extension_info(
extension_name = "tf_custom_op_py_library",
@@ -1448,119 +1518,127 @@ register_extension_info(
# This function attempts to append init_module_name to list of
# exported functions in version script
def _append_init_to_versionscript_impl(ctx):
- mod_name = ctx.attr.module_name
- if ctx.attr.is_version_script:
- ctx.actions.expand_template(
- template=ctx.file.template_file,
- output=ctx.outputs.versionscript,
- substitutions={
- "global:":"global:\n init_%s;\n PyInit_*;"%(mod_name),
- },
- is_executable=False,
- )
- else:
- ctx.actions.expand_template(
- template=ctx.file.template_file,
- output=ctx.outputs.versionscript,
- substitutions={
- "*tensorflow*":"*tensorflow*\ninit_%s\nPyInit_*\n"%(mod_name),
- },
- is_executable=False,
- )
-
+ mod_name = ctx.attr.module_name
+ if ctx.attr.is_version_script:
+ ctx.actions.expand_template(
+ template = ctx.file.template_file,
+ output = ctx.outputs.versionscript,
+ substitutions = {
+ "global:": "global:\n init_%s;\n PyInit_*;" % (mod_name),
+ },
+ is_executable = False,
+ )
+ else:
+ ctx.actions.expand_template(
+ template = ctx.file.template_file,
+ output = ctx.outputs.versionscript,
+ substitutions = {
+ "*tensorflow*": "*tensorflow*\ninit_%s\nPyInit_*\n" % (mod_name),
+ },
+ is_executable = False,
+ )
-_append_init_to_versionscript= rule(
- implementation=_append_init_to_versionscript_impl,
- attrs={
- "module_name":attr.string(mandatory=True),
- "template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
- "is_version_script":attr.bool(default=True,
- doc='whether target is a ld version script or exported symbol list',
- mandatory=False),
- },
- outputs={"versionscript":"%{name}.lds"},
+_append_init_to_versionscript = rule(
+ implementation = _append_init_to_versionscript_impl,
+ attrs = {
+ "module_name": attr.string(mandatory = True),
+ "template_file": attr.label(allow_files = True, single_file = True, mandatory = True),
+ "is_version_script": attr.bool(
+ default = True,
+ doc = "whether target is a ld version script or exported symbol list",
+ mandatory = False,
+ ),
+ },
+ outputs = {"versionscript": "%{name}.lds"},
)
-def tf_py_wrap_cc(name,
- srcs,
- swig_includes=[],
- deps=[],
- copts=[],
- **kwargs):
- module_name = name.split("/")[-1]
- # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
- # and use that as the name for the rule producing the .so file.
- cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
- cc_library_pyd_name = "/".join(
- name.split("/")[:-1] + ["_" + module_name + ".pyd"])
- extra_deps = []
- _py_wrap_cc(
- name=name + "_py_wrap",
- srcs=srcs,
- swig_includes=swig_includes,
- deps=deps + extra_deps,
- toolchain_deps=["@bazel_tools//tools/cpp:current_cc_toolchain"],
- module_name=module_name,
- py_module_name=name)
- vscriptname=name+"_versionscript"
- _append_init_to_versionscript(
- name=vscriptname,
- module_name=module_name,
- is_version_script=select({
- "@local_config_cuda//cuda:darwin":False,
- "//conditions:default":True,
- }),
- template_file=select({
- "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
- "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
- })
- )
- extra_linkopts = select({
- "@local_config_cuda//cuda:darwin": [
- "-Wl,-exported_symbols_list",
- "$(location %s.lds)"%vscriptname,
- ],
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- "//conditions:default": [
- "-Wl,--version-script",
- "$(location %s.lds)"%vscriptname,
- ]
- })
- extra_deps += select({
- "@local_config_cuda//cuda:darwin": [
- "%s.lds"%vscriptname,
- ],
- clean_dep("//tensorflow:windows"): [],
- clean_dep("//tensorflow:windows_msvc"): [],
- "//conditions:default": [
- "%s.lds"%vscriptname,
- ]
- })
-
- tf_cc_shared_object(
- name=cc_library_name,
- srcs=[module_name + ".cc"],
- copts=copts + if_not_windows([
- "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings"
- ]),
- linkopts=extra_linkopts,
- linkstatic=1,
- deps=deps + extra_deps,
- **kwargs)
- native.genrule(
- name="gen_" + cc_library_pyd_name,
- srcs=[":" + cc_library_name],
- outs=[cc_library_pyd_name],
- cmd="cp $< $@",)
- native.py_library(
- name=name,
- srcs=[":" + name + ".py"],
- srcs_version="PY2AND3",
- data=select({
- clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
- "//conditions:default": [":" + cc_library_name],
- }))
+def tf_py_wrap_cc(
+ name,
+ srcs,
+ swig_includes = [],
+ deps = [],
+ copts = [],
+ **kwargs):
+ module_name = name.split("/")[-1]
+
+ # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
+ # and use that as the name for the rule producing the .so file.
+ cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
+ cc_library_pyd_name = "/".join(
+ name.split("/")[:-1] + ["_" + module_name + ".pyd"],
+ )
+ extra_deps = []
+ _py_wrap_cc(
+ name = name + "_py_wrap",
+ srcs = srcs,
+ swig_includes = swig_includes,
+ deps = deps + extra_deps,
+ toolchain_deps = ["@bazel_tools//tools/cpp:current_cc_toolchain"],
+ module_name = module_name,
+ py_module_name = name,
+ )
+ vscriptname = name + "_versionscript"
+ _append_init_to_versionscript(
+ name = vscriptname,
+ module_name = module_name,
+ is_version_script = select({
+ "@local_config_cuda//cuda:darwin": False,
+ "//conditions:default": True,
+ }),
+ template_file = select({
+ "@local_config_cuda//cuda:darwin": clean_dep("//tensorflow:tf_exported_symbols.lds"),
+ "//conditions:default": clean_dep("//tensorflow:tf_version_script.lds"),
+ }),
+ )
+ extra_linkopts = select({
+ "@local_config_cuda//cuda:darwin": [
+ "-Wl,-exported_symbols_list",
+ "$(location %s.lds)" % vscriptname,
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ "//conditions:default": [
+ "-Wl,--version-script",
+ "$(location %s.lds)" % vscriptname,
+ ],
+ })
+ extra_deps += select({
+ "@local_config_cuda//cuda:darwin": [
+ "%s.lds" % vscriptname,
+ ],
+ clean_dep("//tensorflow:windows"): [],
+ "//conditions:default": [
+ "%s.lds" % vscriptname,
+ ],
+ })
+
+ tf_cc_shared_object(
+ name = cc_library_name,
+ srcs = [module_name + ".cc"],
+ copts = copts + if_not_windows([
+ "-Wno-self-assign",
+ "-Wno-sign-compare",
+ "-Wno-write-strings",
+ ]),
+ linkopts = extra_linkopts,
+ linkstatic = 1,
+ deps = deps + extra_deps,
+ **kwargs
+ )
+ native.genrule(
+ name = "gen_" + cc_library_pyd_name,
+ srcs = [":" + cc_library_name],
+ outs = [cc_library_pyd_name],
+ cmd = "cp $< $@",
+ )
+ native.py_library(
+ name = name,
+ srcs = [":" + name + ".py"],
+ srcs_version = "PY2AND3",
+ data = select({
+ clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
+ "//conditions:default": [":" + cc_library_name],
+ }),
+ )
# This macro is for running python tests against system installed pip package
# on Windows.
@@ -1578,246 +1656,263 @@ def tf_py_wrap_cc(name,
# Note that this only works on Windows. See the definition of
# //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons.
# 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test.
-def py_test(deps=[], data=[], **kwargs):
- native.py_test(
- # TODO(jlebar): Ideally we'd use tcmalloc here.,
- deps=select({
- "//conditions:default": deps,
- clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
- }),
- data = data + select({
- "//conditions:default": [],
- clean_dep("//tensorflow:no_tensorflow_py_deps"):
- ["//tensorflow/tools/pip_package:win_pip_package_marker"],
- }),
- **kwargs)
+def py_test(deps = [], data = [], **kwargs):
+ native.py_test(
+ # TODO(jlebar): Ideally we'd use tcmalloc here.,
+ deps = select({
+ "//conditions:default": deps,
+ clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
+ }),
+ data = data + select({
+ "//conditions:default": [],
+ clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
+ }),
+ **kwargs
+ )
register_extension_info(
extension_name = "py_test",
label_regex_for_dep = "{extension_name}",
)
-def tf_py_test(name,
- srcs,
- size="medium",
- data=[],
- main=None,
- args=[],
- tags=[],
- shard_count=1,
- additional_deps=[],
- flaky=0,
- xla_enabled=False,
- grpc_enabled=False):
- if xla_enabled:
- additional_deps = additional_deps + tf_additional_xla_deps_py()
- if grpc_enabled:
- additional_deps = additional_deps + tf_additional_grpc_deps_py()
- py_test(
- name=name,
- size=size,
- srcs=srcs,
- main=main,
- args=args,
- tags=tags,
- visibility=[clean_dep("//tensorflow:internal")],
- shard_count=shard_count,
- data=data,
- deps=[
+def tf_py_test(
+ name,
+ srcs,
+ size = "medium",
+ data = [],
+ main = None,
+ args = [],
+ tags = [],
+ shard_count = 1,
+ additional_deps = [],
+ flaky = 0,
+ xla_enabled = False,
+ grpc_enabled = False):
+ if xla_enabled:
+ additional_deps = additional_deps + tf_additional_xla_deps_py()
+ if grpc_enabled:
+ additional_deps = additional_deps + tf_additional_grpc_deps_py()
+ py_test(
+ name = name,
+ size = size,
+ srcs = srcs,
+ main = main,
+ args = args,
+ tags = tags,
+ visibility = [clean_dep("//tensorflow:internal")],
+ shard_count = shard_count,
+ data = data,
+ deps = [
clean_dep("//tensorflow/python:extra_py_tests_deps"),
clean_dep("//tensorflow/python:gradient_checker"),
- ] + additional_deps,
- flaky=flaky,
- srcs_version="PY2AND3")
+ ] + additional_deps,
+ flaky = flaky,
+ srcs_version = "PY2AND3",
+ )
register_extension_info(
extension_name = "tf_py_test",
label_regex_map = {"additional_deps": "deps:{extension_name}"},
)
-def cuda_py_test(name,
- srcs,
- size="medium",
- data=[],
- main=None,
- args=[],
- shard_count=1,
- additional_deps=[],
- tags=[],
- flaky=0,
- xla_enabled=False,
- grpc_enabled=False):
- test_tags = tags + tf_cuda_tests_tags()
- tf_py_test(
- name=name,
- size=size,
- srcs=srcs,
- data=data,
- main=main,
- args=args,
- tags=test_tags,
- shard_count=shard_count,
- additional_deps=additional_deps,
- flaky=flaky,
- xla_enabled=xla_enabled,
- grpc_enabled=grpc_enabled)
+def cuda_py_test(
+ name,
+ srcs,
+ size = "medium",
+ data = [],
+ main = None,
+ args = [],
+ shard_count = 1,
+ additional_deps = [],
+ tags = [],
+ flaky = 0,
+ xla_enabled = False,
+ grpc_enabled = False):
+ test_tags = tags + tf_cuda_tests_tags()
+ tf_py_test(
+ name = name,
+ size = size,
+ srcs = srcs,
+ data = data,
+ main = main,
+ args = args,
+ tags = test_tags,
+ shard_count = shard_count,
+ additional_deps = additional_deps,
+ flaky = flaky,
+ xla_enabled = xla_enabled,
+ grpc_enabled = grpc_enabled,
+ )
register_extension_info(
extension_name = "cuda_py_test",
label_regex_map = {"additional_deps": "additional_deps:{extension_name}"},
)
-def sycl_py_test(name,
- srcs,
- size="medium",
- data=[],
- main=None,
- args=[],
- shard_count=1,
- additional_deps=[],
- tags=[],
- flaky=0,
- xla_enabled=False,
- grpc_enabled=False):
- test_tags = tags + tf_sycl_tests_tags()
- tf_py_test(
- name=name,
- size=size,
- srcs=srcs,
- data=data,
- main=main,
- args=args,
- tags=test_tags,
- shard_count=shard_count,
- additional_deps=additional_deps,
- flaky=flaky,
- xla_enabled=xla_enabled,
- grpc_enabled=grpc_enabled)
+def sycl_py_test(
+ name,
+ srcs,
+ size = "medium",
+ data = [],
+ main = None,
+ args = [],
+ shard_count = 1,
+ additional_deps = [],
+ tags = [],
+ flaky = 0,
+ xla_enabled = False,
+ grpc_enabled = False):
+ test_tags = tags + tf_sycl_tests_tags()
+ tf_py_test(
+ name = name,
+ size = size,
+ srcs = srcs,
+ data = data,
+ main = main,
+ args = args,
+ tags = test_tags,
+ shard_count = shard_count,
+ additional_deps = additional_deps,
+ flaky = flaky,
+ xla_enabled = xla_enabled,
+ grpc_enabled = grpc_enabled,
+ )
register_extension_info(
extension_name = "sycl_py_test",
label_regex_map = {"additional_deps": "additional_deps:{extension_name}"},
)
-def py_tests(name,
- srcs,
- size="medium",
- additional_deps=[],
- data=[],
- tags=[],
- shard_count=1,
- prefix="",
- xla_enabled=False,
- grpc_enabled=False):
- for src in srcs:
- test_name = src.split("/")[-1].split(".")[0]
- if prefix:
- test_name = "%s_%s" % (prefix, test_name)
- tf_py_test(
- name=test_name,
- size=size,
- srcs=[src],
- main=src,
- tags=tags,
- shard_count=shard_count,
- data=data,
- additional_deps=additional_deps,
- xla_enabled=xla_enabled,
- grpc_enabled=grpc_enabled)
-
-def cuda_py_tests(name,
- srcs,
- size="medium",
- additional_deps=[],
- data=[],
- shard_count=1,
- tags=[],
- prefix="",
- xla_enabled=False,
- grpc_enabled=False):
- test_tags = tags + tf_cuda_tests_tags()
- py_tests(
- name=name,
- size=size,
- srcs=srcs,
- additional_deps=additional_deps,
- data=data,
- tags=test_tags,
- shard_count=shard_count,
- prefix=prefix,
- xla_enabled=xla_enabled,
- grpc_enabled=grpc_enabled)
+def py_tests(
+ name,
+ srcs,
+ size = "medium",
+ additional_deps = [],
+ data = [],
+ tags = [],
+ shard_count = 1,
+ prefix = "",
+ xla_enabled = False,
+ grpc_enabled = False):
+ for src in srcs:
+ test_name = src.split("/")[-1].split(".")[0]
+ if prefix:
+ test_name = "%s_%s" % (prefix, test_name)
+ tf_py_test(
+ name = test_name,
+ size = size,
+ srcs = [src],
+ main = src,
+ tags = tags,
+ shard_count = shard_count,
+ data = data,
+ additional_deps = additional_deps,
+ xla_enabled = xla_enabled,
+ grpc_enabled = grpc_enabled,
+ )
+
+def cuda_py_tests(
+ name,
+ srcs,
+ size = "medium",
+ additional_deps = [],
+ data = [],
+ shard_count = 1,
+ tags = [],
+ prefix = "",
+ xla_enabled = False,
+ grpc_enabled = False):
+ test_tags = tags + tf_cuda_tests_tags()
+ py_tests(
+ name = name,
+ size = size,
+ srcs = srcs,
+ additional_deps = additional_deps,
+ data = data,
+ tags = test_tags,
+ shard_count = shard_count,
+ prefix = prefix,
+ xla_enabled = xla_enabled,
+ grpc_enabled = grpc_enabled,
+ )
# Creates a genrule named <name> for running tools/proto_text's generator to
# make the proto_text functions, for the protos passed in <srcs>.
#
# Return a struct with fields (hdrs, srcs) containing the names of the
# generated files.
-def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs, protodeps=[], deps=[], visibility=None):
- out_hdrs = (
- [p.replace(".proto", ".pb_text.h")
- for p in srcs] + [p.replace(".proto", ".pb_text-impl.h") for p in srcs])
- out_srcs = [p.replace(".proto", ".pb_text.cc") for p in srcs]
- native.genrule(
- name=name + "_srcs",
- srcs=srcs + protodeps + [clean_dep("//tensorflow/tools/proto_text:placeholder.txt")],
- outs=out_hdrs + out_srcs,
- visibility=visibility,
- cmd=
- "$(location //tensorflow/tools/proto_text:gen_proto_text_functions) "
- + "$(@D) " + srcs_relative_dir + " $(SRCS)",
- tools=[
- clean_dep("//tensorflow/tools/proto_text:gen_proto_text_functions")
- ],)
-
- native.filegroup(
- name=name + "_hdrs",
- srcs=out_hdrs,
- visibility=visibility,
- )
-
- native.cc_library(
- name=name,
- srcs=out_srcs,
- hdrs=out_hdrs,
- visibility=visibility,
- deps = deps,
- )
+def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs, protodeps = [], deps = [], visibility = None):
+ out_hdrs = (
+ [
+ p.replace(".proto", ".pb_text.h")
+ for p in srcs
+ ] + [p.replace(".proto", ".pb_text-impl.h") for p in srcs]
+ )
+ out_srcs = [p.replace(".proto", ".pb_text.cc") for p in srcs]
+ native.genrule(
+ name = name + "_srcs",
+ srcs = srcs + protodeps + [clean_dep("//tensorflow/tools/proto_text:placeholder.txt")],
+ outs = out_hdrs + out_srcs,
+ visibility = visibility,
+ cmd =
+ "$(location //tensorflow/tools/proto_text:gen_proto_text_functions) " +
+ "$(@D) " + srcs_relative_dir + " $(SRCS)",
+ tools = [
+ clean_dep("//tensorflow/tools/proto_text:gen_proto_text_functions"),
+ ],
+ )
+
+ native.filegroup(
+ name = name + "_hdrs",
+ srcs = out_hdrs,
+ visibility = visibility,
+ )
+
+ native.cc_library(
+ name = name,
+ srcs = out_srcs,
+ hdrs = out_hdrs,
+ visibility = visibility,
+ deps = deps,
+ )
def tf_genrule_cmd_append_to_srcs(to_append):
- return ("cat $(SRCS) > $(@) && " + "echo >> $(@) && " + "echo " + to_append +
- " >> $(@)")
+ return ("cat $(SRCS) > $(@) && " + "echo >> $(@) && " + "echo " + to_append +
+ " >> $(@)")
def tf_version_info_genrule():
- native.genrule(
- name="version_info_gen",
- srcs=[
- clean_dep("@local_config_git//:gen/spec.json"),
- clean_dep("@local_config_git//:gen/head"),
- clean_dep("@local_config_git//:gen/branch_ref"),
- ],
- outs=["util/version_info.cc"],
- cmd=
- "$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\" --git_tag_override=$${GIT_TAG_OVERRIDE:-}",
- local=1,
- tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],)
+ native.genrule(
+ name = "version_info_gen",
+ srcs = [
+ clean_dep("@local_config_git//:gen/spec.json"),
+ clean_dep("@local_config_git//:gen/head"),
+ clean_dep("@local_config_git//:gen/branch_ref"),
+ ],
+ outs = ["util/version_info.cc"],
+ cmd =
+ "$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\" --git_tag_override=$${GIT_TAG_OVERRIDE:-}",
+ local = 1,
+ tools = [clean_dep("//tensorflow/tools/git:gen_git_source.py")],
+ )
def tf_py_build_info_genrule():
- native.genrule(
- name="py_build_info_gen",
- outs=["platform/build_info.py"],
- cmd=
- "$(location //tensorflow/tools/build_info:gen_build_info.py) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"),
- local=1,
- tools=[clean_dep("//tensorflow/tools/build_info:gen_build_info.py")],)
-
-def cc_library_with_android_deps(deps,
- android_deps=[],
- common_deps=[],
- copts=tf_copts(),
- **kwargs):
- deps = if_not_android(deps) + if_android(android_deps) + common_deps
- native.cc_library(deps=deps, copts=copts, **kwargs)
+ native.genrule(
+ name = "py_build_info_gen",
+ outs = ["platform/build_info.py"],
+ cmd =
+ "$(location //tensorflow/tools/build_info:gen_build_info.py) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"),
+ local = 1,
+ tools = [clean_dep("//tensorflow/tools/build_info:gen_build_info.py")],
+ )
+
+def cc_library_with_android_deps(
+ deps,
+ android_deps = [],
+ common_deps = [],
+ copts = tf_copts(),
+ **kwargs):
+ deps = if_not_android(deps) + if_android(android_deps) + common_deps
+ native.cc_library(deps = deps, copts = copts, **kwargs)
register_extension_info(
extension_name = "cc_library_with_android_deps",
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index f754fa1da8..ff19dcc3a3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index c9516b8f07..3c278fead6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 4f19627691..4de662fe33 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -785,6 +785,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "batch_gather"
+ argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "batch_to_space"
argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
index b0fb04d7d4..9f35395284 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
@@ -298,7 +298,7 @@ tf_module {
}
member_method {
name: "generate_checkpoint_state_proto"
- argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_checkpoint_mtimes"
@@ -446,7 +446,7 @@ tf_module {
}
member_method {
name: "update_checkpoint_state"
- argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "warm_start"
diff --git a/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh
index a1d91a6123..b497326d98 100755
--- a/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh
+++ b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh
@@ -57,6 +57,17 @@ TF_DOCKER_BUILD_TYPE="MKL" \
TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \
${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh
+# build the python3.6 container and whl
+TF_DOCKER_BUILD_TYPE="MKL" \
+ TF_DOCKER_BUILD_IS_DEVEL="YES" \
+ TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \
+ TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \
+ TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \
+ TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3.6" \
+ TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \
+ ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh
+
+
# Build containers for AVX2
# Include the instructions for haswell and later, but tune for broadwell
TF_BAZEL_BUILD_OPTIONS="--config=mkl --copt=-march=haswell --copt=-mtune=broadwell --copt=-O3 --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0"
@@ -80,3 +91,13 @@ TF_DOCKER_BUILD_TYPE="MKL" \
TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \
${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh
+# build the python3.6 container and whl
+TF_DOCKER_BUILD_TYPE="MKL" \
+ TF_DOCKER_BUILD_IS_DEVEL="YES" \
+ TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \
+ TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \
+ TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}-avx2" \
+ TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3.6" \
+ TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \
+ ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh
+
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index bf06214009..2c31d784e5 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -29,6 +29,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
numpy==1.14.5 \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 6552588fac..bacdea72ce 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -33,6 +33,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
mock \
numpy==1.14.5 \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index f4c83f85d4..4f89e3f701 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -49,6 +49,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
mock \
numpy==1.14.5 \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
index 30bc2d2806..056b4755f4 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
@@ -37,6 +37,8 @@ RUN pip --no-cache-dir install --upgrade \
RUN pip --no-cache-dir install \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
numpy \
scipy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index f0c7118ecb..2df770e525 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -18,18 +18,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libhdf5-serial-dev \
libpng12-dev \
libzmq3-dev \
+ libssl-dev \
pkg-config \
- python-dev \
- ${PYTHON3_DEV} \
rsync \
software-properties-common \
unzip \
zip \
zlib1g-dev \
openjdk-8-jdk \
- openjdk-8-jre-headless \
- && \
- apt-get clean && \
+ openjdk-8-jre-headless
+
+#install Python 3
+RUN if [ ${PYTHON} = "python3.6" ]; then \
+ curl https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tar.xz -o /opt/python.tar.xz && \
+ cd /opt && tar xvf python.tar.xz && \
+ cd /opt/*/ && ./configure && \
+ make && make install; \
+ else \
+ apt-get install -y --no-install-recommends \
+ python-dev \
+ ${PYTHON3_DEV}; \
+ fi
+
+RUN apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
@@ -41,6 +52,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
mock \
numpy \
@@ -51,7 +64,9 @@ RUN ${PIP} --no-cache-dir install \
${PYTHON} -m ipykernel.kernelspec
RUN if [ "${PYTHON}" = "python3" ]; then \
- ln -s -f /usr/bin/python3 /usr/bin/python; \
+ ln -s -f /usr/bin/python3 /usr/bin/python; \
+ elif [ "${PYTHON}" = "python3.6" ]; then \
+ ln -s -f /usr/local/bin/python3.6 /usr/bin/python; \
fi
# Set up our notebook config.
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
new file mode 100755
index 0000000000..ab2eec1728
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
@@ -0,0 +1,168 @@
+FROM ubuntu:16.04
+
+LABEL maintainer="Cong Xu <cong.xu@intel.com>"
+
+# These parameters can be overridden by parameterized_docker_build.sh
+ARG TF_BUILD_VERSION=r1.9
+ARG PYTHON="python"
+ARG PYTHON3_DEV=""
+ARG WHL_DIR="/tmp/pip"
+ARG PIP="pip"
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ ${PYTHON3_DEV} \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ wget \
+ numactl \
+ openssh-client \
+ openssh-server \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
+ ${PYTHON} get-pip.py && \
+ rm get-pip.py
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ ipykernel \
+ jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && \
+ ${PYTHON} -m ipykernel.kernelspec
+
+RUN if [ "${PYTHON}" = "python3" ]; then \
+ ln -s -f /usr/bin/python3 /usr/bin/python; \
+ fi
+
+# Set up our notebook config.
+COPY jupyter_notebook_config.py /root/.jupyter/
+
+# Jupyter has issues with being run directly:
+# https://github.com/ipython/ipython/issues/7062
+# We just add a little wrapper script.
+COPY run_jupyter.sh /
+
+# Set up Bazel.
+
+# Running bazel inside a `docker build` command causes trouble, cf:
+# https://github.com/bazelbuild/bazel/issues/134
+# The easiest solution is to set up a bazelrc file forcing --batch.
+RUN echo "startup --batch" >>/etc/bazel.bazelrc
+# Similarly, we need to workaround sandboxing issues:
+# https://github.com/bazelbuild/bazel/issues/418
+RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
+ >>/etc/bazel.bazelrc
+# Install the most recent bazel release.
+ENV BAZEL_VERSION 0.15.0
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
+ chmod +x bazel-*.sh && \
+ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ cd / && \
+ rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+
+# Download and build TensorFlow.
+WORKDIR /tensorflow
+
+# Download and build TensorFlow.
+# Enable checking out both tags and branches
+RUN export TAG_PREFIX="v" && \
+ echo ${TF_BUILD_VERSION} | grep -q ^${TAG_PREFIX}; \
+ if [ $? -eq 0 ]; then \
+ git clone --depth=1 https://github.com/tensorflow/tensorflow.git . && \
+ git fetch --tags && \
+ git checkout ${TF_BUILD_VERSION}; \
+ else \
+ git clone --depth=1 --branch=${TF_BUILD_VERSION} https://github.com/tensorflow/tensorflow.git . ; \
+ fi
+
+RUN yes "" | ${PYTHON} configure.py
+
+ENV CI_BUILD_PYTHON ${PYTHON}
+
+# Set bazel build parameters in .bazelrc in parameterized_docker_build.sh
+# Use --copt=-march values to get optimized builds appropriate for the hardware
+# platform of your choice.
+# For ivy-bridge or sandy-bridge
+# --copt=-march="avx" \
+# For haswell, broadwell, or skylake
+# --copt=-march="avx2" \
+COPY .bazelrc /root/.bazelrc
+
+RUN tensorflow/tools/ci_build/builds/configured CPU \
+ bazel --bazelrc=/root/.bazelrc build -c opt \
+ tensorflow/tools/pip_package:build_pip_package && \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package "${WHL_DIR}" && \
+ ${PIP} --no-cache-dir install --upgrade "${WHL_DIR}"/tensorflow-*.whl && \
+ rm -rf /root/.cache
+# Clean up Bazel cache when done.
+
+WORKDIR /root
+
+# Install Open MPI
+RUN mkdir /tmp/openmpi && \
+ cd /tmp/openmpi && \
+ wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz && \
+ tar zxf openmpi-3.0.0.tar.gz && \
+ cd openmpi-3.0.0 && \
+ ./configure --enable-orterun-prefix-by-default && \
+ make -j $(nproc) all && \
+ make install && \
+ ldconfig && \
+ rm -rf /tmp/openmpi
+
+# Create a wrapper for OpenMPI to allow running as root by default
+RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \
+ echo '#!/bin/bash' > /usr/local/bin/mpirun && \
+ echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \
+ chmod a+x /usr/local/bin/mpirun
+
+# Configure OpenMPI to run good defaults:
+RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf
+
+# Install Horovod
+RUN ${PIP} install --no-cache-dir horovod
+
+# Install OpenSSH for MPI to communicate between containers
+RUN mkdir -p /var/run/sshd
+
+# Allow OpenSSH to talk to containers without asking for confirmation
+RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
+ echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
+ mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config
+
+# TensorBoard
+EXPOSE 6006
+# IPython
+EXPOSE 8888
+
+WORKDIR /root
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index 5ec1e60f00..aa0e0face1 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -37,6 +37,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
numpy==1.14.5 \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
index ad5109f26d..69553302d8 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -38,6 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod
new file mode 100755
index 0000000000..756716ee0e
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod
@@ -0,0 +1,111 @@
+FROM ubuntu:16.04
+
+LABEL maintainer="Cong Xu <cong.xu@intel.com>"
+
+# This parameter MUST be set by parameterized_docker_build.sh
+ARG TF_WHL_URL
+
+# Optional parameters
+ARG TF_BUILD_VERSION=r1.9
+ARG PYTHON="python"
+ARG PYTHON_DEV="python-dev"
+ARG PIP="pip"
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python \
+ ${PYTHON_DEV} \
+ rsync \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ ipykernel \
+ jupyter \
+ keras_applications==1.0.4 \
+ keras_preprocessing==1.0.2 \
+ matplotlib \
+ numpy \
+ pandas \
+ scipy \
+ sklearn \
+ && \
+ python -m ipykernel.kernelspec
+
+COPY ${TF_WHL_URL} /
+RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \
+ rm -rf /${TF_WHL_URL}
+
+RUN if [ "${PYTHON}" = "python3" ]; then \
+ ln -s -f /usr/bin/python3 /usr/bin/python; \
+ fi
+
+# Set up our notebook config.
+COPY jupyter_notebook_config.py /root/.jupyter/
+
+# Copy sample notebooks.
+COPY notebooks /notebooks
+
+# Jupyter has issues with being run directly:
+# https://github.com/ipython/ipython/issues/7062
+# We just add a little wrapper script.
+COPY run_jupyter.sh /
+
+WORKDIR /root
+
+# Install Open MPI
+RUN mkdir /tmp/openmpi && \
+ cd /tmp/openmpi && \
+ wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz && \
+ tar zxf openmpi-3.0.0.tar.gz && \
+ cd openmpi-3.0.0 && \
+ ./configure --enable-orterun-prefix-by-default && \
+ make -j $(nproc) all && \
+ make install && \
+ ldconfig && \
+ rm -rf /tmp/openmpi
+
+# Create a wrapper for OpenMPI to allow running as root by default
+RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \
+ echo '#!/bin/bash' > /usr/local/bin/mpirun && \
+ echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \
+ chmod a+x /usr/local/bin/mpirun
+
+# Configure OpenMPI to run good defaults:
+RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf
+
+# Install Horovod
+RUN ${PIP} install --no-cache-dir horovod
+
+# Install OpenSSH for MPI to communicate between containers
+RUN mkdir -p /var/run/sshd
+
+# Allow OpenSSH to talk to containers without asking for confirmation
+RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
+ echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
+ mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config
+
+# TensorBoard
+EXPOSE 6006
+# IPython
+EXPOSE 8888
+
+WORKDIR "/notebooks"
+
+CMD ["/run_jupyter.sh", "--allow-root"]
diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh
index 4681c5fd61..c9f17a8242 100755
--- a/tensorflow/tools/docker/parameterized_docker_build.sh
+++ b/tensorflow/tools/docker/parameterized_docker_build.sh
@@ -19,8 +19,8 @@
# parameterized_docker_build.sh
#
# The script obeys the following environment variables:
-# TF_DOCKER_BUILD_TYPE: (CPU | GPU | MKL)
-# CPU, GPU, or MKL image
+# TF_DOCKER_BUILD_TYPE: (CPU | GPU | MKL | MKL-HOROVOD)
+# CPU, GPU, MKL or MKL-HOROVOD image
#
# TF_DOCKER_BUILD_IS_DEVEL: (NO | YES)
# Is this developer image
@@ -169,6 +169,15 @@ elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
else
ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl"
fi
+elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then
+ DOCKER_BINARY="docker"
+ FINAL_TAG="${FINAL_TAG}-mkl-horovod"
+ if [[ ${ORIG_DOCKERFILE} == *"."* ]]; then
+ # There is already a dot in the tag, use "-"
+ ORIG_DOCKERFILE="${ORIG_DOCKERFILE}-mkl-horovod"
+ else
+ ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl-horovod"
+ fi
elif [[ ${TF_DOCKER_BUILD_TYPE} == "gpu" ]]; then
DOCKER_BINARY="nvidia-docker"
@@ -188,6 +197,8 @@ if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python2" ]]; then
:
elif [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then
FINAL_TAG="${FINAL_TAG}-py3"
+elif [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then
+ FINAL_TAG="${FINAL_TAG}-py3.6"
else
die "Unrecognized value in TF_DOCKER_BUILD_PYTHON_VERSION: "\
"${TF_DOCKER_BUILD_PYTHON_VERSION}"
@@ -227,6 +238,10 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
die "FAIL: Non-development MKL builds require a pre-built pip whl."
fi
+ if [[ "${TF_DOCKER_BUILD_TYPE}" == "mkl-horovod" ]]; then
+ die "FAIL: Non-development MKL-HOROVOD builds require a pre-built pip whl."
+ fi
+
if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then
export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\
"${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2"
@@ -279,7 +294,8 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
# Use string replacement to put the correct file name into the Dockerfile
PIP_WHL=$(basename "${PIP_WHL}")
- if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \
+ [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then
TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}" )
cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
else
@@ -295,7 +311,8 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
echo
else
echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}"
- if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \
+ [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then
pushd "${TMP_DIR}/"
curl -O ${TF_DOCKER_BUILD_CENTRAL_PIP}
popd
@@ -319,7 +336,8 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
# Modify python/pip version if necessary.
if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then
- if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \
+ [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then
TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}")
TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON_DEV=python3-dev")
TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3")
@@ -340,8 +358,9 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
else # TF_DOCKER_BUILD_IS_DEVEL == 'yes'
DOCKERFILE="${TMP_DIR}/Dockerfile"
- # Set up Dockerfile ARGS for mkl build
- if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ # Set up Dockerfile ARGS for mkl and mkl-horovod build
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \
+ [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then
if [[ -z "${TF_BAZEL_BUILD_OPTIONS// }" ]]; then
TF_BAZEL_BUILD_OPTIONS=("--config=mkl --copt=-mavx --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0")
else
@@ -360,14 +379,17 @@ else # TF_DOCKER_BUILD_IS_DEVEL == 'yes'
fi
# Modify python/pip version if necessary.
- if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then
- if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]] || [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then
TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}")
TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON3_DEV=python3-dev")
TF_DOCKER_BUILD_ARGS+=("--build-arg WHL_DIR=/tmp/pip3")
TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3")
cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
else
+ if [[ "${TF_DOCKER_BUILD_TYPE}" != "mkl" ]]; then
+ die "Python 3.6 build only supported for MKL builds."
+ fi
if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \
sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \
sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index cc7885ab1b..4f7efe193f 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -34,11 +34,29 @@ py_test(
)
py_library(
+ name = "doc_controls",
+ srcs = ["doc_controls.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "doc_controls_test",
+ size = "small",
+ srcs = ["doc_controls_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_controls",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_library(
name = "parser",
srcs = ["parser.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":doc_controls",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"@astor_archive//:astor",
@@ -68,6 +86,7 @@ py_binary(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":doc_controls",
":doc_generator_visitor",
":parser",
":pretty_docs",
diff --git a/tensorflow/tools/docs/doc_controls.py b/tensorflow/tools/docs/doc_controls.py
new file mode 100644
index 0000000000..5e526443cc
--- /dev/null
+++ b/tensorflow/tools/docs/doc_controls.py
@@ -0,0 +1,319 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Documentation control decorators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+_DO_NOT_DOC = "_tf_docs_do_not_document"
+
+
+def do_not_generate_docs(obj):
+ """A decorator: Do not generate docs for this object.
+
+ For example the following classes:
+
+ ```
+ class Parent(object):
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+
+ class Child(Parent):
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+ ```
+
+ Produce the following api_docs:
+
+ ```
+ /Parent.md
+ # method1
+ # method2
+ /Child.md
+ # method1
+ # method2
+ ```
+
+ This decorator allows you to skip classes or methods:
+
+ ```
+ @do_not_generate_docs
+ class Parent(object):
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+
+ class Child(Parent):
+ @do_not_generate_docs
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+ ```
+
+ This will only produce the following docs:
+
+ ```
+ /Child.md
+ # method2
+ ```
+
+ Note: This is implemented by adding a hidden attribute on the object, so it
+ cannot be used on objects which do not allow new attributes to be added. So
+ this decorator must go *below* `@property`, `@classmethod`,
+ or `@staticmethod`:
+
+ ```
+ class Example(object):
+ @property
+ @do_not_generate_docs
+ def x(self):
+ return self._x
+ ```
+
+ Args:
+ obj: The object to hide from the generated docs.
+
+ Returns:
+ obj
+ """
+ setattr(obj, _DO_NOT_DOC, None)
+ return obj
+
+
+_DO_NOT_DOC_INHERITABLE = "_tf_docs_do_not_doc_inheritable"
+
+
+def do_not_doc_inheritable(obj):
+ """A decorator: Do not generate docs for this method.
+
+ This version of the decorator is "inherited" by subclasses. No docs will be
+ generated for the decorated method in any subclass. Even if the sub-class
+ overrides the method.
+
+ For example, to ensure that `method1` is **never documented** use this
+ decorator on the base-class:
+
+ ```
+ class Parent(object):
+ @do_not_doc_inheritable
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+
+ class Child(Parent):
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+ ```
+ This will produce the following docs:
+
+ ```
+ /Parent.md
+ # method2
+ /Child.md
+ # method2
+ ```
+
+ When generating docs for a class's arributes, the `__mro__` is searched and
+ the attribute will be skipped if this decorator is detected on the attribute
+ on any class in the `__mro__`.
+
+ Note: This is implemented by adding a hidden attribute on the object, so it
+ cannot be used on objects which do not allow new attributes to be added. So
+ this decorator must go *below* `@property`, `@classmethod`,
+ or `@staticmethod`:
+
+ ```
+ class Example(object):
+ @property
+ @do_not_doc_inheritable
+ def x(self):
+ return self._x
+ ```
+
+ Args:
+ obj: The class-attribute to hide from the generated docs.
+
+ Returns:
+ obj
+ """
+ setattr(obj, _DO_NOT_DOC_INHERITABLE, None)
+ return obj
+
+
+_FOR_SUBCLASS_IMPLEMENTERS = "_tf_docs_tools_for_subclass_implementers"
+
+
+def for_subclass_implementers(obj):
+ """A decorator: Only generate docs for this method in the defining class.
+
+ Also group this method's docs with and `@abstractmethod` in the class's docs.
+
+ No docs will generated for this class attribute in sub-classes.
+
+ The canonical use case for this is `tf.keras.layers.Layer.call`: It's a
+ public method, essential for anyone implementing a subclass, but it should
+ never be called directly.
+
+ Works on method, or other class-attributes.
+
+ When generating docs for a class's arributes, the `__mro__` is searched and
+ the attribute will be skipped if this decorator is detected on the attribute
+ on any **parent** class in the `__mro__`.
+
+ For example:
+
+ ```
+ class Parent(object):
+ @for_subclass_implementers
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+
+ class Child1(Parent):
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+
+ class Child2(Parent):
+ def method1(self):
+ pass
+ def method2(self):
+ pass
+ ```
+
+ This will produce the following docs:
+
+ ```
+ /Parent.md
+ # method1
+ # method2
+ /Child1.md
+ # method2
+ /Child2.md
+ # method2
+ ```
+
+ Note: This is implemented by adding a hidden attribute on the object, so it
+ cannot be used on objects which do not allow new attributes to be added. So
+ this decorator must go *below* `@property`, `@classmethod`,
+ or `@staticmethod`:
+
+ ```
+ class Example(object):
+ @property
+ @for_subclass_implementers
+ def x(self):
+ return self._x
+ ```
+
+ Args:
+ obj: The class-attribute to hide from the generated docs.
+
+ Returns:
+ obj
+ """
+ setattr(obj, _FOR_SUBCLASS_IMPLEMENTERS, None)
+ return obj
+
+
+def should_skip(obj):
+ """Returns true if docs generation should be skipped for this object.
+
+ checks for the `do_not_generate_docs` or `do_not_doc_inheritable` decorators.
+
+ Args:
+ obj: The object to document, or skip.
+
+ Returns:
+ True if the object should be skipped
+ """
+ # Unwrap fget if the object is a property
+ if isinstance(obj, property):
+ obj = obj.fget
+
+ return hasattr(obj, _DO_NOT_DOC) or hasattr(obj, _DO_NOT_DOC_INHERITABLE)
+
+
+def should_skip_class_attr(cls, name):
+ """Returns true if docs should be skipped for this class attribute.
+
+ Args:
+ cls: The class the attribute belongs to.
+ name: The name of the attribute.
+
+ Returns:
+ True if the attribute should be skipped.
+ """
+ # Get the object with standard lookup, from the nearest
+ # defining parent.
+ try:
+ obj = getattr(cls, name)
+ except AttributeError:
+ # Avoid error caused by enum metaclasses in python3
+ if name in ("name", "value"):
+ return True
+ raise
+
+ # Unwrap fget if the object is a property
+ if isinstance(obj, property):
+ obj = obj.fget
+
+ # Skip if the object is decorated with `do_not_generate_docs` or
+ # `do_not_doc_inheritable`
+ if should_skip(obj):
+ return True
+
+ # Use __dict__ lookup to get the version defined in *this* class.
+ obj = cls.__dict__.get(name, None)
+ if isinstance(obj, property):
+ obj = obj.fget
+ if obj is not None:
+ # If not none, the object is defined in *this* class.
+ # Do not skip if decorated with `for_subclass_implementers`.
+ if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS):
+ return False
+
+ # for each parent class
+ for parent in cls.__mro__[1:]:
+ obj = getattr(parent, name, None)
+
+ if obj is None:
+ continue
+
+ if isinstance(obj, property):
+ obj = obj.fget
+
+ # Skip if the parent's definition is decorated with `do_not_doc_inheritable`
+ # or `for_subclass_implementers`
+ if hasattr(obj, _DO_NOT_DOC_INHERITABLE):
+ return True
+
+ if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS):
+ return True
+
+ # No blockng decorators --> don't skip
+ return False
diff --git a/tensorflow/tools/docs/doc_controls_test.py b/tensorflow/tools/docs/doc_controls_test.py
new file mode 100644
index 0000000000..410342fb69
--- /dev/null
+++ b/tensorflow/tools/docs/doc_controls_test.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for documentation control decorators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import googletest
+from tensorflow.tools.docs import doc_controls
+
+
+class DocControlsTest(googletest.TestCase):
+
+ def test_do_not_generate_docs(self):
+
+ @doc_controls.do_not_generate_docs
+ def dummy_function():
+ pass
+
+ self.assertTrue(doc_controls.should_skip(dummy_function))
+
+ def test_do_not_doc_on_method(self):
+ """The simple decorator is not aware of inheritance."""
+
+ class Parent(object):
+
+ @doc_controls.do_not_generate_docs
+ def my_method(self):
+ pass
+
+ class Child(Parent):
+
+ def my_method(self):
+ pass
+
+ class GrandChild(Child):
+ pass
+
+ self.assertTrue(doc_controls.should_skip(Parent.my_method))
+ self.assertFalse(doc_controls.should_skip(Child.my_method))
+ self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
+
+ self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
+ self.assertFalse(doc_controls.should_skip_class_attr(Child, 'my_method'))
+ self.assertFalse(
+ doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
+
+ def test_do_not_doc_inheritable(self):
+
+ class Parent(object):
+
+ @doc_controls.do_not_doc_inheritable
+ def my_method(self):
+ pass
+
+ class Child(Parent):
+
+ def my_method(self):
+ pass
+
+ class GrandChild(Child):
+ pass
+
+ self.assertTrue(doc_controls.should_skip(Parent.my_method))
+ self.assertFalse(doc_controls.should_skip(Child.my_method))
+ self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
+
+ self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
+ self.assertTrue(
+ doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
+
+ def test_do_not_doc_inheritable_property(self):
+
+ class Parent(object):
+
+ @property
+ @doc_controls.do_not_doc_inheritable
+ def my_method(self):
+ pass
+
+ class Child(Parent):
+
+ @property
+ def my_method(self):
+ pass
+
+ class GrandChild(Child):
+ pass
+
+ self.assertTrue(doc_controls.should_skip(Parent.my_method))
+ self.assertFalse(doc_controls.should_skip(Child.my_method))
+ self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
+
+ self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
+ self.assertTrue(
+ doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
+
+ def test_do_not_doc_inheritable_staticmethod(self):
+
+ class GrandParent(object):
+
+ def my_method(self):
+ pass
+
+ class Parent(GrandParent):
+
+ @staticmethod
+ @doc_controls.do_not_doc_inheritable
+ def my_method():
+ pass
+
+ class Child(Parent):
+
+ @staticmethod
+ def my_method():
+ pass
+
+ class GrandChild(Child):
+ pass
+
+ self.assertFalse(doc_controls.should_skip(GrandParent.my_method))
+ self.assertTrue(doc_controls.should_skip(Parent.my_method))
+ self.assertFalse(doc_controls.should_skip(Child.my_method))
+ self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
+
+ self.assertFalse(
+ doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
+ self.assertTrue(
+ doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
+
+ def testfor_subclass_implementers(self):
+
+ class GrandParent(object):
+
+ def my_method(self):
+ pass
+
+ class Parent(GrandParent):
+
+ @doc_controls.for_subclass_implementers
+ def my_method(self):
+ pass
+
+ class Child(Parent):
+ pass
+
+ class GrandChild(Child):
+
+ def my_method(self):
+ pass
+
+ class Grand2Child(Child):
+ pass
+
+ self.assertFalse(
+ doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
+ self.assertFalse(doc_controls.should_skip_class_attr(Parent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
+ self.assertTrue(
+ doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
+ self.assertTrue(
+ doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 4bc8cbf4b4..9387042224 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -28,6 +28,7 @@ import six
from tensorflow.python.util import tf_inspect
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
+from tensorflow.tools.docs import doc_controls
from tensorflow.tools.docs import doc_generator_visitor
from tensorflow.tools.docs import parser
from tensorflow.tools.docs import pretty_docs
@@ -96,7 +97,7 @@ def write_docs(output_dir,
symbol_to_file = {}
# Collect redirects for an api _redirects.yaml file.
- redirects = ['redirects:\n']
+ redirects = []
# Parse and write Markdown pages, resolving cross-links (@{symbol}).
for full_name, py_object in six.iteritems(parser_config.index):
@@ -110,6 +111,9 @@ def write_docs(output_dir,
_is_free_function(py_object, full_name, parser_config.index)):
continue
+ if doc_controls.should_skip(py_object):
+ continue
+
sitepath = os.path.join('api_docs/python',
parser.documentation_path(full_name)[:-3])
@@ -162,17 +166,20 @@ def write_docs(output_dir,
continue
duplicates = [item for item in duplicates if item != full_name]
- template = ('- from: /{}\n'
- ' to: /{}\n')
+
for dup in duplicates:
from_path = os.path.join(site_api_path, dup.replace('.', '/'))
to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
- redirects.append(
- template.format(from_path, to_path))
+ redirects.append((from_path, to_path))
- if site_api_path:
+ if site_api_path and redirects:
+ redirects = sorted(redirects)
+ template = ('- from: /{}\n'
+ ' to: /{}\n')
+ redirects = [template.format(f, t) for f, t in redirects]
api_redirects_path = os.path.join(output_dir, '_redirects.yaml')
with open(api_redirects_path, 'w') as redirect_file:
+ redirect_file.write('redirects:\n')
redirect_file.write(''.join(redirects))
if yaml_toc:
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index ffb93027ed..801c8bcb4a 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -32,6 +32,7 @@ import six
from google.protobuf.message import Message as ProtoMessage
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
+from tensorflow.tools.docs import doc_controls
# A regular expression capturing a python identifier.
@@ -1175,15 +1176,18 @@ class _ClassPageInfo(object):
# Don't document anything that is defined in object or by protobuf.
defining_class = _get_defining_class(py_class, short_name)
- if (defining_class is object or
- defining_class is type or defining_class is tuple or
- defining_class is BaseException or defining_class is Exception or
- # The following condition excludes most protobuf-defined symbols.
- defining_class and defining_class.__name__ in ['CMessage', 'Message',
- 'MessageMeta']):
+ if defining_class in [object, type, tuple, BaseException, Exception]:
+ continue
+
+ # The following condition excludes most protobuf-defined symbols.
+ if (defining_class and
+ defining_class.__name__ in ['CMessage', 'Message', 'MessageMeta']):
continue
# TODO(markdaoust): Add a note in child docs showing the defining class.
+ if doc_controls.should_skip_class_attr(py_class, short_name):
+ continue
+
child_doc = _parse_md_docstring(child, relative_path,
parser_config.reference_resolver)
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index 274d48ef66..9f6b185e81 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -24,6 +24,7 @@ import sys
from tensorflow.python.platform import googletest
from tensorflow.python.util import tf_inspect
+from tensorflow.tools.docs import doc_controls
from tensorflow.tools.docs import parser
@@ -37,13 +38,27 @@ def test_function_with_args_kwargs(unused_arg, *unused_args, **unused_kwargs):
pass
-class TestClass(object):
+class ParentClass(object):
+
+ @doc_controls.do_not_doc_inheritable
+ def hidden_method(self):
+ pass
+
+
+class TestClass(ParentClass):
"""Docstring for TestClass itself."""
def a_method(self, arg='default'):
"""Docstring for a method."""
pass
+ def hidden_method(self):
+ pass
+
+ @doc_controls.do_not_generate_docs
+ def hidden_method2(self):
+ pass
+
class ChildClass(object):
"""Docstring for a child class."""
pass
@@ -175,6 +190,104 @@ class ParserTest(googletest.TestCase):
# Make sure this file is contained as the definition location.
self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
+ def test_docs_for_class_should_skip(self):
+
+ class Parent(object):
+
+ @doc_controls.do_not_doc_inheritable
+ def a_method(self, arg='default'):
+ pass
+
+ class Child(Parent):
+
+ def a_method(self, arg='default'):
+ pass
+
+ index = {
+ 'Child': Child,
+ 'Child.a_method': Child.a_method,
+ }
+
+ visitor = DummyVisitor(index=index, duplicate_of={})
+
+ reference_resolver = parser.ReferenceResolver.from_visitor(
+ visitor=visitor, doc_index={}, py_module_names=['tf'])
+
+ tree = {
+ 'Child': ['a_method'],
+ }
+
+ parser_config = parser.ParserConfig(
+ reference_resolver=reference_resolver,
+ duplicates={},
+ duplicate_of={},
+ tree=tree,
+ index=index,
+ reverse_index={},
+ guide_index={},
+ base_dir='/')
+
+ page_info = parser.docs_for_object(
+ full_name='Child', py_object=Child, parser_config=parser_config)
+
+ # Make sure the `a_method` is not present
+ self.assertEqual(0, len(page_info.methods))
+
+ def test_docs_for_message_class(self):
+
+ class CMessage(object):
+
+ def hidden(self):
+ pass
+
+ class Message(object):
+
+ def hidden2(self):
+ pass
+
+ class MessageMeta(object):
+
+ def hidden3(self):
+ pass
+
+ class ChildMessage(CMessage, Message, MessageMeta):
+
+ def my_method(self):
+ pass
+
+ index = {
+ 'ChildMessage': ChildMessage,
+ 'ChildMessage.hidden': ChildMessage.hidden,
+ 'ChildMessage.hidden2': ChildMessage.hidden2,
+ 'ChildMessage.hidden3': ChildMessage.hidden3,
+ 'ChildMessage.my_method': ChildMessage.my_method,
+ }
+
+ visitor = DummyVisitor(index=index, duplicate_of={})
+
+ reference_resolver = parser.ReferenceResolver.from_visitor(
+ visitor=visitor, doc_index={}, py_module_names=['tf'])
+
+ tree = {'ChildMessage': ['hidden', 'hidden2', 'hidden3', 'my_method']}
+
+ parser_config = parser.ParserConfig(
+ reference_resolver=reference_resolver,
+ duplicates={},
+ duplicate_of={},
+ tree=tree,
+ index=index,
+ reverse_index={},
+ guide_index={},
+ base_dir='/')
+
+ page_info = parser.docs_for_object(
+ full_name='ChildMessage',
+ py_object=ChildMessage,
+ parser_config=parser_config)
+
+ self.assertEqual(1, len(page_info.methods))
+ self.assertEqual('my_method', page_info.methods[0].short_name)
+
def test_docs_for_module(self):
# Get the current module.
module = sys.modules[__name__]
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index ef7ae1aa25..6bba139b4d 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -9,7 +9,7 @@ load(
"if_windows",
"transitive_hdrs",
)
-load("//third_party/mkl:build_defs.bzl", "if_mkl")
+load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
load("//tensorflow:tensorflow.bzl", "if_cuda")
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
@@ -208,14 +208,13 @@ sh_binary(
srcs = ["build_pip_package.sh"],
data = select({
"//tensorflow:windows": [":simple_console_for_windows"],
- "//tensorflow:windows_msvc": [":simple_console_for_windows"],
"//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
"//tensorflow/contrib/lite/python:interpreter_test_data",
"//tensorflow/contrib/lite/python:tflite_convert",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",
],
- }) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
+ }) + if_mkl_ml(["//third_party/intel_mkl_ml"]),
)
# A genrule for generating a marker file for the pip package on Windows
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index ca40f2eaa8..666ea75d46 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -44,7 +44,7 @@ function cp_external() {
PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
function is_windows() {
# On windows, the shell script is actually running in msys
- if [[ "${PLATFORM}" =~ msys_nt* ]]; then
+ if [[ "${PLATFORM}" =~ (mingw64|msys)_nt* ]]; then
true
else
false
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index 31e8fb9120..fc2c041b6c 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -49,7 +49,6 @@ cc_library(
copts = if_ios(["-DGOOGLE_LOGGING"]),
linkopts = select({
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//tensorflow:darwin": [
"-lm",
"-lpthread",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8b9dd7c14d..10bfe7e0f6 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -157,11 +157,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_googlesource_code_re2",
urls = [
- "https://mirror.bazel.build/github.com/google/re2/archive/2018-04-01.tar.gz",
- "https://github.com/google/re2/archive/2018-04-01.tar.gz",
+ "https://mirror.bazel.build/github.com/google/re2/archive/2018-07-01.tar.gz",
+ "https://github.com/google/re2/archive/2018-07-01.tar.gz",
],
- sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912",
- strip_prefix = "re2-2018-04-01",
+ sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
+ strip_prefix = "re2-2018-07-01",
system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
)
@@ -486,11 +486,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/36f54002c931a026f490f9fb074c11d91e3487a2.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/36f54002c931a026f490f9fb074c11d91e3487a2.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/17454e67ca55357e103cec104c3dc973bbb11ff0.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/17454e67ca55357e103cec104c3dc973bbb11ff0.tar.gz",
],
- sha256 = "e360a9e9b0d4f1adedcdb89fc1efc171f68e250c115ddfaeb82d71edef7a10c8",
- strip_prefix = "llvm-36f54002c931a026f490f9fb074c11d91e3487a2",
+ sha256 = "7543322052e27e70f882801ef70a45afc268e09aaf6a07b840450bfcac366eb6",
+ strip_prefix = "llvm-17454e67ca55357e103cec104c3dc973bbb11ff0",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD
index 1638b72161..c93fac6549 100644
--- a/third_party/curl.BUILD
+++ b/third_party/curl.BUILD
@@ -243,7 +243,6 @@ cc_library(
"lib/vtls/darwinssl.c",
],
"@org_tensorflow//tensorflow:windows": CURL_WIN_SRCS,
- "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_SRCS,
"//conditions:default": [
"lib/vtls/openssl.c",
],
@@ -260,7 +259,6 @@ cc_library(
],
copts = select({
"@org_tensorflow//tensorflow:windows": CURL_WIN_COPTS,
- "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_COPTS,
"//conditions:default": [
"-Iexternal/curl/lib",
"-D_GNU_SOURCE",
@@ -280,10 +278,6 @@ cc_library(
# See curl.h for discussion of write size and Windows
"/DCURL_MAX_WRITE_SIZE=16384",
],
- "@org_tensorflow//tensorflow:windows_msvc": [
- # See curl.h for discussion of write size and Windows
- "/DCURL_MAX_WRITE_SIZE=16384",
- ],
"//conditions:default": [
"-DCURL_MAX_WRITE_SIZE=65536",
],
@@ -307,12 +301,6 @@ cc_library(
"-DEFAULTLIB:crypt32.lib",
"-DEFAULTLIB:Normaliz.lib",
],
- "@org_tensorflow//tensorflow:windows_msvc": [
- "-DEFAULTLIB:ws2_32.lib",
- "-DEFAULTLIB:advapi32.lib",
- "-DEFAULTLIB:crypt32.lib",
- "-DEFAULTLIB:Normaliz.lib",
- ],
"//conditions:default": [
"-lrt",
],
@@ -323,7 +311,6 @@ cc_library(
] + select({
"@org_tensorflow//tensorflow:ios": [],
"@org_tensorflow//tensorflow:windows": [],
- "@org_tensorflow//tensorflow:windows_msvc": [],
"//conditions:default": [
"@boringssl//:ssl",
],
@@ -426,7 +413,6 @@ cc_binary(
],
copts = select({
"@org_tensorflow//tensorflow:windows": CURL_BIN_WIN_COPTS,
- "@org_tensorflow//tensorflow:windows_msvc": CURL_BIN_WIN_COPTS,
"//conditions:default": [
"-Iexternal/curl/lib",
"-D_GNU_SOURCE",
diff --git a/third_party/double_conversion.BUILD b/third_party/double_conversion.BUILD
index 9f905216c0..d875a1a2b5 100644
--- a/third_party/double_conversion.BUILD
+++ b/third_party/double_conversion.BUILD
@@ -4,6 +4,11 @@ licenses(["notice"])
exports_files(["LICENSE"])
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
+
cc_library(
name = "double-conversion",
srcs = [
@@ -28,11 +33,10 @@ cc_library(
"double-conversion/ieee.h",
"double-conversion/strtod.h",
],
- includes = [
- ".",
- ],
- linkopts = [
- "-lm",
- ],
+ includes = ["."],
+ linkopts = select({
+ ":windows": [],
+ "//conditions:default": ["-lm"],
+ }),
visibility = ["//visibility:public"],
)
diff --git a/third_party/farmhash.BUILD b/third_party/farmhash.BUILD
index a51e1511c1..4b8464684a 100644
--- a/third_party/farmhash.BUILD
+++ b/third_party/farmhash.BUILD
@@ -3,13 +3,6 @@ licenses(["notice"]) # MIT
exports_files(["COPYING"])
config_setting(
- name = "windows_msvc",
- values = {
- "cpu": "x64_windows_msvc",
- },
-)
-
-config_setting(
name = "windows",
values = {
"cpu": "x64_windows",
@@ -23,7 +16,6 @@ cc_library(
# Disable __builtin_expect support on Windows
copts = select({
":windows": ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"],
- ":windows_msvc": ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"],
"//conditions:default": [],
}),
includes = ["src/."],
diff --git a/third_party/fft2d/fft2d.BUILD b/third_party/fft2d/fft2d.BUILD
index 3dbd36aec0..74dd3112fc 100644
--- a/third_party/fft2d/fft2d.BUILD
+++ b/third_party/fft2d/fft2d.BUILD
@@ -14,6 +14,11 @@ FFT2D_SRCS = [
"fft/fftsg.c",
]
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
+
# This is the main 2D FFT library. The 2D FFTs in this library call
# 1D FFTs. In addition, fast DCTs are provided for the special case
# of 8x8 and 16x16. This code in this library is referred to as
@@ -21,7 +26,10 @@ FFT2D_SRCS = [
cc_library(
name = "fft2d",
srcs = FFT2D_SRCS,
- linkopts = ["-lm"],
+ linkopts = select({
+ ":windows": [],
+ "//conditions:default": ["-lm"],
+ }),
)
objc_library(
diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD
index 639dff2cd0..4a3701e893 100644
--- a/third_party/flatbuffers/flatbuffers.BUILD
+++ b/third_party/flatbuffers/flatbuffers.BUILD
@@ -12,12 +12,14 @@ config_setting(
visibility = ["//visibility:public"],
)
-FLATBUFFERS_COPTS = [
- "-fexceptions",
-] + select({
- "@bazel_tools//src:windows": [],
- "@bazel_tools//src:windows_msvc": [],
- "//conditions:default": ["-Wno-implicit-fallthrough"],
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
+
+FLATBUFFERS_COPTS = select({
+ ":windows": [],
+ "//conditions:default": ["-Wno-implicit-fallthrough", "-fexceptions"],
})
# Public flatc library to compile flatbuffer files at runtime.
@@ -121,6 +123,7 @@ cc_binary(
":freebsd": [
"-lm",
],
+ ":windows": [],
"//conditions:default": [
"-lm",
"-ldl",
diff --git a/third_party/gif.BUILD b/third_party/gif.BUILD
index 78fbd6c0e0..cbe730fe10 100644
--- a/third_party/gif.BUILD
+++ b/third_party/gif.BUILD
@@ -21,7 +21,6 @@ cc_library(
],
hdrs = ["lib/gif_lib.h"],
defines = select({
- #"@org_tensorflow//tensorflow:android": [
":android": [
"S_IREAD=S_IRUSR",
"S_IWRITE=S_IWUSR",
@@ -33,7 +32,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
":windows": [":windows_polyfill"],
- ":windows_msvc": [":windows_polyfill"],
"//conditions:default": [],
}),
)
@@ -51,13 +49,6 @@ genrule(
)
config_setting(
- name = "windows_msvc",
- values = {
- "cpu": "x64_windows_msvc",
- },
-)
-
-config_setting(
name = "windows",
values = {
"cpu": "x64_windows",
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index e848fa175c..f6a39aeaf1 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -61,6 +61,7 @@ CUDA_LIB_PATHS = [
CUPTI_HEADER_PATHS = [
"extras/CUPTI/include/",
"include/cuda/CUPTI/",
+ "include/",
]
# Lookup paths for the cupti library, relative to the
@@ -69,7 +70,7 @@ CUPTI_HEADER_PATHS = [
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_LIB_PATHS = [
"extras/CUPTI/lib64/",
- "lib/x86_64-linux-gnu",
+ "lib/x86_64-linux-gnu/",
"lib64/",
"extras/CUPTI/libx64/",
"extras/CUPTI/lib/",
@@ -96,6 +97,7 @@ CUDNN_INCLUDE_PATHS = [
NVVM_LIBDEVICE_PATHS = [
"nvvm/libdevice/",
"share/cuda/",
+ "lib/nvidia-cuda-toolkit/libdevice/",
]
# Files used to detect the NVVM libdevice path.
diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD
index b36295ad06..96e7ac061c 100644
--- a/third_party/jpeg/jpeg.BUILD
+++ b/third_party/jpeg/jpeg.BUILD
@@ -22,7 +22,6 @@ libjpegturbo_copts = select({
"-w",
],
":windows": WIN_COPTS,
- ":windows_msvc": WIN_COPTS,
"//conditions:default": [
"-O3",
"-w",
@@ -425,7 +424,6 @@ genrule(
outs = ["jconfig.h"],
cmd = select({
":windows": "cp $(location jconfig_win.h) $@",
- ":windows_msvc": "cp $(location jconfig_win.h) $@",
":k8": "cp $(location jconfig_nowin_simd.h) $@",
":armeabi-v7a": "cp $(location jconfig_nowin_simd.h) $@",
":arm64-v8a": "cp $(location jconfig_nowin_simd.h) $@",
@@ -443,7 +441,6 @@ genrule(
outs = ["jconfigint.h"],
cmd = select({
":windows": "cp $(location jconfigint_win.h) $@",
- ":windows_msvc": "cp $(location jconfigint_win.h) $@",
"//conditions:default": "cp $(location jconfigint_nowin.h) $@",
}),
)
@@ -544,11 +541,6 @@ config_setting(
)
config_setting(
- name = "windows_msvc",
- values = {"cpu": "x64_windows_msvc"},
-)
-
-config_setting(
name = "linux_ppc64le",
values = {"cpu": "ppc"},
)
diff --git a/third_party/lmdb.BUILD b/third_party/lmdb.BUILD
index 9b3e1d97c8..f36a698ee3 100644
--- a/third_party/lmdb.BUILD
+++ b/third_party/lmdb.BUILD
@@ -20,7 +20,6 @@ cc_library(
],
linkopts = select({
":windows": ["-DEFAULTLIB:advapi32.lib"], # InitializeSecurityDescriptor, SetSecurityDescriptorDacl
- ":windows_msvc": ["-DEFAULTLIB:advapi32.lib"],
"//conditions:default": ["-lpthread"],
}),
visibility = ["//visibility:public"],
@@ -30,8 +29,3 @@ config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
)
-
-config_setting(
- name = "windows_msvc",
- values = {"cpu": "x64_windows_msvc"},
-)
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index a058c46cc4..efff7fd51b 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -2,17 +2,28 @@ licenses(["notice"]) # 3-Clause BSD
config_setting(
name = "using_mkl",
- values = {
- "define": "using_mkl=true",
+ define_values = {
+ "using_mkl": "true",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "using_mkl_ml_only",
+ define_values = {
+ "using_mkl": "true",
+ "using_mkl_ml_only": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "using_mkl_lnx_x64",
+ define_values = {
+ "using_mkl": "true",
+ },
values = {
"cpu": "k8",
- "define": "using_mkl=true",
},
visibility = ["//visibility:public"],
)
diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl
index 53e02769da..06a8c3518c 100644
--- a/third_party/mkl/build_defs.bzl
+++ b/third_party/mkl/build_defs.bzl
@@ -1,6 +1,9 @@
# -*- Python -*-
"""Skylark macros for MKL.
if_mkl is a conditional to check if MKL is enabled or not.
+if_mkl_ml is a conditional to check if MKL-ML is enabled.
+if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode.
+if_mkl_lnx_x64 is a conditional to check for MKL
mkl_repository is a repository rule for creating MKL repository rule that can
be pointed to either a local folder, or download it from the internet.
@@ -15,27 +18,89 @@ _TF_MKL_ROOT = "TF_MKL_ROOT"
def if_mkl(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with MKL.
- Returns a select statement which evaluates to if_true if we're building
- with MKL enabled. Otherwise, the select statement evaluates to if_false.
+ Args:
+ if_true: expression to evaluate if building with MKL.
+ if_false: expression to evaluate if building without MKL.
+ Returns:
+ a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl")): if_true,
- "//conditions:default": if_false
+ "//third_party/mkl:using_mkl": if_true,
+ "//conditions:default": if_false,
+ })
+
+def if_mkl_ml(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with MKL-ML.
+
+ Args:
+ if_true: expression to evaluate if building with MKL-ML.
+ if_false: expression to evaluate if building without MKL-ML
+ (i.e. without MKL at all, or with MKL-DNN only).
+
+ Returns:
+ a select evaluating to either if_true or if_false as appropriate.
+ """
+ return select({
+ "//third_party/mkl_dnn:using_mkl_dnn_only":
+ if_false,
+ "//third_party/mkl:using_mkl": if_true,
+ "//conditions:default": if_false,
+ })
+
+def if_mkl_ml_only(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with MKL-ML only.
+
+ Args:
+ if_true: expression to evaluate if building with MKL-ML only.
+ if_false: expression to evaluate if building without MKL, or with MKL-DNN.
+
+ Returns:
+ a select evaluating to either if_true or if_false as appropriate.
+ """
+ return select({
+ "//third_party/mkl:using_mkl_ml_only": if_true,
+ "//conditions:default": if_false,
})
def if_mkl_lnx_x64(if_true, if_false = []):
- """Shorthand for select()'ing on whether we're building with MKL.
+ """Shorthand to select() on if MKL is on and the target is Linux x86-64.
- Returns a select statement which evaluates to if_true if we're building
- with MKL enabled. Otherwise, the select statement evaluates to if_false.
+ Args:
+ if_true: expression to evaluate if building with MKL is enabled and the
+ target platform is Linux x86-64.
+ if_false: expression to evaluate if building without MKL or for a
+ different platform.
+ Returns:
+ a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
- "//conditions:default": if_false
+ "//third_party/mkl:using_mkl_lnx_x64": if_true,
+ "//conditions:default": if_false,
})
+def mkl_deps():
+ """Shorthand for select() to pull in the correct set of MKL library deps.
+
+ Can pull in MKL-ML, MKL-DNN, both, or neither depending on config settings.
+
+ Returns:
+ a select evaluating to a list of library dependencies, suitable for
+ inclusion in the deps attribute of rules.
+ """
+ return select({
+ "//third_party/mkl_dnn:using_mkl_dnn_only":
+ ["@mkl_dnn"],
+ "//third_party/mkl:using_mkl_ml_only":
+ ["//third_party/mkl:intel_binary_blob"],
+ "//third_party/mkl:using_mkl":
+ [
+ "//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn"
+ ],
+ "//conditions:default": []
+ })
def _enable_local_mkl(repository_ctx):
return _TF_MKL_ROOT in repository_ctx.os.environ
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
index d075809ee9..3e567fa9fc 100644
--- a/third_party/mkl_dnn/BUILD
+++ b/third_party/mkl_dnn/BUILD
@@ -4,8 +4,9 @@ exports_files(["LICENSE"])
config_setting(
name = "using_mkl_dnn_only",
- values = {
- "define": "using_mkl_dnn_only=true",
+ define_values = {
+ "using_mkl": "true",
+ "using_mkl_dnn_only": "true",
},
visibility = ["//visibility:public"],
)
diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD
index 89330eac54..2b877883b9 100644
--- a/third_party/nasm.BUILD
+++ b/third_party/nasm.BUILD
@@ -142,7 +142,6 @@ cc_binary(
],
copts = select({
":windows": [],
- ":windows_msvc": [],
"//conditions:default": [
"-w",
"-std=c99",
@@ -150,7 +149,6 @@ cc_binary(
}),
defines = select({
":windows": [],
- ":windows_msvc": [],
"//conditions:default": [
"HAVE_SNPRINTF",
"HAVE_SYS_TYPES_H",
@@ -160,13 +158,6 @@ cc_binary(
)
config_setting(
- name = "windows_msvc",
- values = {
- "cpu": "x64_windows_msvc",
- },
-)
-
-config_setting(
name = "windows",
values = {
"cpu": "x64_windows",
diff --git a/third_party/png.BUILD b/third_party/png.BUILD
index 17c5449cc0..c26a289717 100644
--- a/third_party/png.BUILD
+++ b/third_party/png.BUILD
@@ -29,6 +29,10 @@ cc_library(
"pngwtran.c",
"pngwutil.c",
] + select({
+ ":windows": [
+ "intel/intel_init.c",
+ "intel/filter_sse2_intrinsics.c",
+ ],
"@org_tensorflow//tensorflow:linux_ppc64le": [
"powerpc/powerpc_init.c",
"powerpc/filter_vsx_intrinsics.c",
@@ -41,7 +45,14 @@ cc_library(
"pngconf.h",
],
includes = ["."],
- linkopts = ["-lm"],
+ copts = select({
+ ":windows": ["-DPNG_INTEL_SSE_OPT=1"],
+ "//conditions:default": [],
+ }),
+ linkopts = select({
+ ":windows": [],
+ "//conditions:default": ["-lm"],
+ }),
visibility = ["//visibility:public"],
deps = ["@zlib_archive//:zlib"],
)
@@ -52,3 +63,8 @@ genrule(
outs = ["pnglibconf.h"],
cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x12b0/' $< >$@",
)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD
index cc11f52d0e..d93f030769 100644
--- a/third_party/snappy.BUILD
+++ b/third_party/snappy.BUILD
@@ -18,17 +18,9 @@ cc_library(
"snappy-stubs-public.h",
],
hdrs = ["snappy.h"],
- copts = select({
- "@org_tensorflow//tensorflow:windows": [
- "/DHAVE_CONFIG_H",
- "/EHsc",
- ],
- "@org_tensorflow//tensorflow:windows_msvc": [
- "/DHAVE_CONFIG_H",
- "/EHsc",
- ],
+ copts = ["-DHAVE_CONFIG_H"] + select({
+ "@org_tensorflow//tensorflow:windows": [],
"//conditions:default": [
- "-DHAVE_CONFIG_H",
"-fno-exceptions",
"-Wno-sign-compare",
"-Wno-shift-negative-value",
diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD
index 2876f305f1..8b876fb56f 100644
--- a/third_party/sqlite.BUILD
+++ b/third_party/sqlite.BUILD
@@ -4,7 +4,6 @@
licenses(["unencumbered"]) # Public Domain
SQLITE_COPTS = [
- "-Os",
"-DSQLITE_ENABLE_JSON1",
"-DHAVE_DECL_STRERROR_R=1",
"-DHAVE_STDINT_H=1",
@@ -15,15 +14,14 @@ SQLITE_COPTS = [
"@org_tensorflow//tensorflow:windows": [
"-DSQLITE_MAX_TRIGGER_DEPTH=100",
],
- "@org_tensorflow//tensorflow:windows_msvc": [
- "-DSQLITE_MAX_TRIGGER_DEPTH=100",
- ],
"@org_tensorflow//tensorflow:darwin": [
+ "-Os",
"-DHAVE_GMTIME_R=1",
"-DHAVE_LOCALTIME_R=1",
"-DHAVE_USLEEP=1",
],
"//conditions:default": [
+ "-Os",
"-DHAVE_FDATASYNC=1",
"-DHAVE_GMTIME_R=1",
"-DHAVE_LOCALTIME_R=1",
@@ -48,7 +46,7 @@ cc_library(
"SQLITE_OMIT_DEPRECATED",
],
linkopts = select({
- "@org_tensorflow//tensorflow:windows_msvc": [],
+ "@org_tensorflow//tensorflow:windows": [],
"//conditions:default": [
"-ldl",
"-lpthread",
diff --git a/third_party/swig.BUILD b/third_party/swig.BUILD
index f2f647401b..59a3d9e671 100644
--- a/third_party/swig.BUILD
+++ b/third_party/swig.BUILD
@@ -71,7 +71,6 @@ cc_binary(
],
copts = ["$(STACK_FRAME_UNLIMITED)"] + select({
":windows": [],
- ":windows_msvc": [],
"//conditions:default": [
"-Wno-parentheses",
"-Wno-unused-variable",
@@ -332,11 +331,6 @@ genrule(
)
config_setting(
- name = "windows_msvc",
- values = {"cpu": "x64_windows_msvc"},
-)
-
-config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
)
diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD
index e8048dd98a..33694eaaae 100644
--- a/third_party/zlib.BUILD
+++ b/third_party/zlib.BUILD
@@ -34,7 +34,6 @@ cc_library(
hdrs = ["zlib.h"],
copts = select({
"@org_tensorflow//tensorflow:windows": [],
- "@org_tensorflow//tensorflow:windows_msvc": [],
"//conditions:default": [
"-Wno-shift-negative-value",
"-DZ_HAVE_UNISTD_H",