aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sami Kama <skama@nvidia.com>2018-06-12 22:25:58 -0700
committerGravatar Sami Kama <skama@nvidia.com>2018-06-12 22:25:58 -0700
commitd56675f6662b651dc996786262a6d28ccf9e06e7 (patch)
tree5fc14a49a2b59649815aef242a4a4c66ca8a4a75
parent4979c54d90a7fdd7429feb50edd1520d819c9653 (diff)
parent565640eae327b092edf43613f77ba5ab0747d20d (diff)
Merge conflict fix
-rw-r--r--CONTRIBUTING.md2
-rw-r--r--README.md1
-rw-r--r--RELEASE.md59
-rw-r--r--SECURITY.md11
-rw-r--r--tensorflow/BUILD7
-rw-r--r--tensorflow/api_template.__init__.py17
-rw-r--r--tensorflow/c/c_api.cc17
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc71
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc22
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc22
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h3
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h41
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc7
-rw-r--r--tensorflow/compiler/tests/BUILD2
-rw-r--r--tensorflow/compiler/tests/eager_test.py71
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py38
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc74
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc15
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc1
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc31
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h12
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc1
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc38
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc11
-rw-r--r--tensorflow/compiler/xla/literal_util.cc22
-rw-r--r--tensorflow/compiler/xla/literal_util.h6
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc2
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py2
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py12
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc4
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service.h2
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc6
-rw-r--r--tensorflow/compiler/xla/rpc/xla_service.proto16
-rw-r--r--tensorflow/compiler/xla/service/BUILD52
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc31
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc85
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc168
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h6
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc38
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.h1
-rw-r--r--tensorflow/compiler/xla/service/compilation_cache.cc78
-rw-r--r--tensorflow/compiler/xla/service/compilation_cache.h78
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.h9
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc291
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h16
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc8
-rw-r--r--tensorflow/compiler/xla/service/executable.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD30
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc25
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc498
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h35
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc118
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc138
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_casting_utils.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_casting_utils_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc45
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc123
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc63
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc627
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h223
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc589
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h438
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h17
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc122
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc27
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc71
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc50
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc48
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h175
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc33
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h59
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc1
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc342
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h160
-rw-r--r--tensorflow/compiler/xla/service/service.h5
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc357
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h25
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc144
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc27
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.cc31
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h9
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc77
-rw-r--r--tensorflow/compiler/xla/service/versioned_computation_handle.cc32
-rw-r--r--tensorflow/compiler/xla/service/versioned_computation_handle.h55
-rw-r--r--tensorflow/compiler/xla/tests/BUILD16
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc3
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h17
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc20
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h16
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc134
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc124
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc52
-rw-r--r--tensorflow/compiler/xla/xla.proto94
-rw-r--r--tensorflow/compiler/xla/xla_data.proto516
-rw-r--r--tensorflow/contrib/autograph/__init__.py7
-rw-r--r--tensorflow/contrib/autograph/converters/BUILD12
-rw-r--r--tensorflow/contrib/autograph/converters/lists.py233
-rw-r--r--tensorflow/contrib/autograph/converters/lists_test.py130
-rw-r--r--tensorflow/contrib/autograph/converters/slices.py83
-rw-r--r--tensorflow/contrib/autograph/converters/slices_test.py59
-rw-r--r--tensorflow/contrib/autograph/impl/BUILD1
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py5
-rw-r--r--tensorflow/contrib/autograph/impl/directives.py68
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info.py40
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py18
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py15
-rw-r--r--tensorflow/contrib/batching/__init__.py1
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py69
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py50
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo.py5
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py40
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py22
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py8
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc54
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc14
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h7
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc48
-rw-r--r--tensorflow/contrib/boosted_trees/ops/prediction_ops.cc70
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py1
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py87
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py117
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py44
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake45
-rw-r--r--tensorflow/contrib/control_flow/BUILD5
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2.py69
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2_test.py59
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc84
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD30
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py21
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py207
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py218
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py100
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py45
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD5
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py12
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py10
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py3
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py21
-rw-r--r--tensorflow/contrib/distribute/python/BUILD7
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py29
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py221
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py156
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py145
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py2
-rw-r--r--tensorflow/contrib/distributions/BUILD57
-rw-r--r--tensorflow/contrib/distributions/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py98
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py69
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py28
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py66
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py148
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py114
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py102
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py98
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py202
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb4
-rw-r--r--tensorflow/contrib/layers/__init__.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py23
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py2
-rw-r--r--tensorflow/contrib/lite/build_def.bzl3
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h3
-rw-r--r--tensorflow/contrib/lite/examples/label_image/BUILD31
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc28
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h4
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc12
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image_test.cc16
-rw-r--r--tensorflow/contrib/lite/g3doc/ops_versioning.md206
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md41
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc28
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc66
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc333
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc110
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h16
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv_test.cc18
-rw-r--r--tensorflow/contrib/lite/model.cc3
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc3
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.cc11
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert.py32
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py31
-rw-r--r--tensorflow/contrib/lite/python/lite.py20
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py68
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py38
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs11
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h239
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py122
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc6
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc2
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc103
-rw-r--r--tensorflow/contrib/lite/toco/model.h23
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc2
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.cc13
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.cc27
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc5
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto9
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc17
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD5
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/README.md108
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc50
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h4
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc6
-rw-r--r--tensorflow/contrib/lite/tools/verifier_test.cc6
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py20
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py42
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py5
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py7
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.cc6
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.cc16
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.h16
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt69
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt40
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.cc3
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.h3
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.cc18
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h9
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr_test.cc11
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc10
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h6
-rw-r--r--tensorflow/core/common_runtime/executor.cc27
-rw-r--r--tensorflow/core/common_runtime/executor_factory.cc85
-rw-r--r--tensorflow/core/common_runtime/executor_factory.h51
-rw-r--r--tensorflow/core/common_runtime/executor_test.cc10
-rw-r--r--tensorflow/core/common_runtime/function.cc16
-rw-r--r--tensorflow/core/common_runtime/function_test.cc72
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc18
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h4
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h6
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.h2
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc2
-rw-r--r--tensorflow/core/distributed_runtime/BUILD50
-rw-r--r--tensorflow/core/distributed_runtime/cancellable_call.h65
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc48
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc7
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.cc42
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc26
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h8
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h5
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc78
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h3
-rw-r--r--tensorflow/core/distributed_runtime/master_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD3
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_call.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc45
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h15
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_state.h4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_util.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h18
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc142
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h79
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc124
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc10
-rw-r--r--tensorflow/core/framework/cost_graph.proto3
-rw-r--r--tensorflow/core/framework/device_base.h4
-rw-r--r--tensorflow/core/framework/function.cc4
-rw-r--r--tensorflow/core/framework/function.h6
-rw-r--r--tensorflow/core/framework/op_kernel.cc26
-rw-r--r--tensorflow/core/framework/op_kernel.h12
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc22
-rw-r--r--tensorflow/core/graph/control_flow.cc11
-rw-r--r--tensorflow/core/graph/control_flow.h6
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD3
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD35
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc17
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h4
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc26
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h8
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h46
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc149
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc2
-rw-r--r--tensorflow/core/kernels/BUILD12
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc2
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc2
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc54
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc46
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc6
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc49
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc217
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc185
-rw-r--r--tensorflow/core/kernels/functional_ops.cc10
-rw-r--r--tensorflow/core/kernels/matmul_op.cc3
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc11
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc7
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc8
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_identity_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc16
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc10
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc7
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc10
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h2
-rw-r--r--tensorflow/core/kernels/mkl_transpose_op.cc2
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc24
-rw-r--r--tensorflow/core/kernels/tensor_array.cc10
-rw-r--r--tensorflow/core/kernels/tensor_array.h4
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc46
-rw-r--r--tensorflow/core/kernels/transpose_op.cc2
-rw-r--r--tensorflow/core/kernels/transpose_op.h4
-rw-r--r--tensorflow/core/lib/io/random_inputstream.cc10
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc10
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt169
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc44
-rw-r--r--tensorflow/core/ops/dataset_ops.cc69
-rw-r--r--tensorflow/core/ops/functional_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt169
-rw-r--r--tensorflow/core/ops/state_ops.cc9
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/exec_on_stall_test.cc23
-rw-r--r--tensorflow/core/util/mkl_util.h59
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h7
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md50
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md32
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.md83
-rw-r--r--tensorflow/docs_src/get_started/eager.md2
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md24
-rw-r--r--tensorflow/docs_src/install/install_linux.md18
-rw-r--r--tensorflow/docs_src/install/install_mac.md10
-rw-r--r--tensorflow/docs_src/install/install_sources.md9
-rw-r--r--tensorflow/docs_src/mobile/tflite/index.md17
-rw-r--r--tensorflow/docs_src/programmers_guide/debugger.md26
-rw-r--r--tensorflow/docs_src/programmers_guide/keras.md870
-rw-r--r--tensorflow/go/op/wrappers.go1086
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml4
-rw-r--r--tensorflow/java/maven/run_inside_container.sh2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc11
-rw-r--r--tensorflow/python/client/session.py159
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py243
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py21
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py222
-rw-r--r--tensorflow/python/data/ops/readers.py3
-rw-r--r--tensorflow/python/data/util/BUILD1
-rw-r--r--tensorflow/python/data/util/convert.py37
-rw-r--r--tensorflow/python/data/util/convert_test.py73
-rw-r--r--tensorflow/python/eager/function.py17
-rw-r--r--tensorflow/python/eager/function_test.py17
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc82
-rw-r--r--tensorflow/python/estimator/BUILD9
-rw-r--r--tensorflow/python/estimator/api/BUILD17
-rw-r--r--tensorflow/python/estimator/canned/baseline.py6
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py6
-rw-r--r--tensorflow/python/estimator/canned/dnn.py6
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py6
-rw-r--r--tensorflow/python/estimator/canned/linear.py6
-rw-r--r--tensorflow/python/estimator/canned/parsing_utils.py6
-rw-r--r--tensorflow/python/estimator/estimator.py24
-rw-r--r--tensorflow/python/estimator/estimator_test.py42
-rw-r--r--tensorflow/python/estimator/export/export.py10
-rw-r--r--tensorflow/python/estimator/export/export_output.py10
-rw-r--r--tensorflow/python/estimator/exporter.py10
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py4
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py4
-rw-r--r--tensorflow/python/estimator/model_fn.py6
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/python/estimator/training.py8
-rw-r--r--tensorflow/python/feature_column/feature_column.py4
-rw-r--r--tensorflow/python/framework/function.py10
-rw-r--r--tensorflow/python/framework/test_util.py99
-rw-r--r--tensorflow/python/keras/backend.py19
-rw-r--r--tensorflow/python/keras/backend_test.py47
-rw-r--r--tensorflow/python/keras/datasets/boston_housing.py3
-rw-r--r--tensorflow/python/keras/datasets/fashion_mnist.py8
-rw-r--r--tensorflow/python/keras/datasets/imdb.py6
-rw-r--r--tensorflow/python/keras/datasets/mnist.py10
-rw-r--r--tensorflow/python/keras/datasets/reuters.py6
-rw-r--r--tensorflow/python/keras/engine/network.py10
-rw-r--r--tensorflow/python/keras/engine/saving_test.py52
-rw-r--r--tensorflow/python/keras/engine/sequential.py7
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py24
-rw-r--r--tensorflow/python/keras/engine/training.py11
-rw-r--r--tensorflow/python/keras/engine/training_eager.py25
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py72
-rw-r--r--tensorflow/python/keras/engine/training_test.py70
-rw-r--r--tensorflow/python/keras/engine/training_utils.py30
-rw-r--r--tensorflow/python/keras/layers/convolutional.py83
-rw-r--r--tensorflow/python/keras/layers/core.py30
-rw-r--r--tensorflow/python/keras/layers/local.py44
-rw-r--r--tensorflow/python/keras/layers/local_test.py83
-rw-r--r--tensorflow/python/keras/layers/normalization.py6
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py20
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py6
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py79
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py54
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor.cc40
-rw-r--r--tensorflow/python/ops/custom_gradient.py2
-rw-r--r--tensorflow/python/ops/data_flow_ops.py46
-rw-r--r--tensorflow/python/ops/gradients_test.py1
-rw-r--r--tensorflow/python/ops/image_ops_impl.py22
-rw-r--r--tensorflow/python/ops/image_ops_test.py83
-rw-r--r--tensorflow/python/ops/linalg_grad.py11
-rw-r--r--tensorflow/python/ops/lookup_ops.py8
-rw-r--r--tensorflow/python/ops/math_ops.py6
-rw-r--r--tensorflow/python/ops/script_ops.py128
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py21
-rw-r--r--tensorflow/python/ops/state_ops.py72
-rw-r--r--tensorflow/python/ops/tensor_array_grad.py1
-rw-r--r--tensorflow/python/saved_model/BUILD1
-rw-r--r--tensorflow/python/saved_model/builder_impl.py46
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py75
-rw-r--r--tensorflow/python/training/adadelta.py17
-rw-r--r--tensorflow/python/training/adadelta_test.py116
-rw-r--r--tensorflow/python/training/adagrad.py12
-rw-r--r--tensorflow/python/training/adagrad_test.py73
-rw-r--r--tensorflow/python/training/adam.py20
-rw-r--r--tensorflow/python/training/adam_test.py18
-rw-r--r--tensorflow/python/training/distribute.py2
-rw-r--r--tensorflow/python/training/gradient_descent.py15
-rw-r--r--tensorflow/python/training/gradient_descent_test.py26
-rw-r--r--tensorflow/python/training/momentum.py4
-rw-r--r--tensorflow/python/training/optimizer.py4
-rw-r--r--tensorflow/python/training/rmsprop.py22
-rw-r--r--tensorflow/python/training/rmsprop_test.py54
-rw-r--r--tensorflow/python/util/tf_export.py58
-rw-r--r--tensorflow/python/util/tf_export_test.py7
-rw-r--r--tensorflow/python/util/util.cc11
-rw-r--r--tensorflow/security/index.md4
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc59
-rw-r--r--tensorflow/tools/api/generator/api_gen.bzl20
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py35
-rw-r--r--tensorflow/tools/api/generator/create_python_api_test.py9
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt4
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh4
-rwxr-xr-xtensorflow/tools/ci_build/copy_binary.py3
-rwxr-xr-xtensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh29
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-cpu-mkl2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc26
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc46
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/tools/pip_package/setup.py3
-rw-r--r--tensorflow/workspace.bzl17
-rw-r--r--third_party/clang_toolchain/download_clang.bzl8
-rw-r--r--third_party/eigen_fix_cuda_compilation.patch38
560 files changed, 16153 insertions, 6741 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 8669c25c45..db4b1581ae 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g.,
Changes to TensorFlow C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
-Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do:
+Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
diff --git a/README.md b/README.md
index 6fb4486d0d..63853137cf 100644
--- a/README.md
+++ b/README.md
@@ -56,6 +56,7 @@ $ python
42
>>> sess.close()
```
+Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/).
## Contribution guidelines
diff --git a/RELEASE.md b/RELEASE.md
index 27f73b7fc6..e09e9c6190 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,62 @@
+# Release 1.9.0
+
+## Major Features And Improvements
+* Update tf.keras to the Keras 2.1.6 API.
+* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`.
+* Adding support of core feature columns and losses to gradient boosted trees estimators.
+* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details.
+* Layered variable names have changed in the following conditions:
+ * Using `tf.keras.layers` with custom variable scopes.
+ * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details
+
+## Breaking Chances
+ * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...).
+
+## Bug Fixes and Other Changes
+* `tf.data`:
+ * The `DatasetBase::DebugString()` method is now `const`.
+ * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets.
+* Eager Execution:
+* `tf.keras`:
+ * Move Keras code out of _impl folder and remove API files.
+ * `tf.keras.Model.save_weights` now saves in TensorFlow format by default.
+ * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods.
+* Accelerated Linear Algebra (XLA):
+* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB).
+* `tf.contrib`:
+ * Add `tf.contrib.data.choose_from_datasets()`.
+ * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`.
+ * `tf.contrib.framework.zero_initializer` supports ResourceVariable.
+ * Adding "constrained_optimization" to tensorflow/contrib.
+* Other:
+ * Add GCS Configuration Ops.
+ * Changing signature of `MakeIterator` to enable propagating error status.
+ * KL divergence for two Dirichlet distributions.
+ * More consistent GcsFileSystem behavior for certain reads past EOF.
+ * Update benchmark for tf.scan to match ranges across eager and graph modes.
+ * Fixed bug in `tf.reduce_prod gradient` for complex dtypes.
+ * Add optional `args` argument to `Dataset.from_generator()`.
+ * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)").
+ * Benchmark for tf.scan in graph and eager modes.
+ * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
+ * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch.
+ * Support indicator column in boosted trees.
+ * Prevent `tf.gradients()` from backpropagating through integer tensors.
+ * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`.
+ * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary.
+ * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints.
+ * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
+ * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product.
+ * Allow LinearOperator to broadcast.
+ * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other.
+
+
+## Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang
+
# Release 1.8.0
## Major Features And Improvements
diff --git a/SECURITY.md b/SECURITY.md
index 0a4be37cbc..e2f6ff353a 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -242,12 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
-----END PGP PUBLIC KEY BLOCK-----
```
-### Known vulnerabilities
-
-| Type | Versions affected | Reported by | Additional Information |
-|--------------------|:-----------------:|-----------------------|-----------------------------|
-| TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-003.md) |
-| GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-002.md) |
-| BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-001.md) |
-| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+### Known Vulnerabilities
+For a list of known vulnerabilities and security advisories for TensorFlow,
+(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here].
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 9b07669a5d..6d134dbb80 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -539,14 +539,17 @@ exports_files(
)
gen_api_init_files(
- name = "python_api_gen",
+ name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
root_init_template = "api_template.__init__.py",
)
py_library(
name = "tensorflow_py",
- srcs = [":python_api_gen"],
+ srcs = [
+ ":tensorflow_python_api_gen",
+ "//tensorflow/python/estimator/api:estimator_python_api_gen",
+ ],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 9b0d7d48af..9662d7b478 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -22,7 +22,22 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER
-from tensorflow.python.util.lazy_loader import LazyLoader
+try:
+ import os # pylint: disable=g-import-not-at-top
+ # Add `estimator` attribute to allow access to estimator APIs via
+ # "tf.estimator..."
+ from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
+
+ # Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
+ # style imports.
+ from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
+ __path__ += [os.path.dirname(estimator_api.__file__)]
+ del estimator_api
+ del os
+except (ImportError, AttributeError):
+ print('tf.estimator package not installed.')
+
+from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index b86b277ac3..cb0b093ad2 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -631,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
"Failed to allocate memory to serialize message of type '",
in.GetTypeName(), "' and size ", proto_size);
}
- in.SerializeToArray(buf, proto_size);
+ // SerializeToArray takes size as an int.
+ // This next 'if' is a workaround till we update to depend on a version
+ // of protocol buffers that includes
+ // https://github.com/google/protobuf/pull/4739
+ if (proto_size > std::numeric_limits<int>::max()) {
+ return InvalidArgument("Cannot serialize protocol buffer of type ",
+ in.GetTypeName(), " as the serialized size (",
+ proto_size,
+ "bytes) would be larger than the limit (",
+ std::numeric_limits<int>::max(), " bytes)");
+ }
+ if (!in.SerializeToArray(buf, proto_size)) {
+ return InvalidArgument("Unable to serialize ", in.GetTypeName(),
+ " protocol buffer, perhaps the serialized size (",
+ proto_size, " bytes) is too large?");
+ }
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index d6a4f141b6..dfdef88945 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) {
return "<Unknown AttrValue type>"; // Prevent missing return warning
}
+bool IsEmptyList(const AttrValue::ListValue& list) {
+ return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
+ list.b_size() == 0 && list.type_size() == 0 &&
+ list.shape_size() == 0 && list.tensor_size() == 0;
+}
+
string ToCamelCase(const string& str) {
string result;
const char joiner = '_';
@@ -297,9 +303,9 @@ string ToCamelCase(const string& str) {
// indicate whether to treat the type as const when accepting the C++ type as an
// argument to a function.
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
- static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
- StringPieceHasher>
- attr_type_map{
+ static const auto* attr_type_map =
+ new std::unordered_map<StringPiece, std::pair<const char*, bool>,
+ StringPieceHasher>{
{"string", {"StringPiece", false}},
{"list(string)", {"gtl::ArraySlice<string>", true}},
{"int", {"int64", false}},
@@ -317,14 +323,34 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
{"func", {"NameAttrList", true}},
};
- auto entry = attr_type_map.find(attr_type);
- if (entry == attr_type_map.end()) {
+ auto entry = attr_type_map->find(attr_type);
+ if (entry == attr_type_map->end()) {
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
return {"", false};
}
return entry->second;
}
+const char* ListElementTypeName(StringPiece attr_type) {
+ static const auto* attr_list_type_map =
+ new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
+ {"list(string)", "string"},
+ {"list(int)", "int"},
+ {"list(float)", "float"},
+ {"list(bool)", "bool"},
+ {"list(type)", "DataType"},
+ {"list(shape)", "PartialTensorShape"},
+ {"list(tensor)", "TensorProto"},
+ };
+
+ auto entry = attr_list_type_map->find(attr_type);
+ if (entry == attr_list_type_map->end()) {
+ LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
+ return "";
+ }
+ return entry->second;
+}
+
bool IsCPPKeyword(StringPiece name) {
static const std::unordered_set<StringPiece, StringPieceHasher>
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
@@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
+ string defaults_static_storage;
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
@@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const {
"_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
- strings::StrAppend(
- &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
- "_ = ",
- PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
- ";\n");
+ string field_initiliazer;
+ auto& default_value = api_def_attr.default_value();
+ if (default_value.value_case() == AttrValue::kList &&
+ !IsEmptyList(default_value.list())) {
+ // Non-empty lists need static storage for their defaults. Define a
+ // function with static local variable that stores the array.
+ strings::StrAppend(&defaults_static_storage, " static ",
+ attr_type_name, " Default_", api_def_attr.rename_to(),
+ "() {\n");
+ strings::StrAppend(
+ &defaults_static_storage, " static const ",
+ ListElementTypeName(attr.type()), " kStorage[] = ",
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
+ ";\n");
+ strings::StrAppend(&defaults_static_storage, " return ",
+ attr_type_name, "(kStorage);\n }\n");
+ // Set the field_initializer to call the defined function.
+ strings::StrAppend(&field_initiliazer, "Default_",
+ api_def_attr.rename_to(), "()");
+ } else {
+ field_initiliazer =
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
+ }
+ strings::StrAppend(&struct_fields, " ", attr_type_name, " ",
+ api_def_attr.rename_to(), "_ = ", field_initiliazer,
+ ";\n");
}
if (struct_fields.empty()) {
@@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const {
string struct_decl = MakeComment(attrs_comment, " ");
strings::StrAppend(&struct_decl, " struct Attrs {\n");
strings::StrAppend(&struct_decl, setters, struct_fields);
+ if (!defaults_static_storage.empty()) {
+ strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage);
+ }
strings::StrAppend(&struct_decl, " };\n");
return struct_decl;
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index ab8cd8f4bc..51a79e2cd9 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -181,6 +181,7 @@ cc_library(
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
+ "//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops",
],
)
@@ -342,6 +343,7 @@ cc_library(
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
],
)
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 74468266b9..8c3882116d 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -44,12 +44,6 @@ namespace tensorflow {
namespace {
-// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward
-// a ref tensor input to its output.
-static bool AlwaysForwardsRefInput(const Node& node) {
- return node.IsIdentity();
-}
-
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
@@ -68,20 +62,8 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// XLA does not offer guaranteed aliasing between the input and output of the
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
- if (AlwaysForwardsRefInput(node)) {
- for (const Edge* incoming_edge : node.in_edges()) {
- if (incoming_edge->IsControlEdge()) {
- continue;
- }
-
- Node* incoming_node = incoming_edge->src();
- if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
- VLOG(2) << "Not clustering " << node.def().ShortDebugString()
- << " because of ref input " << incoming_node->name() << " "
- << incoming_node->type_string();
- return false;
- }
- }
+ if (HasForwardedRefInput(node)) {
+ return false;
}
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 70bd10336b..05b7821b88 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <unordered_map>
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -66,6 +67,9 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
}
return description;
}
+
+bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
+
} // namespace
Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
@@ -77,6 +81,24 @@ Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
return Status::OK();
}
+bool HasForwardedRefInput(const Node& node) {
+ if (AlwaysForwardsRefInput(node)) {
+ for (const Edge* incoming_edge : node.in_edges()) {
+ if (incoming_edge->IsControlEdge()) {
+ continue;
+ }
+
+ Node* incoming_node = incoming_edge->src();
+ if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
+ VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
+ << incoming_node->name() << " " << incoming_node->type_string();
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index 5b673bdc27..bcce082aaf 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -36,6 +36,9 @@ using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
// Returns the DeviceType corresponding to 'device'.
Status DeviceToDeviceType(const string& device, DeviceType* device_type);
+// Returns true if `node` has a ref tensor input that it forwards to its output.
+bool HasForwardedRefInput(const Node& node);
+
// Creates a graph representation to enable cycle detection when clustering.
// This representation handles loops in graph by disconnecting each loop from
// the enclosing graph.
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 0c49286acd..11e45d2823 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
+#include "tensorflow/core/kernels/shape_ops.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@@ -87,6 +88,46 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("Shape") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Shape") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeOp<int64>); \
+ REGISTER_KERNEL_BUILDER(Name("ShapeN") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeNOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("ShapeN") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeNOp<int64>); \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ SizeOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ SizeOp<int64>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \
+ TYPES), \
+ RankOp); \
REGISTER_KERNEL_BUILDER( \
Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
XlaAssignVariableOp); \
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 96016521ea..74257b09a8 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -178,6 +178,13 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
continue;
}
+ // XLA does not offer guaranteed aliasing between the input and output of
+ // the XLA cluster so it can't implement the forward-tensor-ref semantic.
+ // Leave such nodes out of XLA clusters.
+ if (HasForwardedRefInput(*node)) {
+ continue;
+ }
+
compilation_candidates.insert(node);
}
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index b51c11bf6e..e6c92f9720 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -545,7 +545,9 @@ tf_xla_py_test(
],
deps = [
":xla_test",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 4dff5f0f40..fceb61ef87 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -160,6 +160,77 @@ class EagerTest(XLATestCase):
for _ in range(100):
values.append(var.value())
+ # The shape, shape_n, size, and rank are tested here because their
+ # execution kernels (as opposed to compilation only tf2xla kernels)
+ # are distincts from tf2xla kernels.
+
+ def testShape(self):
+ def const(value):
+ return array_ops.shape(
+ constant_op.constant(value)).numpy()
+
+ def ones(value):
+ return array_ops.shape(
+ array_ops.ones(value)).numpy()
+
+ with self.test_scope():
+ # Shapes of directly constructed tensors
+ self.assertAllEqual([], const(3))
+ self.assertAllEqual([3], const([1.0, 2.0, 3.0]))
+ self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]]))
+ self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]]))
+
+ # Shapes of tensors created by op running on device
+ # We make this distinction because directly constructed tensors
+ # are treated differently in a few places that can influence shape:
+ # - they always have on_host_tensor
+ # - they and their shapes can be cached
+ # - they end up on device via a copy, instead of as program output
+ self.assertAllEqual([], ones([]))
+ self.assertAllEqual([3], ones([3]))
+ self.assertAllEqual([2, 2], ones([2, 2]))
+ self.assertAllEqual([2, 1, 2], ones([2, 1, 2]))
+
+ def testShapeN(self):
+ with self.test_scope():
+ # Shapes of directly constructed tensors
+ shapes = array_ops.shape_n([
+ constant_op.constant(1.0),
+ constant_op.constant([1.0, 2.0, 3.0]),
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]])])
+ self.assertAllEqual(
+ [[], [3], [2, 2]],
+ [x.numpy().tolist() for x in shapes])
+
+ # Shapes of tensors created by op running on device
+ shapes = array_ops.shape_n([
+ array_ops.ones([]),
+ array_ops.ones([3]),
+ array_ops.ones([2, 2])])
+ self.assertAllEqual(
+ [[], [3], [2, 2]],
+ [x.numpy().tolist() for x in shapes])
+
+ def testSize(self):
+ with self.test_scope():
+ self.assertEqual(
+ 1, array_ops.size(constant_op.constant(1.0)).numpy())
+ self.assertEqual(
+ 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy())
+ self.assertEqual(
+ 4, array_ops.size(
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
+
+ def testRank(self):
+ with self.test_scope():
+ self.assertEqual(
+ 0, array_ops.rank(constant_op.constant(1.0)).numpy())
+ self.assertEqual(
+ 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy())
+ self.assertEqual(
+ 2, array_ops.rank(
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
+
class EagerFunctionTest(XLATestCase):
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 70be22936a..f13dff9620 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest
@@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase):
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
- (not np.array_equal(z, w)) or
- (not np.array_equal(y, w)))
+ (not np.array_equal(z, w)) or (not np.array_equal(y, w)))
def testRandomUniformIsNotConstant(self):
+
def rng(dtype):
- return random_ops.random_uniform(shape=[2], dtype=dtype,
- maxval=1000000)
+ return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
@@ -70,13 +72,14 @@ class RandomOpsTest(XLATestCase):
for dtype in self._random_types():
with self.test_session() as sess:
with self.test_scope():
- x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
- maxval=33)
+ x = random_ops.random_uniform(
+ shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = sess.run(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
def testTruncatedNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
@@ -94,6 +97,29 @@ class RandomOpsTest(XLATestCase):
self.assertTrue((y >= -2).sum() == count)
self.assertTrue((y <= 2).sum() == count)
+ def testShuffle1d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = math_ops.range(20)
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = range(20)
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(set(result), set(expected))
+
+ def testShuffle2d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = array_ops.diag(math_ops.range(20))
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = np.diag(range(20)).flatten()
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(len(result.flatten()), len(expected))
+ self.assertAllEqual(set(result.flatten()), set(expected))
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 42585ad4d8..1438f6b48c 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -1438,7 +1438,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
std::vector<ControlFlowInfo> cf_info;
- TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
+ std::vector<string> unreachable_nodes;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
+ if (!unreachable_nodes.empty()) {
+ return errors::InvalidArgument(
+ "The following nodes are unreachable from the source in the graph: ",
+ tensorflow::str_util::Join(unreachable_nodes, ", "));
+ }
// Builds Frames, indexed by name.
std::unordered_map<string, Frame> frames;
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 8b9b026643..d48c6eea75 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -48,11 +48,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Building If: " << input_types_.size() << " inputs";
- std::vector<xla::XlaOp> inputs(input_types_.size());
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
DataType type = ctx->input_type(i + 1);
+
if (type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
@@ -60,7 +60,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arg.initialized = resource->initialized();
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind();
- OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
arg.type = resource->type();
arg.shape = resource->shape();
@@ -79,7 +78,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_types_[i];
arg.shape = ctx->InputShape(i + 1);
- inputs[i] = ctx->Input(i + 1);
VLOG(2) << "Arg type: " << DataTypeString(arg.type)
<< " shape: " << arg.shape.DebugString();
}
@@ -100,6 +98,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
arguments, &else_result));
+ bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
XlaResource* resource;
@@ -121,9 +120,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
+ if (!resource->tensor_array_gradients().empty())
+ has_tensor_array_gradients = true;
}
}
+ // Recompile the functions to update the argument shapes for tensor arrays.
+ if (has_tensor_array_gradients) {
+ then_result = {};
+ OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
+ arguments, &then_result));
+ else_result = {};
+ OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
+ arguments, &else_result));
+ }
+
// Check that both branches have identical input shapes.
OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
@@ -175,6 +186,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
"Mismatch in resource of then and else branch for resource ", i));
}
+ int num_inputs = then_result.input_mapping.size();
+ std::vector<xla::XlaOp> inputs(num_inputs);
+ for (int i = 0; i < num_inputs; ++i) {
+ int input_num = then_result.input_mapping[i] + 1;
+ if (ctx->input_type(input_num) == DT_RESOURCE) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
+ OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
+ } else {
+ inputs[i] = ctx->Input(i + 1);
+ }
+ }
+
xla::XlaOp outputs =
b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
b->Tuple(inputs), *else_result.computation);
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 39149d56ad..105be38fe2 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -17,6 +17,8 @@ limitations under the License.
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
// TODO(misard,phawkins): add tests.
+#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -56,6 +58,78 @@ class RandomUniformOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
RandomUniformOp);
+class RandomShuffleOp : public XlaOpKernel {
+ public:
+ explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ TensorShape input_shape = ctx->InputShape(0);
+ const int64 n = input_shape.dim_size(0);
+ int64 num_elements = 1;
+ for (tensorflow::TensorShapeDim dimension : input_shape) {
+ num_elements *= dimension.size;
+ }
+ if (num_elements <= 1 || n <= 1) {
+ // No shuffling is required, so copy input directly to output
+ ctx->SetOutput(0, input);
+ } else {
+ // Generate the random swaps for the indices.
+ auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
+ auto swaps =
+ builder->RngUniform(builder->ConstantR0<int32>(0),
+ builder->ConstantR0<int32>(n), swaps_shape);
+
+ // Generate range(n) as the initial value for the indices to be swapped.
+ xla::XlaOp indices;
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices));
+
+ // Swap the indices at i and swaps[i].
+ auto swap_body_fn = [&](xla::XlaOp i,
+ gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto swaps = loop_vars[0];
+ auto indices = loop_vars[1];
+ i = builder->Reshape(i, {1});
+ // temp = indices[i]
+ auto temp = builder->DynamicSlice(indices, i, {1});
+ // swap_index = swaps[i]
+ auto swap_index = builder->DynamicSlice(swaps, i, {1});
+ // swap_value = indices[swaps[i]]
+ auto swap_value = builder->DynamicSlice(indices, swap_index, {1});
+ // indices[i] = indices[swaps[i]]
+ indices = builder->DynamicUpdateSlice(indices, swap_value, i);
+ // indices[swaps[i]] = temp
+ indices = builder->DynamicUpdateSlice(indices, temp, swap_index);
+ return std::vector<xla::XlaOp>{swaps, indices};
+ };
+ // for i in range(n):
+ auto swap_loop_result =
+ XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
+ "indices_swap_loop", builder)
+ .ValueOrDie();
+ auto swapped_indices = swap_loop_result[1];
+
+ // Gather the data using the swapped indices as the shuffled order.
+ auto indices_tensor_shape = TensorShape({n});
+ DataType type = ctx->expected_output_dtype(0);
+ xla::XlaOp gather;
+ OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
+ indices_tensor_shape,
+ /*axis=*/0, /*indices_are_nd=*/false, type,
+ DT_INT32, builder, &gather));
+ ctx->SetOutput(0, gather);
+ }
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp);
+};
+
+REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp);
+
class RandomUniformIntOp : public XlaOpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 05354bca5b..d59720bef7 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("Shape"), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Rank"), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Size"), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
@@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel {
if (!wrapped_squeeze_dims.empty()) {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
- errors::InvalidArgument("Tried to explicitly squeeze "
- "dimension ",
- i, " but dimension was not 1: ",
- existing_dim));
+ errors::InvalidArgument(
+ "Tried to explicitly squeeze dimension ", i,
+ " but dimension was not 1: ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 3f1384bc86..20925118bf 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -110,7 +110,6 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i]
- auto ip1 = body_builder->Add(i, body_builder->ConstantR0<int32>(1));
// select the whole i-th column, then mask out all rows above i+1
TF_ASSIGN_OR_RETURN(
auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 43e1c1e9fe..db56b12837 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -40,6 +40,37 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
return Status::OK();
}
+Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
+ xla::BorrowingLiteral* literal) {
+ xla::Shape xla_shape;
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
+ host_tensor.shape(), &xla_shape));
+ *literal = xla::BorrowingLiteral(
+ static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
+ return Status::OK();
+}
+
+Status HostTensorsToBorrowingLiteralTuple(
+ tensorflow::gtl::ArraySlice<Tensor> host_tensors,
+ xla::BorrowingLiteral* literal) {
+ std::vector<const char*> buf_ptrs;
+ buf_ptrs.reserve(host_tensors.size());
+ std::vector<xla::Shape> tensor_shapes(host_tensors.size());
+
+ for (int i = 0; i < host_tensors.size(); i++) {
+ // Validate runtime shapes and fail if it doesn't match the contract.
+ const Tensor* tensor = &host_tensors[i];
+ buf_ptrs.emplace_back(static_cast<const char*>(DMAHelper::base(tensor)));
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(),
+ &tensor_shapes[i]));
+ }
+
+ *literal = xla::BorrowingLiteral(
+ buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes));
+
+ return Status::OK();
+}
+
Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor) {
TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index 220bec1553..74685025c1 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -29,6 +30,17 @@ namespace tensorflow {
// unsupported type.
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
+// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by
+// 'host_tensor'.
+Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
+ xla::BorrowingLiteral* literal);
+
+// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
+// owned by 'host_tensors'.
+Status HostTensorsToBorrowingLiteralTuple(
+ tensorflow::gtl::ArraySlice<Tensor> host_tensors,
+ xla::BorrowingLiteral* literal);
+
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
// type <target_type>.
// Fails if the literal's primitive type !=
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index a8bd199675..9c8e56a17e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -652,6 +652,7 @@ Status XlaCompiler::CompileSingleOp(
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
+ FixupSourceAndSinkEdges(graph.get());
return CompileGraph(options, name, std::move(graph), args, result);
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 5fbf4b952c..613230452b 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -1049,5 +1050,42 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
<< status.error_message();
}
+TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef no_op;
+ no_op.set_name("NoOp");
+ no_op.set_op("NoOp");
+ Status status;
+ graph->AddNode(no_op, &status);
+ TF_ASSERT_OK(status);
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler compiler(DefaultOptions());
+ // No control edge linking NoOp with source/sink.
+ {
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
+ std::move(graph_copy), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "The following nodes are unreachable "
+ "from the source in the graph: NoOp"))
+ << status.error_message();
+ }
+
+ // Fix control edges for NoOp.
+ {
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
+ std::move(graph_copy), args, &result));
+ EXPECT_EQ(0, result.resource_updates.size());
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index f1594193af..a1da176fe3 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -19,11 +19,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(dtype));
}
- xla::Literal linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
+ xla::BorrowingLiteral linspace_literal;
+ TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
+
*iota = builder->ConstantLiteral(linspace_literal);
return Status::OK();
}
@@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(index_type));
}
- xla::Literal linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
+ xla::BorrowingLiteral linspace_literal;
+ TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
// Broadcast the linspace constant across the indices along the new axis,
// and test equality at each position.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 61afc311a7..6b29589700 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -2341,28 +2341,28 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal,
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
- : LiteralBase(), shape_(shape) {
- CHECK(ShapeUtil::IsArray(shape_));
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsArray(*shape_));
CHECK_NE(src_buf_ptr, nullptr);
- CHECK(LayoutUtil::HasLayout(shape_));
+ CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = Piece();
root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
- root_piece_.set_subshape(&shape_);
+ root_piece_.set_subshape(shape_.get());
}
BorrowingLiteral::BorrowingLiteral(
tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
- : LiteralBase(), shape_(shape) {
- CHECK(ShapeUtil::IsTuple(shape_));
- CHECK(!ShapeUtil::IsNestedTuple(shape_));
- CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_));
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsTuple(*shape_));
+ CHECK(!ShapeUtil::IsNestedTuple(*shape_));
+ CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
root_piece_ = Piece();
- root_piece_.set_subshape(&shape_);
- BuildPieceSubtree(shape_, &root_piece_);
+ root_piece_.set_subshape(shape_.get());
+ BuildPieceSubtree(*shape_, &root_piece_);
for (int i = 0; i < src_buf_ptrs.size(); ++i) {
- const auto& src_shape = shape_.tuple_shapes(i);
+ const auto& src_shape = shape_->tuple_shapes(i);
CHECK(ShapeUtil::IsArray(src_shape));
root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
}
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 1e26eb7ad4..8e4159e360 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -1099,8 +1099,10 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_;
- // Shape of this literal.
- const Shape shape_;
+ // Shape of this literal. Stored as unique_ptr so such that the (default)
+ // move construction of this class would be trivially correct: the pointer to
+ // Shape root_piece_ stores will still point to the correct address.
+ std::unique_ptr<Shape> shape_;
};
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index f127cee0fd..53b926163c 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -1431,7 +1431,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
-TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
+TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
std::vector<int64> int64_values = {1, 2, 3};
const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
@@ -1443,7 +1443,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
EXPECT_EQ(literal.Get<int64>({2}), 3);
}
-TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) {
+TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
std::vector<int64> one_two_three = {1, 2, 3};
const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index f808990cad..ac058feccd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -598,10 +598,12 @@ _FORWARD_BINOP(Or)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
+_FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
+_FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 9ac13b6523..e30c7790b9 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -305,10 +305,12 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
+ _FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
+ _FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 536b93c6f9..fcd30b6c2f 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -974,10 +974,12 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp;
+%unignore xla::swig::LocalComputationBuilder::Expm1;
%unignore xla::swig::LocalComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil;
%unignore xla::swig::LocalComputationBuilder::Round;
%unignore xla::swig::LocalComputationBuilder::Log;
+%unignore xla::swig::LocalComputationBuilder::Log1p;
%unignore xla::swig::LocalComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 11611ac612..8b03682892 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -89,10 +89,12 @@ _UNARY_OPS = [
'Not',
'Abs',
'Exp',
+ 'Expm1',
'Floor',
'Round',
'Ceil',
'Log',
+ 'Log1p',
'Sign',
'Cos',
'Sin',
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 375e720f9b..6c0680f443 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -571,6 +571,12 @@ class SingleOpTest(LocalComputationTest):
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
+ def testExpm1(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Expm1(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.expm1(arr))
+
def testRound(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
@@ -583,6 +589,12 @@ class SingleOpTest(LocalComputationTest):
c.Log(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log(arr))
+ def testLog1p(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Log1p(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.log1p(arr))
+
def testNeg(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 313f11a9a9..d7dd9786a2 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "grpc++/create_channel.h"
-#include "grpc++/security/credentials.h"
+#include "grpcpp/create_channel.h"
+#include "grpcpp/security/credentials.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h
index 5cd573167a..ca1b09b648 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service.h
+++ b/tensorflow/compiler/xla/rpc/grpc_service.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_
#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_
-#include "grpc++/server_context.h"
+#include "grpcpp/server_context.h"
#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index e29908ccec..c68c857c30 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -15,9 +15,9 @@ limitations under the License.
// Basic server binary that exposes a xla::Service through a GRPC interface
// on a configurable port.
-#include "grpc++/security/server_credentials.h"
-#include "grpc++/server.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/server.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto
index 92eb19ec0f..551ae895e0 100644
--- a/tensorflow/compiler/xla/rpc/xla_service.proto
+++ b/tensorflow/compiler/xla/rpc/xla_service.proto
@@ -115,10 +115,6 @@ service XlaService {
returns (ComputeConstantResponse) {
}
- // Retrieves the inferred shape for a value within a computation.
- rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) {
- }
-
// Requests one or more device handles from the target. The returned device
// handles can be used to specify the device on which to execute computations
// or transfer data.
@@ -132,18 +128,6 @@ service XlaService {
returns (CreateChannelHandleResponse) {
}
- // Requests that the referenced computation be specialized for the provided
- // arguments for subsequent execution. This permits things such as value
- // specialization.
- rpc Specialize(SpecializeRequest) returns (SpecializeResponse) {
- }
-
- // Modifies the provided computation so that subsequent executions
- // will compute the provided ComputationDataHandle, rather than the
- // last expression enqueued on that Computation.
- rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) {
- }
-
// Invokes the provided computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 345f5ddeb2..6f34703fec 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -269,6 +269,7 @@ cc_library(
"dfs_hlo_visitor.cc",
"hlo_computation.cc",
"hlo_instruction.cc",
+ "hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
"hlo_sharding.cc",
@@ -280,16 +281,17 @@ cc_library(
"hlo_computation.h",
"hlo_domain_metadata.h",
"hlo_instruction.h",
+ "hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
"hlo_sharding.h",
],
deps = [
+ ":hlo_casting_utils",
":hlo_module_config",
":hlo_proto",
":hlo_reachability",
":name_uniquer",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
@@ -398,17 +400,6 @@ tf_cc_test(
],
)
-cc_library(
- name = "versioned_computation_handle",
- srcs = ["versioned_computation_handle.cc"],
- hdrs = ["versioned_computation_handle.h"],
- deps = [
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- ],
-)
-
tf_cc_test(
name = "hlo_instruction_test",
srcs = ["hlo_instruction_test.cc"],
@@ -588,7 +579,6 @@ cc_library(
":allocation_tracker",
":backend",
":channel_tracker",
- ":compilation_cache",
":compiler",
":computation_layout",
":device_memory_allocator",
@@ -603,7 +593,6 @@ cc_library(
":platform_util",
":source_map_util",
":transfer_manager",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:service_interface",
@@ -638,7 +627,6 @@ cc_library(
":platform_util",
":service",
":shaped_buffer",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@@ -759,7 +747,6 @@ cc_library(
":hlo_proto",
":pool",
":shaped_buffer",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -861,7 +848,6 @@ cc_library(
hdrs = ["channel_tracker.h"],
deps = [
":hlo",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -1163,6 +1149,19 @@ tf_cc_test(
)
cc_library(
+ name = "multi_output_fusion",
+ srcs = ["multi_output_fusion.cc"],
+ hdrs = ["multi_output_fusion.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "hlo_creation_utils",
srcs = ["hlo_creation_utils.cc"],
hdrs = ["hlo_creation_utils.h"],
@@ -1643,7 +1642,6 @@ tf_cc_test(
":hlo_cost_analysis",
":local_service",
":service",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
@@ -1985,20 +1983,6 @@ tf_cc_test(
)
cc_library(
- name = "compilation_cache",
- srcs = ["compilation_cache.cc"],
- hdrs = ["compilation_cache.h"],
- deps = [
- ":executable",
- ":hlo_module_config",
- ":versioned_computation_handle",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "layout_assignment",
srcs = [
"layout_assignment.cc",
@@ -2168,6 +2152,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
],
)
@@ -3015,13 +3000,14 @@ cc_library(
cc_library(
name = "hlo_casting_utils",
hdrs = ["hlo_casting_utils.h"],
- deps = [":hlo"],
+ deps = ["//tensorflow/core:lib"],
)
tf_cc_test(
name = "hlo_casting_utils_test",
srcs = ["hlo_casting_utils_test.cc"],
deps = [
+ ":hlo",
":hlo_casting_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index dc5f1b31bf..3b36939b8a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1783,6 +1783,37 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
return ReplaceWithNewInstruction(
reduce, HloInstruction::CreateReshape(reduce->shape(), arg));
}
+
+ // If a reduce feeds a reduce with the same computation and initial value,
+ // they can be combined into a single reduce.
+ if (arg->opcode() == HloOpcode::kReduce &&
+ init_value->Identical(*arg->operand(1)) &&
+ *function == *arg->to_apply()) {
+ // Create a new reduce with the combined reduction dimensions of both
+ // reduces.
+ std::vector<int64> arg_dims = arg->dimensions();
+ std::sort(arg_dims.begin(), arg_dims.end());
+ std::vector<int64> reduce_dims = reduce->dimensions();
+ std::sort(reduce_dims.begin(), reduce_dims.end());
+ // Transform reduce_dims to the same rank as the operand of the operand.
+ for (int64 arg_dim : arg_dims) {
+ for (int64& dim : reduce_dims) {
+ if (dim >= arg_dim) {
+ ++dim;
+ }
+ }
+ }
+ std::vector<int64> new_dimensions;
+ new_dimensions.reserve(arg->dimensions().size() +
+ reduce->dimensions().size());
+ std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
+ reduce_dims.end(), std::back_inserter(new_dimensions));
+ return ReplaceWithNewInstruction(
+ reduce,
+ HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0),
+ init_value, new_dimensions, function));
+ }
+
// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index cda157f9fa..2605b0488c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
EXPECT_EQ(root, param0);
}
+// Test that Reduce(Reduce(A)) -> Reduce(A)
+TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
+ HloComputation::Builder builder(TestName());
+ // Create add computation.
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module().AddEmbeddedComputation(builder.Build());
+ }
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r4f32, "param"));
+ std::vector<int64> dims0({0});
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
+ HloInstruction* reduce0 = builder.AddInstruction(
+ HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
+ std::vector<int64> dims1({1, 2});
+ Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
+ builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
+ dims1, add_computation));
+ module().AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Reduce(param, zero));
+ EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
+}
+
// Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@@ -1714,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1759,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@@ -1781,7 +1819,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1804,7 +1842,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1932,7 +1970,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
@@ -2060,7 +2099,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2090,7 +2129,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2121,7 +2160,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2151,7 +2190,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
@@ -2184,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2200,10 +2239,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, scalar_param,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
@@ -2219,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2237,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, forty_two,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
HloInstruction* transpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -2259,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2268,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2349,7 +2385,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2444,7 +2481,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 598718c72c..ec13fadbc7 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
- bool rewrite_inference_op, bool rewrite_grad_op,
- bool use_fusion);
+ bool rewrite_inference_op, bool rewrite_grad_op);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@@ -70,21 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
explicit BatchNormExpanderVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
- bool rewrite_grad_op, bool use_fusion)
+ bool rewrite_grad_op)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
- rewrite_grad_op_(rewrite_grad_op),
- use_fusion_(use_fusion) {}
+ rewrite_grad_op_(rewrite_grad_op) {}
HloComputation* GetOrCreateScalarAddComputation(
PrimitiveType primitive_type) {
- HloComputation** scalar_add_computation =
- &scalar_add_computations_[primitive_type];
- if (*scalar_add_computation) {
- return *scalar_add_computation;
- }
-
HloComputation::Builder b("scalar_add_computation");
Shape shape = ShapeUtil::MakeShape(primitive_type, {});
auto scalar_lhs = b.AddInstruction(
@@ -93,71 +85,38 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
- *scalar_add_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_add_computation;
- }
-
- // TODO(b/80534766): Remove maps after performance issues with scalar
- // broadcasts are resolved on all backends.
- HloComputation* GetOrCreateScalarRsqrtComputation(
- PrimitiveType primitive_type) {
- HloComputation** scalar_rsqrt_computation =
- &scalar_rsqrt_computations_[primitive_type];
- if (*scalar_rsqrt_computation) {
- return *scalar_rsqrt_computation;
- }
-
- HloComputation::Builder b("scalar_add_computation");
- Shape shape = ShapeUtil::MakeShape(primitive_type, {});
- auto scalar_lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
- shape, b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-0.5f)))));
- auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kPower, scalar_lhs, scalar_rhs));
- *scalar_rsqrt_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_rsqrt_computation;
+ return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
}
- std::unique_ptr<HloInstruction> Rsqrt(HloInstruction* operand) {
- return HloInstruction::CreateMap(
- operand->shape(), {operand},
- GetOrCreateScalarRsqrtComputation(operand->shape().element_type()));
- }
-
- HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type,
- int64 element_count) {
- HloComputation** scalar_mean_computation =
- &scalar_mean_computations_[std::pair<PrimitiveType, int64>(
- primitive_type, element_count)];
- if (*scalar_mean_computation) {
- return *scalar_mean_computation;
- }
-
- HloComputation::Builder b("scalar_add_computation");
- Shape shape = ShapeUtil::MakeShape(primitive_type, {});
- auto scalar_lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
- shape, b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(
- 1.0f / static_cast<float>(element_count))))));
- auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs));
- *scalar_mean_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_mean_computation;
+ std::unique_ptr<HloInstruction> Rsqrt(
+ HloInstruction* operand,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ add_instruction) {
+ HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast(
+ operand->shape(),
+ add_instruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeShape(operand->shape().element_type(), {}),
+ add_instruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(-0.5f))))),
+ {}));
+ return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower,
+ operand, exponent);
}
- std::unique_ptr<HloInstruction> Mean(int64 element_count,
- HloInstruction* operand) {
- return HloInstruction::CreateMap(
- operand->shape(), {operand},
- GetOrCreateScalarMeanComputation(operand->shape().element_type(),
- element_count));
+ std::unique_ptr<HloInstruction> Mean(
+ int64 element_count, HloInstruction* operand,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ add_instruction) {
+ HloInstruction* elem_count_recip =
+ add_instruction(HloInstruction::CreateBroadcast(
+ operand->shape(),
+ add_instruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeShape(operand->shape().element_type(), {}),
+ add_instruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(1.0 / element_count))))),
+ {}));
+ return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply,
+ operand, elem_count_recip);
}
// Replaces the existing HLO instruction old_instruction, with
@@ -189,18 +148,9 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
- bool use_fusion_;
// Whether rewrite has occurred.
bool changed_ = false;
-
- // Cached computations for adding two scalars.
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_add_computations_;
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_rsqrt_computations_;
- tensorflow::gtl::FlatMap<std::pair<PrimitiveType, int64>, HloComputation*>
- scalar_mean_computations_;
};
} // namespace
@@ -208,13 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool BatchNormExpanderVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
- bool rewrite_grad_op, bool use_fusion) {
+ bool rewrite_grad_op) {
BatchNormExpanderVisitor visitor(
computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_inference_op=*/rewrite_inference_op,
- /*rewrite_grad_op=*/rewrite_grad_op,
- /*use_fusion=*/use_fusion);
+ /*rewrite_grad_op=*/rewrite_grad_op);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@@ -290,28 +239,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
feature_shape, operand_squared, zero, dimensions_without_feature,
add_reduce_computation));
- // Fuse two parallel reduces together to improve performance.
- if (use_fusion_ && !batch_norm->has_sharding()) {
- auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum}));
-
- auto fused = computation_->CreateFusionInstruction(
- {tuple, sum, squared_sum, operand_squared},
- HloInstruction::FusionKind::kInput);
-
- sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
-
- squared_sum =
- add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
- }
-
// E[X].
- auto mean = add(Mean(elements_per_feature_int64, sum));
+ auto mean = add(Mean(elements_per_feature_int64, sum, add));
auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
// E[X^2].
- auto square_mean = add(Mean(elements_per_feature_int64, squared_sum));
+ auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add));
// E^2[X].
auto mean_square =
@@ -329,7 +264,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
+ auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
@@ -431,7 +366,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
+ auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
@@ -545,10 +480,12 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
// rsqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon_broadcasted =
add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
- variance_broadcasted, epsilon_activation)));
+ variance_broadcasted, epsilon_activation),
+ add));
auto rsqrt_var_add_epsilon = add(Rsqrt(
- add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature)));
+ add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
+ add));
// X - E[X].
auto activation_minus_mean = add_binary(
@@ -573,21 +510,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
feature_shape, grad_output, zero, dimensions_without_feature,
add_reduce_computation));
- if (use_fusion_ && !batch_norm->has_sharding()) {
- auto tuple = add(HloInstruction::CreateTuple(
- {sum_grad_output_times_activiation_minus_mean, grad_beta}));
-
- auto fused = computation_->CreateFusionInstruction(
- {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta},
- HloInstruction::FusionKind::kInput);
-
- sum_grad_output_times_activiation_minus_mean =
- add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
-
- grad_beta =
- add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
- }
-
// Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
sum_grad_output_times_activiation_minus_mean,
@@ -616,8 +538,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
rsqrt_var_add_epsilon_broadcasted);
- scale_times_rsqrt_var_add_epsilon =
- add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon));
+ scale_times_rsqrt_var_add_epsilon = add(
+ Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add));
auto elements_per_feature_literal =
Literal::CreateR0<float>(elements_per_feature_int64);
@@ -665,8 +587,8 @@ StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
- rewrite_inference_op_, rewrite_grad_op_,
- use_fusion_)) {
+ rewrite_inference_op_,
+ rewrite_grad_op_)) {
changed = true;
}
}
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 4ad987085d..7ae202c583 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface {
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
bool rewrite_inference_op = false,
- bool rewrite_grad_op = false, bool use_fusion = true)
+ bool rewrite_grad_op = false)
: rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
- rewrite_grad_op_(rewrite_grad_op),
- use_fusion_(use_fusion) {}
+ rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
@@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
- bool use_fusion_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 7e86c33687..96d25675de 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -371,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) {
// param1[100] --------------/--------/
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -418,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
// share anything.
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -477,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
// have the color 0, which allows the mul and add to share buffers.
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -547,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
//
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -601,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) {
// Creates the main kernel and verifies instruction counts.
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10_, ""));
+ HloInstruction::CreateParameter(0, f32a100x10_, "p"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
module->AddEntryComputation(builder.Build());
@@ -654,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10_, ""));
+ HloInstruction::CreateParameter(0, f32a100x10_, "p"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
auto exp2 = builder.AddInstruction(
@@ -818,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
// param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32vec100_, ""));
+ HloInstruction::CreateParameter(0, f32vec100_, "p"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
auto tanh = builder.AddInstruction(
@@ -1496,11 +1496,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
// param1[100] --------------/--------/
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -1536,7 +1536,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
// be {%rev, %neg, %concat}. This occurs right at the concat itself.
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32vec100_, ""));
+ HloInstruction::CreateParameter(0, f32vec100_, "p"));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
auto rev = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h
index 52f33a1318..fac0afd672 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.h
+++ b/tensorflow/compiler/xla/service/channel_tracker.h
@@ -19,7 +19,6 @@ limitations under the License.
#include <map>
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc
deleted file mode 100644
index b16907da9e..0000000000
--- a/tensorflow/compiler/xla/service/compilation_cache.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/compilation_cache.h"
-
-#include <utility>
-
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace xla {
-
-std::shared_ptr<Executable> CompilationCache::Insert(
- std::unique_ptr<Executable> executable,
- const HloModuleConfig& module_config) {
- tensorflow::mutex_lock lock(mutex_);
-
- CacheKey key =
- BuildKey(executable->entry_computation_handle(), module_config);
- VLOG(2) << "inserting cache key: " << key;
- if (cache_.count(key) == 0) {
- cache_.emplace(key, std::move(executable));
- } else {
- // Executable already exists in the cache. This can happen if two Execute
- // calls for a new computation are received simultaneously by the
- // service. In this case, we discard the Executable given as a parameter and
- // return what is in the cache. This is necessary because the service relies
- // on the cache to keep ownership of the Executable. We only want to store
- // one Executable for a given computation version and we can't discard the
- // executable which is in the cache because it may be in use.
- executable.reset();
- }
- return cache_.at(key);
-}
-
-std::shared_ptr<Executable> CompilationCache::LookUp(
- const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const {
- tensorflow::mutex_lock lock(mutex_);
-
- CacheKey key = BuildKey(versioned_handle, module_config);
- VLOG(2) << "looking up cache key: " << key;
- if (cache_.count(key) == 0) {
- VLOG(2) << "cache key not found: " << key;
- return nullptr;
- } else {
- std::shared_ptr<Executable> result = cache_.at(key);
- VLOG(2) << "hit executable with module config: "
- << result->module_config().compilation_cache_key();
- return result;
- }
-}
-
-CompilationCache::CacheKey CompilationCache::BuildKey(
- const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const {
- // The computation shape is represented entirely by its ProgramShape member,
- // so just serialize the proto as part of the key.
- return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::",
- versioned_handle.version, "::",
- module_config.compilation_cache_key());
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h
deleted file mode 100644
index 09989726ae..0000000000
--- a/tensorflow/compiler/xla/service/compilation_cache.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
-
-#include <map>
-#include <memory>
-#include <string>
-
-#include "tensorflow/compiler/xla/service/executable.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-
-namespace xla {
-
-// A cache which stores Executables indexed by computation handle and version.
-class CompilationCache {
- public:
- CompilationCache() {}
-
- // Insert the given Executable into the cache. Return a bare Executable
- // pointer for the caller to use. Note: the returned pointer will *not* be the
- // same as the given unique pointer if the computation already exists in the
- // cache. See comments in the .cc implementation for details of this case.
- //
- // module_config is provided by the caller, instead of being taken from the
- // executable, so that we can insert keys into the compilation cache that are
- // devoid of layout (where XLA gets to choose what layout to compile).
- //
- // A shared_ptr is returned so the caller can keep the Executable from being
- // destructed in the event that the Executable is evicted from the
- // computation cache (and the cache's shared_ptr to the Executable is
- // destructed).
- std::shared_ptr<Executable> Insert(std::unique_ptr<Executable> executable,
- const HloModuleConfig& module_config);
-
- // Lookup the Executable for the specified versioned computation in the cache.
- // Return a shared_ptr to the Executable if it exists in the cache. Return
- // nullptr otherwise.
- std::shared_ptr<Executable> LookUp(
- const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const;
-
- protected:
- mutable tensorflow::mutex mutex_;
-
- // Map from versioned handle with program layout to Executable built
- // for that computation version and program layout.
- using CacheKey = string;
-
- CacheKey BuildKey(const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const;
- std::map<CacheKey, std::shared_ptr<Executable>> cache_ GUARDED_BY(mutex_);
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache);
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h
index 53c3a3f7b7..6975f387b4 100644
--- a/tensorflow/compiler/xla/service/computation_layout.h
+++ b/tensorflow/compiler/xla/service/computation_layout.h
@@ -32,12 +32,21 @@ namespace xla {
// mutable layouts.
class ComputationLayout {
public:
+ // Creates a new ComputationLayout with the given result layout.
+ explicit ComputationLayout(ShapeLayout result_layout)
+ : result_layout_(std::move(result_layout)) {}
+
// Constructs a ComputationLayout from a ProgramShape. The layouts of the
// parameters and results are set to the default layout. Layouts in the
// ProgramShape are ignored if ignore_layouts is true.
explicit ComputationLayout(const ProgramShape& program_shape,
bool ignore_layouts = true);
+ // Adds a new parameter layout to the computation layout.
+ void add_parameter_layout(ShapeLayout shape_layout) {
+ parameter_layouts_.push_back(std::move(shape_layout));
+ }
+
// Returns the layout of a particular parameter.
const ShapeLayout& parameter_layout(int64 param_no) const {
return parameter_layouts_[param_no];
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 153f062d01..684fff8a6f 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -1636,8 +1636,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) {
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
- config);
+ HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_SequentialWhiles");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1677,8 +1676,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) {
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
- config);
+ HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_ParallelWhiles");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1750,8 +1748,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
for (int i = 0; i < num_iters; ++i) {
auto builder = HloComputation::Builder("BM_ParallelWhiles");
- HloModule module("BM_ManyElementTuple", VersionedComputationHandle(),
- config);
+ HloModule module("BM_ManyElementTuple", config);
for (int j = 0; j < num_tuple_inputs; ++j) {
tuple_params[j] = builder.AddInstruction(
HloInstruction::CreateParameter(j, element_shape, ""));
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 278bb1bebf..1067b38f93 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -898,6 +898,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:support",
],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 25b18eff20..4c0e189e78 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
- /*rewrite_grad_op=*/true,
- /*use_fusion=*/false);
+ /*rewrite_grad_op=*/true);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; },
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index e75fcb6bc9..3ed7876715 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace {
@@ -24,6 +25,7 @@ const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
const char* const kXlaEnableExperimentalLlvmIrGemm =
"xla_enable_experimental_llvm_ir_gemm";
+const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
} // namespace
@@ -62,6 +64,43 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0;
}
+static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str,
+ tensorflow::StringPiece suffix) {
+ CHECK_GE(str.size(), suffix.size());
+ CHECK_EQ(str.substr(str.size() - suffix.size()), suffix);
+ return str.substr(0, str.size() - suffix.size());
+}
+
+tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+ const HloModuleConfig& config) {
+ const auto& extra_options_map =
+ config.debug_options().xla_backend_extra_options();
+ auto it = extra_options_map.find(kLlvmIrGemmTileSize);
+ if (it == extra_options_map.end()) {
+ return tensorflow::gtl::nullopt;
+ }
+
+ std::vector<string> tile_components =
+ tensorflow::str_util::Split(it->second, ':');
+ CHECK_EQ(tile_components.size(), 3);
+
+ int64 tile_size_m;
+ int64 tile_size_k;
+ int64 tile_size_n_in_vector_width;
+
+ CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m));
+ CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k));
+
+ tensorflow::StringPiece tile_size_n_in_vector_width_str =
+ RemoveSuffix(tile_components[2], "*vectwidth");
+
+ CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str,
+ &tile_size_n_in_vector_width));
+
+ return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k,
+ tile_size_n_in_vector_width);
+}
+
} // namespace options
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 106dfbbc62..429b9e16cb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -29,6 +29,8 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config);
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
const HloModuleConfig& config);
+tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+ const HloModuleConfig& config);
} // namespace options
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index d77076546f..8eb39d615f 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -324,11 +324,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() {
int64 column_remainder = k() % tile_cols();
int64 column_limit = k() - column_remainder;
- ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
- [&](llvm::Value* column, bool is_first_column) {
- EmitOuterLoopBody(column, tile_cols(), is_first_column);
- });
+ ksl_.ForReturnVoid("dot.outer.tiled",
+ /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
+ [&](llvm::Value* column, bool is_first_column) {
+ EmitOuterLoopBody(column, tile_cols(), is_first_column);
+ });
if (column_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
@@ -341,19 +341,20 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
int64 columns, bool is_first_column) {
int64 row_limit = m() - (m() % tile_rows());
- ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
- /*step=*/tile_rows(), [&](llvm::Value* row) {
- std::vector<llvm::Value*> lhs_tile =
- lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
- llvm::Value* accumulator =
- is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
- : vsl_.GetZeroVector())
- : vsl_.LoadVector(result_, row);
- for (int i = 0; i < columns; i++) {
- accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
- }
- vsl_.StoreVector(accumulator, result_, row);
- });
+ ksl_.ForReturnVoid(
+ "dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
+ /*step=*/tile_rows(), [&](llvm::Value* row) {
+ std::vector<llvm::Value*> lhs_tile =
+ lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
+ llvm::Value* accumulator =
+ is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
+ : vsl_.GetZeroVector())
+ : vsl_.LoadVector(result_, row);
+ for (int i = 0; i < columns; i++) {
+ accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
+ }
+ vsl_.StoreVector(accumulator, result_, row);
+ });
}
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
@@ -372,7 +373,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
// // initialized.
// }
- ksl_.For(
+ ksl_.ForReturnVoid(
"dot.inner.epilg.outer", /*start=*/current_tile_col,
/*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
/*step=*/1, /*peel_first_iteration=*/false,
@@ -382,7 +383,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
ir_builder_->CreateMul(col, ir_builder_->getInt64(m()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
- ksl_.For(
+ ksl_.ForReturnVoid(
"dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
/*step=*/1, [&](llvm::Value* scalar_row) {
llvm::Value* product = vsl_.Mul(
@@ -390,7 +391,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
is_first_scalar_col,
ir_builder_->getInt1(is_first_tiled_column));
- ksl_.If(
+ ksl_.IfReturnVoid(
setting_result_first_time,
/*true_block_generator=*/
[&]() {
@@ -571,9 +572,10 @@ void RowMajorMatrixVectorProductEmitter::Emit() {
int64 row_remainder = m() % tile_rows();
int64 row_limit = m() - row_remainder;
- ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
- [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
+ ksl_.ForReturnVoid(
+ "dot.outer.tiled",
+ /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
+ [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
if (row_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
@@ -585,17 +587,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
std::vector<VectorVariable>* vector_accumulators) {
int64 column_limit = k() - (k() % tile_cols());
- ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
- /*step=*/tile_cols(), [&](llvm::Value* col) {
- std::vector<llvm::Value*> lhs_tile =
- lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
- llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
- for (int i = 0; i < rows; i++) {
- llvm::Value* old_sum = (*vector_accumulators)[i].Get();
- (*vector_accumulators)[i].Set(
- vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
- }
- });
+ ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
+ /*step=*/tile_cols(), [&](llvm::Value* col) {
+ std::vector<llvm::Value*> lhs_tile =
+ lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
+ llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
+ for (int i = 0; i < rows; i++) {
+ llvm::Value* old_sum = (*vector_accumulators)[i].Get();
+ (*vector_accumulators)[i].Set(vsl_.Add(
+ old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
+ }
+ });
}
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
@@ -612,14 +614,15 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
ir_builder_->getInt64(k()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
- ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
- /*step=*/1, [&](llvm::Value* scalar_col) {
- llvm::Value* product =
- vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
- vsl_.LoadScalar(rhs_, scalar_col));
- llvm::Value* old_value = (*scalar_accumulators)[r].Get();
- (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
- });
+ ksl_.ForReturnVoid(
+ "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
+ /*step=*/1, [&](llvm::Value* scalar_col) {
+ llvm::Value* product =
+ vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
+ vsl_.LoadScalar(rhs_, scalar_col));
+ llvm::Value* old_value = (*scalar_accumulators)[r].Get();
+ (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
+ });
}
}
@@ -665,6 +668,10 @@ class MatrixMatrixBlockPanelEmitter {
// the largest vector register we will use). This can be larger than the
// largest vector register supported by the machine -- LLVM will legalize
// these large vector widths into legally sized vectors.
+ //
+ // `max_vector_count` is the maximum number of vectors of size
+ // `max_vectorization_width` that we will attempt to process at once.
+ //
// `min_vectorization_width` is the smallest vector width the emitter will use
// -- below that it will devolve to using a scalar loop.
//
@@ -674,12 +681,13 @@ class MatrixMatrixBlockPanelEmitter {
class Config {
public:
explicit Config(PrimitiveType scalar_type, Dimensions dims,
- int64 max_vectorization_width,
+ int64 max_vectorization_width, int64 max_vector_count,
int64 min_vectorization_width, int64 tile_size_m,
int64 tile_size_k)
: scalar_type_(scalar_type),
dims_(dims),
max_vectorization_width_(max_vectorization_width),
+ max_vector_count_(max_vector_count),
min_vectorization_width_(min_vectorization_width),
tile_size_m_(tile_size_m),
tile_size_k_(tile_size_k) {}
@@ -694,6 +702,7 @@ class MatrixMatrixBlockPanelEmitter {
PrimitiveType scalar_type() const { return scalar_type_; }
Dimensions dims() const { return dims_; }
int64 max_vectorization_width() const { return max_vectorization_width_; }
+ int64 max_vector_count() const { return max_vector_count_; }
int64 min_vectorization_width() const { return min_vectorization_width_; }
int64 tile_size_m() const { return tile_size_m_; }
@@ -703,6 +712,7 @@ class MatrixMatrixBlockPanelEmitter {
PrimitiveType scalar_type_;
Dimensions dims_;
int64 max_vectorization_width_;
+ int64 max_vector_count_;
int64 min_vectorization_width_;
int64 tile_size_m_;
int64 tile_size_k_;
@@ -721,39 +731,35 @@ class MatrixMatrixBlockPanelEmitter {
ksl_(ir_builder_) {
CHECK(max_vectorization_width() > 0 &&
IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
+ CHECK_GT(max_vector_count(), 0);
CHECK(min_vectorization_width() > 0 &&
IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
+ CHECK_GE(max_vectorization_width(), min_vectorization_width());
CHECK_GT(tile_size_k(), 0);
}
void Emit();
private:
- // This emits a loop that loops over the `n` dimension in multiples of
- // `max_vectorization_width` as much as possible and then emits a remainder
- // epilogue.
- void EmitLoopOverN();
-
- // This emits a loop that loops over the `k` dimension in multiples of
- // `tile_size_k` as much as possible and then emits a remainder epilogue.
- void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
- llvm::Value* n_end);
-
- // This emits a loop that loops over the `m` dimension in multiples of
- // `tile_size_m` as much as possible and then emits a remainder epilogue.
- void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ // The HandleResiduesOnX helpers split the iteration space for dimension X
+ // into a multiple of the tile size on dimension X and an epilogue. These
+ // helpers ultimately call into `EmitTiledGemm` for emitting the
+ // tiled GEMM kernel.
+
+ void HandleResiduesOnN();
+ void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
+ llvm::Value* n_end);
+ void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end);
+
+ // This emits a tiled GEMM kernel. For a detailed description see the comment
+ // on the implementation.
+ void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k,
llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end);
-
- // This emits the inner reduction loop. This inner reduction loop multiplies
- // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the
- // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size
- // [`tile_size_m`, vls->vector_width()] in the result.
- void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k,
- llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end,
- int64 tile_size_m, llvm::Value* m_start,
- llvm::Value* m_end);
+ llvm::Value* n_start, llvm::Value* n_end,
+ int64 tile_size_m, llvm::Value* m_start,
+ llvm::Value* m_end);
llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); }
@@ -763,6 +769,7 @@ class MatrixMatrixBlockPanelEmitter {
int64 max_vectorization_width() const {
return config().max_vectorization_width();
}
+ int64 max_vector_count() const { return config().max_vector_count(); }
int64 min_vectorization_width() const {
return config().min_vectorization_width();
}
@@ -779,16 +786,19 @@ class MatrixMatrixBlockPanelEmitter {
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); }
+void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); }
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
// We can only iterate the `n` dimension for an extent that is divisible by
// the vectorization width. So we emit an outer loop that first processes the
// largest extent in `n` that is divisible by max_vectorization_width, then
// the largest remaining extent that is divisible by max_vectorization_width /
// 2 etc.
- int64 current_vectorization_width = max_vectorization_width();
+ int64 current_vectorization_width =
+ max_vector_count() * max_vectorization_width();
+ int64 current_vector_count = max_vector_count();
+
int64 n_start = 0;
while (n_start != dims().n() &&
current_vectorization_width >= min_vectorization_width()) {
@@ -796,53 +806,67 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
if (n_start != n_end) {
VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
ir_builder_, "gebp");
- EmitLoopOverK(&vsl, GetInt64(n_start), GetInt64(n_end));
+ HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
- current_vectorization_width /= 2;
+ if (current_vector_count == 1) {
+ current_vectorization_width /= 2;
+ } else {
+ current_vector_count--;
+ current_vectorization_width =
+ current_vector_count * max_vectorization_width();
+ }
}
if (n_start != dims().n()) {
VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp");
- ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
+ ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next =
ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
- EmitLoopOverK(&vsl, n_i, n_i_next);
+ HandleResiduesOnK(&vsl, n_i, n_i_next);
});
}
}
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
- llvm::Value* n_start,
- llvm::Value* n_end) {
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
+ llvm::Value* n_start,
+ llvm::Value* n_end) {
int64 k_start = 0;
int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
- EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
- n_start, n_end);
+ HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
+ n_start, n_end);
k_start = k_end;
}
if (k_start != dims().k()) {
- EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start),
- GetInt64(dims().k()), n_start, n_end);
+ HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
+ GetInt64(dims().k()), n_start, n_end);
}
}
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
const int64 m_end = dims().m() - dims().m() % tile_size_m();
- EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
- tile_size_m(), GetInt64(0), GetInt64(m_end));
+ EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
+ GetInt64(0), GetInt64(m_end));
if (m_end != dims().m()) {
- EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
- dims().m() - m_end, GetInt64(m_end),
- GetInt64(dims().m()));
+ EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+ dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
}
}
+// The loop structure is:
+//
+// Iterate over dimension M as m:
+// Iterate over dimension N as n:
+// Iterate over dimension K as k:
+// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
+//
+// I.e. a just a tiled version of a "naive" GEMM.
+//
// The tiling scheme is as follows:
//
// Let the LHS be:
@@ -904,41 +928,48 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
// +-------------------+-------------------+-------------------+---------
// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
// +-------------------+-------------------+-------------------+---------
-void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop(
+void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
- ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
- MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_,
- /*matrix_size_along_minor_dim=*/dims().n(),
- /*major_dim_offset=*/m_i,
- /*tile_size_along_major_dim=*/tile_size_m);
- MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
- /*matrix_size_along_minor_dim=*/dims().k(),
- /*major_dim_offset=*/m_i,
- /*tile_size_along_major_dim=*/tile_size_m);
-
- ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
- MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i,
- tile_size_k);
- std::vector<std::vector<llvm::Value*>> lhs_tile =
- lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
- ksl_.For(
- "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
- std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
- std::vector<llvm::Value*> result_tile =
- result_memory_tile.LoadTile(n_i);
- for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
- for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
- result_tile[r_m_i] =
- vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
- result_tile[r_m_i]);
- }
- }
- result_memory_tile.StoreTile(result_tile, n_i);
- });
- });
- });
+ ksl_.ForReturnVoid(
+ "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
+ MemoryTile result_memory_tile(
+ vsl, ir_builder_, /*matrix=*/result_,
+ /*matrix_size_along_minor_dim=*/dims().n(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+ MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/dims().k(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+ ksl_.ForReturnVoid(
+ "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
+ TileVariable result_tile_var(vsl,
+ result_memory_tile.LoadTile(n_i));
+ ksl_.ForReturnVoid(
+ "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
+ MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_,
+ dims().n(), k_i, tile_size_k);
+ std::vector<std::vector<llvm::Value*>> lhs_tile =
+ lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
+ std::vector<llvm::Value*> rhs_tile =
+ rhs_memory_tile.LoadTile(n_i);
+ std::vector<llvm::Value*> result_tile =
+ result_tile_var.Get();
+ for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
+ for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
+ result_tile[r_m_i] =
+ vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
+ result_tile[r_m_i]);
+ }
+ }
+ result_tile_var.Set(result_tile);
+ });
+
+ result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
+ });
+ });
}
} // namespace
@@ -1023,16 +1054,21 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
target, ir_builder_->getInt8(0), size_bytes,
target_machine_features_.minimum_alignment_for_allocation(size_bytes));
- int64 max_vector_width =
+ int64 max_target_vector_width =
target_machine_features_.vector_register_num_elements(
*ir_builder_->GetInsertBlock()->getParent(), primitive_type);
+ int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
+ std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
+ GetGemmTileSize();
+
MatrixMatrixBlockPanelEmitter::Config config(
/*scalar_type=*/primitive_type,
MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
- /*max_vectorization_width=*/max_vector_width,
- /*min_vectorization_width=*/std::min<int64>(4, max_vector_width),
- /*tile_size_m=*/3, /*tile_size_k=*/5);
+ /*max_vectorization_width=*/max_target_vector_width,
+ /*max_vector_count=*/tile_size_n_in_vector_width,
+ /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
+ /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
<< config.GetCacheKey();
@@ -1265,8 +1301,11 @@ Status DotOpEmitter::Emit() {
// from messing up the vectorization.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
- /*prevent_unrolling=*/lhs_reduction_along_minor_dimension &&
- rhs_reduction_along_minor_dimension);
+ /*unroll_mode=*/
+ (lhs_reduction_along_minor_dimension &&
+ rhs_reduction_along_minor_dimension)
+ ? xla::llvm_ir::UnrollMode::kNoUnroll
+ : xla::llvm_ir::UnrollMode::kDefaultUnroll);
// The final entry in the rhs and lhs indexes is the indvar of the
// reduction loop.
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index d88ccea0db..ed2a18976a 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -143,6 +143,17 @@ class DotOpEmitter {
.value_or(kDefaultTilingFactor);
}
+ std::tuple<int64, int64, int64> GetGemmTileSize() const {
+ // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
+ //
+ // TODO(b/80093688): Tune for other architectures and centralize this
+ // information in one place.
+ const std::tuple<int64, int64, int64> kDefaultTileSize =
+ std::tuple<int64, int64, int64>(11, 9, 1);
+ return options::LlvmIrGemmTileSize(hlo_module_config_)
+ .value_or(kDefaultTileSize);
+ }
+
// Returns true if we should use an experimental implementation of GEMM
// (general matrix matrix multiplication) if possible.
bool EnableExperimentalLlvmIrGemm() const {
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
index 92da5f71c2..f8c8dd5e93 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
#include "third_party/intel_mkl_ml/include/mkl_service.h"
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index cd1165e238..c444d15185 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const {
void LlvmVariable::Set(llvm::Value* new_value) {
ir_builder_->CreateStore(new_value, alloca_);
}
+
+TileVariable::TileVariable(VectorSupportLibrary* vector_support,
+ std::vector<llvm::Value*> initial_value) {
+ for (llvm::Value* initial_vector_value : initial_value) {
+ storage_.emplace_back(vector_support, initial_vector_value);
+ }
+}
+
+std::vector<llvm::Value*> TileVariable::Get() const {
+ std::vector<llvm::Value*> result;
+ c_transform(storage_, std::back_inserter(result),
+ [&](VectorVariable vect_var) { return vect_var.Get(); });
+ return result;
+}
+
+void TileVariable::Set(tensorflow::gtl::ArraySlice<llvm::Value*> value) {
+ CHECK_EQ(value.size(), storage_.size());
+ for (int64 i = 0, e = value.size(); i < e; i++) {
+ storage_[i].Set(value[i]);
+ }
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index edcaec5849..49c2a4e2f4 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -317,6 +318,21 @@ class ScalarVariable : public LlvmVariable {
Set(initial_value);
}
};
+
+// This wraps a set of alloca-backed stack variables that can, as a whole, store
+// a tile. A "tile" is a sequence of vectors that is typically used as a 2D
+// grid of scalar values (e.g. for tiled GEMMs).
+class TileVariable {
+ public:
+ TileVariable(VectorSupportLibrary* vector_support,
+ std::vector<llvm::Value*> initial_value);
+
+ std::vector<llvm::Value*> Get() const;
+ void Set(tensorflow::gtl::ArraySlice<llvm::Value*> value);
+
+ private:
+ std::vector<VectorVariable> storage_;
+};
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 64678d9d74..ee2b455730 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -243,6 +243,8 @@ class DfsHloVisitorBase {
virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
+ virtual Status HandleGenerateToken(HloInstructionPtr token) = 0;
+
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
virtual Status FinishVisit(HloInstructionPtr root) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 240faebe62..6934e00a4b 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
+ Status HandleGenerateToken(HloInstructionPtr token) override {
+ return DefaultAction(token);
+ }
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 9a8bab353e..93fea7ead7 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -456,17 +456,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
llvm::ConstantFP::get(type, 1.0)));
}
case HloOpcode::kIsFinite: {
- // (x == x) && abs(x) != inf
+ // abs(x) o!= inf, this works because the comparison returns false if
+ // either operand is NaN.
auto type = operand_value->getType();
- auto equal_self =
- ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
auto infinity = llvm::ConstantFP::getInfinity(type);
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
- auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
return ir_builder_->CreateZExt(
- result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
return ir_builder_->CreateFNeg(operand_value);
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 087bd14329..dc1f26ea65 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -131,12 +130,6 @@ class Executable {
const HloModuleConfig& module_config() const { return hlo_module_->config(); }
- // Returns the versioned computation handle of the computation computed by
- // this executable.
- const VersionedComputationHandle& entry_computation_handle() const {
- return hlo_module_->entry_computation_handle();
- }
-
// The shape (including layout) that results from this execution. This is the
// shape of the DeviceMemoryBase result value in ExecuteOnStream above.
const Shape& host_result_shape() const {
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6bd9d4c31d..5e02631a58 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -164,6 +164,7 @@ cc_library(
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
@@ -423,6 +424,34 @@ tf_cc_test(
)
cc_library(
+ name = "multi_output_fusion",
+ srcs = ["multi_output_fusion.cc"],
+ hdrs = ["multi_output_fusion.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:multi_output_fusion",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "multi_output_fusion_test",
+ srcs = ["multi_output_fusion_test.cc"],
+ deps = [
+ ":multi_output_fusion",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "gpu_copy_insertion",
srcs = ["gpu_copy_insertion.cc"],
hdrs = ["gpu_copy_insertion.h"],
@@ -522,6 +551,7 @@ cc_library(
":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
+ ":multi_output_fusion",
":pad_insertion",
":partition_assignment",
":stream_assignment",
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index e5e2a0478a..b812dd7d3f 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -53,11 +53,17 @@ using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
using tensorflow::strings::StrAppend;
+namespace {
// Returns whether operand is a floating-point literal with the given value.
bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
- return operand->opcode() == HloOpcode::kConstant &&
- operand->literal().IsAllFloat(value);
+ if (operand->opcode() == HloOpcode::kConstant &&
+ operand->literal().IsAllFloat(value)) {
+ return true;
+ }
+ return operand->opcode() == HloOpcode::kBroadcast &&
+ IsFPLiteralWithValue(operand->operand(0), value);
}
+} // namespace
GpuElementalIrEmitter::GpuElementalIrEmitter(
const HloModuleConfig& hlo_module_config, llvm::Module* module,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index b857219807..afefc740d7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -52,6 +52,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
@@ -159,13 +160,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
pass.AddPass<CudnnBatchNormRewriter>();
}
- // TODO(kramerb): Remove use_fusion once instruction fusion can create
- // multi-output fusions from the unfused expander output.
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
- /*rewrite_grad_op=*/true,
- /*use_fusion=*/true);
+ /*rewrite_grad_op=*/true);
// Rewrite gather ops into smaller ones.
pass.AddPass<GatherExpander>();
@@ -261,6 +259,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
+ fusion.AddPass<GpuMultiOutputFusion>();
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index e230d538cc..45f0a1c645 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -47,8 +47,7 @@ class HloScheduleTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>("test_module", config);
}
HloVec RemoveHlo(const HloVec& input,
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 36a1b82a26..6c4519185b 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -77,15 +77,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Check if we can use output fusion for (A @ B) * alpha
- if (consumer->operand_count() == 2 &&
- (producer->opcode() == HloOpcode::kDot ||
- (producer->opcode() == HloOpcode::kFusion &&
- producer->fused_expression_root()->opcode() == HloOpcode::kDot))) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
int64 other_operand_index = 1 - operand_index;
- const HloInstruction* alpha = consumer->operand(other_operand_index);
HloInstruction* op1 = nullptr;
HloInstruction* op2 = nullptr;
- if (consumer->opcode() == HloOpcode::kFusion &&
+ if (consumer->operand_count() == 1 &&
+ consumer->opcode() == HloOpcode::kFusion &&
consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
Match(consumer->fused_expression_root(),
match::Op()
@@ -103,10 +102,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
op2->opcode() != HloOpcode::kBroadcast) {
return false;
}
- if (IsIEEEFloatingPointScalarConstant(alpha)) {
+ if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) {
return true;
}
- } else if (consumer->opcode() == HloOpcode::kMultiply) {
+ } else if (consumer->operand_count() == 2 &&
+ consumer->opcode() == HloOpcode::kMultiply) {
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
// Fuse if 'alpha' is a broadcast of a scalar constant.
if (alpha->opcode() == HloOpcode::kBroadcast &&
alpha->dimensions().empty() &&
@@ -173,6 +174,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
+ // Fuse scalar constants into loop fusion nodes, this reduces the number of
+ // parameters and makes matching scalar broadcasts easier.
+ if (ShapeUtil::IsEffectiveScalar(producer->shape()) &&
+ consumer->opcode() == HloOpcode::kFusion &&
+ producer->opcode() == HloOpcode::kConstant) {
+ return true;
+ }
+
return IsFusile(*producer) && IsFusile(*consumer) &&
InstructionFusion::ShouldFuse(consumer, operand_index);
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 426b1d235c..1963d9eef7 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -168,7 +168,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
- op::Reduce(op::Broadcast(op::Parameter()), op::Parameter()));
+ op::Reduce(op::Broadcast(op::Constant()), op::Constant()));
}
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
@@ -255,7 +255,7 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
EXPECT_THAT(
root->fused_expression_root(),
op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
- op::Broadcast(op::Parameter())));
+ op::Broadcast(op::Constant())));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
@@ -339,7 +339,7 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
EXPECT_THAT(root->fused_expression_root(),
op::Multiply(op::Multiply(op::Parameter(), op::Parameter()),
- op::Broadcast(op::Parameter())));
+ op::Broadcast(op::Constant())));
}
// Counts the HLO ops with a given op code in the specified module.
@@ -581,5 +581,30 @@ TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) {
<< module->ToString();
}
+TEST_F(InstructionFusionTest, FuseScalarConstant) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+
+ ENTRY FuseScalarConstant {
+ p0 = f32[] parameter(0)
+ c0 = f32[] constant(1)
+ add1 = f32[] add(p0, c0)
+ b0 = f32[2]{0} broadcast(add1), dimensions={}
+ c1 = f32[2]{0} constant({1, 2})
+ ROOT add2 = f32[2]{0} add(b0, c1)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())),
+ op::Parameter()));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b40b557cab..726434c3df 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -59,6 +59,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
@@ -501,20 +502,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
case HloOpcode::kReduce: {
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
std::vector<std::unique_ptr<Thunk>> thunks;
- ArraySlice<HloInstruction*> reduces =
+ ArraySlice<HloInstruction*> output_instructions =
root->opcode() == HloOpcode::kTuple
? root->operands()
: ArraySlice<HloInstruction*>(&root, 1);
// For multi-output fusion emit an initializer for each tuple element.
// Otherwise it's sufficient to just initialize the single output.
- for (int i = 0, e = reduces.size(); i != e; ++i) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Thunk> initializer_thunk,
- BuildInitializerThunk(
- fusion, reduces[i] == root ? ShapeIndex() : ShapeIndex({i})));
- thunks.push_back(std::move(initializer_thunk));
+ HloInstruction* first_reduce = nullptr;
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ if (output_instructions[i]->opcode() == HloOpcode::kReduce) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Thunk> initializer_thunk,
+ BuildInitializerThunk(fusion, output_instructions[i] == root
+ ? ShapeIndex()
+ : ShapeIndex({i})));
+ thunks.push_back(std::move(initializer_thunk));
+ first_reduce =
+ first_reduce == nullptr ? output_instructions[i] : first_reduce;
+ }
}
+ CHECK(first_reduce != nullptr);
thunks.push_back(BuildKernelThunk(fusion));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
@@ -533,29 +541,45 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// fusion is a special case of that.
InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
InlinedVector<llvm_ir::ElementGenerator, 1> init_value_gens;
+ std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens;
InlinedVector<HloComputation*, 1> reducers;
- for (const HloInstruction* reduce : reduces) {
- CHECK_EQ(HloOpcode::kReduce, reduce->opcode());
+ InlinedVector<ShapeIndex, 1> reduce_output_shapes;
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ const HloInstruction* inst = output_instructions[i];
+ ShapeIndex output_shape_index;
+ if (root->opcode() == HloOpcode::kTuple) {
+ output_shape_index = {i};
+ }
// TODO(kramerb): CHECK that layouts are equal. Currently this
// breaks multioutputfusion_test. The test has pre-fused
// instructions, but layout_assignment will not assign any layouts
// for instructions inside of a fused computation. It just removes
// the layouts instead.
- CHECK(ShapeUtil::Compatible(reduces[0]->shape(), reduce->shape()));
- CHECK(ShapeUtil::Compatible(reduces[0]->operand(0)->shape(),
- reduce->operand(0)->shape()));
- CHECK(ShapeUtil::Compatible(reduces[0]->operand(1)->shape(),
- reduce->operand(1)->shape()));
- CHECK(reduces[0]->dimensions() == reduce->dimensions());
- input_gens.push_back(fused_emitter.GetGenerator(reduce->operand(0)));
- init_value_gens.push_back(
- fused_emitter.GetGenerator(reduce->operand(1)));
- reducers.push_back(reduce->to_apply());
+ if (inst->opcode() == HloOpcode::kReduce) {
+ CHECK(ShapeUtil::Compatible(first_reduce->shape(), inst->shape()));
+ CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(),
+ inst->operand(0)->shape()));
+ CHECK(ShapeUtil::Compatible(first_reduce->operand(1)->shape(),
+ inst->operand(1)->shape()));
+ CHECK(first_reduce->dimensions() == inst->dimensions());
+ input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
+ init_value_gens.push_back(
+ fused_emitter.GetGenerator(inst->operand(1)));
+ reducers.push_back(inst->to_apply());
+ reduce_output_shapes.push_back(std::move(output_shape_index));
+ } else {
+ CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(),
+ inst->shape()));
+ extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
+ std::move(output_shape_index));
+ }
}
- const Shape& input_shape = reduces[0]->operand(0)->shape();
- return EmitReductionToVector(reduces[0], input_shape, input_gens,
- init_value_gens, reduces[0]->dimensions(),
- reducers);
+ const Shape& input_shape = first_reduce->operand(0)->shape();
+ return EmitReductionToVector(first_reduce, input_shape, input_gens,
+ init_value_gens,
+ first_reduce->dimensions(), reducers,
+ reduce_output_shapes, extra_output_gens);
}
default:
LOG(FATAL) << "Bad opcode for input fusion: "
@@ -940,11 +964,33 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
return IrEmitter::HandleCopy(copy);
}
+Status IrEmitterUnnested::EmitExtraOutputsForReduce(
+ const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
+ for (int i = 0; i != extra_output_gens.size(); ++i) {
+ const HloInstruction* output = reduce->parent()->FusionInstruction();
+ llvm::Value* extra_output_address =
+ GetIrArray(*output, *output, extra_output_gens[i].second)
+ .EmitArrayElementAddress(index, &ir_builder_,
+ "extra_output_element_address");
+ TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
+ extra_output_gens[i].first(index));
+ ir_builder_.CreateStore(extra_output_ir_value, extra_output_address);
+ }
+ return Status::OK();
+}
+
Status IrEmitterUnnested::EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// Number of elements processed by a single thread.
constexpr int64 kTileSize = 16;
int64 num_elems = ShapeUtil::ElementsIn(input_shape);
@@ -1050,7 +1096,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
{partial_reduction_result_addresses[i], input_address},
partial_reduction_result_addresses[i]));
}
- return Status::OK();
+ return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens);
};
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
@@ -1120,17 +1166,13 @@ Status IrEmitterUnnested::EmitReductionToScalar(
&ir_builder_);
for (int i = 0; i != num_reduces; ++i) {
- ShapeIndex output_shape_index;
- if (output->IsMultiOutputFusion()) {
- output_shape_index = {i};
- }
llvm::Value* output_address =
- GetIrArray(*output, *output, output_shape_index)
+ GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
llvm_ir::IrArray::Index(
/*linear=*/ir_builder_.getInt64(0),
ShapeUtil::GetSubshape(output->shape(),
- output_shape_index),
+ reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
@@ -1158,7 +1200,11 @@ Status IrEmitterUnnested::EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// Divide the input matrix into tiles of size Kx1. For example, when the
// input matrix is 4x4 and K=2, the tiled matrix looks like
//
@@ -1284,7 +1330,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
{partial_reduction_result_addresses[i], input_address},
partial_reduction_result_addresses[i]));
}
- return Status::OK();
+ return EmitExtraOutputsForReduce(reduce, input_index,
+ extra_output_gens);
}
};
@@ -1315,17 +1362,13 @@ Status IrEmitterUnnested::EmitColumnReduction(
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
- ShapeIndex output_shape_index;
- if (output->IsMultiOutputFusion()) {
- output_shape_index = {i};
- }
llvm::Value* output_address =
- GetIrArray(*output, *output, output_shape_index)
+ GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
llvm_ir::IrArray::Index(
x,
ShapeUtil::GetSubshape(output->shape(),
- output_shape_index),
+ reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
@@ -1349,14 +1392,42 @@ Status IrEmitterUnnested::EmitColumnReduction(
.EmitLoop(IrName(reduce));
}
+static std::pair<int64, int64> ComputeTilingSchemeForReduction(
+ int64 depth, int64 width, int64 kWarpSize) {
+ constexpr int64 kTargetNumElementsPerThread = 64;
+ int64 x_tile_size = kTargetNumElementsPerThread;
+ int64 z_tile_size = 1;
+
+ // Only tile along the x dimension with tile size kTargetNumElementsPerThread
+ // if doing so doesn't require a slow version of loop with bound check on each
+ // dimension. A more sophisticated heuristics is to enable tile along the
+ // x dimension with tile size kTargetNumElementsPerThread when either width is
+ // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big
+ // enough so that only a small fraction of the threads execute the slow
+ // version of loop with bound check.
+ if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) {
+ x_tile_size = 8;
+ z_tile_size = 8;
+ while (depth % z_tile_size != 0) {
+ z_tile_size -= 1;
+ }
+ }
+
+ return std::pair<int64, int64>(x_tile_size, z_tile_size);
+}
+
Status IrEmitterUnnested::EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// A naive algorithm is:
- // 1. Divide the input tensor into tiles of size 1x1xK.
+ // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX.
// 2. Partially reduces each tile to a scalar using one thread.
// 3. Accumulates that scalar to the output vector using atomic operations.
//
@@ -1367,15 +1438,15 @@ Status IrEmitterUnnested::EmitRowReduction(
// int y = linear_index / width_in_tiles % height;
// int z = linear_index / (height * width_in_tiles);
// float partial_result = 0;
- // for (element_id_in_tile : range(kTileSize)) {
- // int x = x_in_tiles * kTileSize + element_id_in_tile;
+ // for (element_id_in_tile : range(x_tile_size)) {
+ // int x = x_in_tiles * x_tile_size + element_id_in_tile;
// if (x < width)
// partial_result = reducer(partial_result, input[z][y][z]);
// }
// AtomicReducer(&output[y], partial_result);
// }
//
- // Three optimizations are performed.
+ // Four optimizations are performed.
//
// 1. To coalesce global memory accesses, dilate the tile with a factor of 32
// (i.e. the warp size). For example, suppose the width is 8x32=256. Instead
@@ -1402,29 +1473,44 @@ Status IrEmitterUnnested::EmitRowReduction(
// element_id_in_tile, which makes the code more friendly to optimizations
// such as LICM.
//
+ // 4. When the width is too small and x_tile_size is less than the target
+ // number of elements per thread and use a small factor of depth as
+ // z_tile_size to increase the number of elements calculated by each
+ // partial sum. This can reduce the needed number of dynamic shfl_down and
+ // atomic operations.
+ //
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
// linear_index < depth * height * width_in_tiles;
// linear_index += blockDim.x * gridDim.x) {
// int x_in_tiles = linear_index % width_in_tiles;
// int y = linear_index / width_in_tiles % height;
- // int z = linear_index / (height * width_in_tiles);
+ // int z_in_tiles = linear_index / (height * width_in_tiles);
// int warp_id = x_in_tiles / warpSize;
// int lane_id = x_in_tiles % warpSize;
// float partial_result = 0;
// int x = warp_id * kTileSize * warpSize + lane_id;
- // if (width % (kTileSize * warpSize) == 0 ||
- // x + (kTileSize - 1) * warpSize < width) {
- // // The entire tile is in bounds.
- // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
- // ++element_id_in_tile, x += warpSize) {
- // partial_result = Reducer(partial_result, input[z][y][x]);
+ // if (width % (x_tile_size * warpSize) == 0 ||
+ // x + (x_tile_size - 1) * warpSize < width) {
+ // // The entire x_tile is in bounds.
+ // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
+ // ++element_id_in_z_tile) {
+ // z = z_in_tiles * z_tile_size + element_id_in_z_tile;
+ // for (int element_id_in_x_tile = 0;element_id_in_x_tile < x_tile_size;
+ // ++element_id_in_x_tile, x += warpSize) {
+ // partial_result = Reducer(partial_result, input[z][y][x]);
+ // }
// }
// } else {
// // The tile is partially in bounds.
- // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
+ // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
+ // ++element_id_in_z_tile) {
+ // z = z_in_tiles * z_tile_size + element_id_in_z_tile;
+ // for (int element_id_in_x_tile = 0; element_id_in_x_tile <
+ // x_tile_size;
// ++element_id_in_tile, x += warpSize) {
- // if (x < width)
- // partial_result = Reducer(partial_result, input[z][y][x]);
+ // if (x < width)
+ // partial_result = Reducer(partial_result, input[z][y][x]);
+ // }
// }
// }
// for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2)
@@ -1435,17 +1521,20 @@ Status IrEmitterUnnested::EmitRowReduction(
// AtomicReducer(&output[y], partial_result);
// }
//
- // Choose 8 as the tile size, which matches Eigen's RowReduceKernel.
- constexpr int64 kTileSize = 8;
+
+ int64 x_tile_size;
+ int64 z_tile_size;
+ std::tie(x_tile_size, z_tile_size) =
+ ComputeTilingSchemeForReduction(depth, width, kWarpSize);
+
// Round the width in tiles up to the nearest multiple of kWarpSize, so that
// the use of shfl_down is valid.
const int64 width_in_tiles =
- RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize);
+ RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize);
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) {
+ // Emit the loop body that reduces one z-x-tile.
const int num_reduces = reducers.size();
- // Emit the loop body that reduces one tile.
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), ir_emitter_context_->llvm_module());
std::vector<llvm::Value*> partial_reduction_result_addresses;
@@ -1460,9 +1549,7 @@ Status IrEmitterUnnested::EmitRowReduction(
partial_reduction_result_address);
}
- // Emit an inner for-loop that partially reduces the elements in the given
- // tile.
- llvm::Value* z = tile_index[0];
+ llvm::Value* z_tile = tile_index[0];
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
llvm::Value* warp_id = ir_builder_.CreateUDiv(
@@ -1470,106 +1557,132 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* lane_id = ir_builder_.CreateURem(
x_tile, ir_builder_.getInt64(kWarpSize), "lane_id");
- // The x-location of the last element in this tile.
- // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize);
+ // The x-location of the last element in this z-x-tile.
+ // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id *
+ // x_tile_size);
llvm::Value* last_x = ir_builder_.CreateNSWAdd(
- lane_id,
- ir_builder_.CreateNSWMul(
- ir_builder_.getInt64(kWarpSize),
- ir_builder_.CreateNSWAdd(
- ir_builder_.getInt64(kTileSize - 1),
- ir_builder_.CreateNSWMul(warp_id,
- ir_builder_.getInt64(kTileSize)))));
-
- auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
- std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
- ir_builder_.getInt64(0),
- ir_builder_.getInt64(kTileSize),
- ir_builder_.getInt64(1), &ir_builder_);
-
- // Emit the body of the partial reduction loop.
- llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &ir_builder_);
- // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize);
- llvm::Value* x = ir_builder_.CreateNSWAdd(
- lane_id,
- ir_builder_.CreateNSWMul(
- ir_builder_.getInt64(kWarpSize),
- ir_builder_.CreateNSWAdd(
- tile_element_loop->GetIndVarValue(),
- ir_builder_.CreateNSWMul(warp_id,
- ir_builder_.getInt64(kTileSize)))));
-
- // Unless we know the tile is entirely in bounds, we have to emit a
- // x-in-bounds check before reading from the input.
- if (!tile_in_bounds) {
- llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)),
- "x_in_bounds", &ir_builder_);
-
- // Points ir_builder_ to the then-block.
- llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
- &ir_builder_);
- }
-
- // Emit code that reads the input element and accumulates it to the
- // partial reduction result.
- llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
- {
- // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We
- // need to convert that to an index to input_shape (the shape of the
- // operand of "reduce"). This conversion is composed of a transposition
- // from input_shape to normalized_input_shape and a reshape from
- // normalized_input_shape to input_3d_tensor_shape.
- const Shape normalized_input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input_shape);
- auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
- const std::vector<int64> transpose_dimension_mapping(
- input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
- const Shape input_3d_tensor_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
- {depth, height, width});
- const llvm_ir::IrArray::Index input_3d_tensor_index(
- {z, y, x}, input_3d_tensor_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
- input_3d_tensor_index
- .SourceIndexOfReshape(input_3d_tensor_shape,
- normalized_input_shape, &ir_builder_)
- .SourceIndexOfTranspose(normalized_input_shape, input_shape,
- transpose_dimension_mapping,
- &ir_builder_);
- for (int i = 0; i != num_reduces; ++i) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
- input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i], input_address},
- partial_reduction_result_addresses[i]));
- }
+ lane_id, ir_builder_.CreateNSWMul(
+ ir_builder_.getInt64(kWarpSize),
+ ir_builder_.CreateNSWAdd(
+ ir_builder_.getInt64(x_tile_size - 1),
+ ir_builder_.CreateNSWMul(
+ warp_id, ir_builder_.getInt64(x_tile_size)))));
+
+ KernelSupportLibrary ksl(
+ &ir_builder_,
+ /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll,
+ /*prevent_vectorization=*/false);
+
+ // Emit a for-loop that partially reduces the elements in the given
+ // z-x-tile.
+ auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
+ int64 x_tile_loop_bound) -> Status {
+ auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
+ llvm::Value* z = ir_builder_.CreateNSWAdd(
+ z_indvar, ir_builder_.CreateNSWMul(
+ ir_builder_.getInt64(z_tile_size), z_tile));
+
+ TF_RETURN_IF_ERROR(ksl.For(
+ "x_tile",
+ /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1,
+ [&](llvm::Value* x_indvar) -> Status {
+ // x = lane_id + warpSize * (element_id_in_x_tile + warp_id *
+ // x_tile_size);
+ llvm::Value* x = ir_builder_.CreateNSWAdd(
+ lane_id,
+ ir_builder_.CreateNSWMul(
+ ir_builder_.getInt64(kWarpSize),
+ ir_builder_.CreateNSWAdd(
+ x_indvar,
+ ir_builder_.CreateNSWMul(
+ warp_id, ir_builder_.getInt64(x_tile_size)))));
+
+ // Unless we know the x-tile is entirely in bounds, we have to
+ // emit a x-in-bounds check before reading from the input.
+ if (!x_tile_in_bounds) {
+ llvm_ir::LlvmIfData if_x_in_bounds_data =
+ llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT(
+ x, ir_builder_.getInt64(width)),
+ "x_in_bounds", &ir_builder_);
+ // Points ir_builder_ to the then-block.
+ llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
+ &ir_builder_);
+ }
+
+ // Emit code that reads the input element and accumulates it
+ // to the partial reduction result.
+ llvm::Value* input_address =
+ ir_builder_.CreateAlloca(element_ir_type);
+ {
+ // {z,y,x} is an index to input_3d_tensor_shape
+ // [depth,height,width]. We need to convert that to an index
+ // to input_shape (the shape of the operand of "reduce").
+ // This conversion is composed of a transposition from
+ // input_shape to normalized_input_shape and a reshape from
+ // normalized_input_shape to input_3d_tensor_shape.
+ const Shape normalized_input_shape = ShapeUtil::
+ MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ input_shape);
+ auto input_shape_min2maj =
+ LayoutUtil::MinorToMajor(input_shape);
+ const std::vector<int64> transpose_dimension_mapping(
+ input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
+ const Shape input_3d_tensor_shape =
+ ShapeUtil::MakeShapeWithDescendingLayout(
+ input_shape.element_type(), {depth, height, width});
+ const llvm_ir::IrArray::Index input_3d_tensor_index(
+ {z, y, x}, input_3d_tensor_shape, &ir_builder_);
+ const llvm_ir::IrArray::Index input_index =
+ input_3d_tensor_index
+ .SourceIndexOfReshape(input_3d_tensor_shape,
+ normalized_input_shape,
+ &ir_builder_)
+ .SourceIndexOfTranspose(
+ normalized_input_shape, input_shape,
+ transpose_dimension_mapping, &ir_builder_);
+
+ for (int i = 0; i != num_reduces; ++i) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+ input_gens[i](input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], input_address},
+ partial_reduction_result_addresses[i]));
+ }
+ return EmitExtraOutputsForReduce(reduce, input_index,
+ extra_output_gens);
+ }
+ }));
return Status::OK();
- }
+ };
+
+ return ksl.For("z_tile",
+ /*start=*/0, /*end=*/z_tile_size, /*step=*/1,
+ emit_z_tile_element_loop);
};
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0),
+ ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0),
ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width)));
- llvm_ir::LlvmIfData if_tile_in_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
- &ir_builder_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
- &ir_builder_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
- // After the if-then-else statement on tile_in_bounds, emit calls to
- // shfl_down that accumulate the partial reduction results of all threads
- // from the warp.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
- &ir_builder_);
+ TF_RETURN_IF_ERROR(
+ ksl.If(tile_in_bounds,
+ /*true_block_generator=*/
+ [&]() -> Status {
+ return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true,
+ x_tile_size);
+ },
+ /*false_block_generator=*/
+ [&]() -> Status {
+ return emit_z_x_tile_element_loop(
+ /*x_tile_in_bounds=*/false,
+ CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize));
+ }));
+
+ // After accumulating the elements of the z_x_tile, emit calls to
+ // shfl_down that accumulate the partial reduction results of all
+ // threads in a warp.
int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
// bitcast cannot be applied to aggregate types (even packed ones), so we
// instead bitcast addresses of load/store to intN* of the same bit-width.
@@ -1610,29 +1723,33 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
for (int i = 0; i != num_reduces; ++i) {
- ShapeIndex output_shape_index;
- if (output->IsMultiOutputFusion()) {
- output_shape_index = {i};
- }
llvm::Value* output_address =
- GetIrArray(*output, *output, output_shape_index)
+ GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
llvm_ir::IrArray::Index(
y,
ShapeUtil::GetSubshape(output->shape(),
- output_shape_index),
+ reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
- TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address, partial_reduction_result_addresses[i]));
+ if (x_tile_size * z_tile_size < depth * width) {
+ TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address,
+ partial_reduction_result_addresses[i]));
+ } else {
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {output_address, partial_reduction_result_addresses[i]},
+ output_address));
+ }
}
return Status::OK();
};
// Emit a parallel loop that iterates through every input tiles.
Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(), {depth, height, width_in_tiles},
- {2, 1, 0});
+ reduce->shape().element_type(),
+ {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0});
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
@@ -1656,7 +1773,11 @@ Status IrEmitterUnnested::EmitReductionToVector(
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// This emission requires "reduce" to have an input layout. It is either set
// by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
// a fused kReduce).
@@ -1692,7 +1813,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
// dimension of the input is to keep.
if (input_dims_to_keep.empty()) {
return EmitReductionToScalar(reduce, input_shape, input_gens,
- init_value_gens, reducers);
+ init_value_gens, reducers,
+ reduce_output_shapes, extra_output_gens);
} else if (input_dims_to_keep.front() ==
LayoutUtil::Minor(input_shape.layout(), 0)) {
// Column reduction. Treat the result of "input" as a matrix whose width
@@ -1710,7 +1832,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
}
}
return EmitColumnReduction(height, width, reduce, input_shape, input_gens,
- init_value_gens, reducers);
+ init_value_gens, reducers, reduce_output_shapes,
+ extra_output_gens);
} else {
// Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
// 3D tensor. The size of dimension 1 (the height) is the size of the
@@ -1736,7 +1859,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
}
const int64 height = ShapeUtil::ElementsIn(reduce->shape());
return EmitRowReduction(depth, height, width, reduce, input_shape,
- input_gens, init_value_gens, reducers);
+ input_gens, init_value_gens, reducers,
+ reduce_output_shapes, extra_output_gens);
}
}
@@ -1768,7 +1892,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
return GetIrArray(*init_value, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
- dimensions_to_reduce, {reducer});
+ dimensions_to_reduce, {reducer}, {{}}, {});
}
thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
@@ -2390,7 +2514,9 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (alpha->opcode() == HloOpcode::kBroadcast) {
alpha = alpha->operand(0);
}
- alpha = inst->operand(alpha->parameter_number());
+ if (alpha->opcode() == HloOpcode::kParameter) {
+ alpha = inst->operand(alpha->parameter_number());
+ }
// TODO(b/74185543): Remove the following if block once we support fusion
// with a non-constant as well. Then we will just always use the constant
// on the device.
@@ -2436,7 +2562,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
const HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
- const HloInstruction* init_value = [&] {
+ const HloInstruction* init_value_operand = [&] {
switch (inst->opcode()) {
case HloOpcode::kSelectAndScatter:
return inst->operand(2);
@@ -2456,6 +2582,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
}();
+ const HloInstruction* init_value = init_value_operand;
if (fused && init_value->opcode() == HloOpcode::kParameter) {
init_value = hlo->operand(init_value->parameter_number());
}
@@ -2507,13 +2634,24 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// Otherwise fall back to our slow initializer code.
std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(hlo);
- TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
- *hlo,
- [=](const llvm_ir::IrArray::Index& index) {
- return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &ir_builder_);
- },
- kernel_thunk.get()));
+ LaunchDimensions launch_dimensions =
+ CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
+ ir_emitter_context_->device_description());
+ UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
+ ir_emitter_context_->llvm_module());
+ // If the init_value was fused into this reduce we have to generate it first.
+ if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
+ CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
+ TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value)));
+ }
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+ [=](const llvm_ir::IrArray::Index& index) {
+ return GetIrArray(*init_value, *hlo)
+ .EmitReadArrayElement(index, &ir_builder_);
+ },
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &ir_builder_)
+ .EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
// done in IrEmitterUnnested::Postprocess().)
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index b41eaa303b..202231b82f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -100,6 +100,13 @@ class IrEmitterUnnested : public IrEmitter {
const HloInstruction& inst,
tensorflow::gtl::ArraySlice<const BufferAllocation*> args);
+ // Helper for writing extra outputs from inside a reduce kernel.
+ Status EmitExtraOutputsForReduce(
+ const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
+
// EmitColumnReduction and EmitRowReduction emit code for column and row
// reduction of a matrix and/or 3D tensor. Row and column reduction have
// different memory access pattern, so for performance their implementations
@@ -115,7 +122,11 @@ class IrEmitterUnnested : public IrEmitter {
const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
// vector of shape [height]. Other parameters have the same meaning as those
@@ -127,14 +138,22 @@ class IrEmitterUnnested : public IrEmitter {
const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Emits code that reduces a tensor of arbitrary rank to a scalar.
Status EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Figures out whether `reduce` is a row or column reduction, and which
// dimensions to reduce, and calls either `EmitRowReduction` or
@@ -147,13 +166,21 @@ class IrEmitterUnnested : public IrEmitter {
// Multiple reduces can be emitted in the same loop, assuming they have the
// same input and output shapes, and the same reduce dimensions.
//
+ // extra_output_gens can contain extra generators for intermediate outputs.
+ // These must have the same shape as the reduce input as they are computed
+ // when the reduce inputs are being read.
+ //
// Prerequisite: `IsReductionToVector(*reduce)`
Status EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
new file mode 100644
index 0000000000..86c5c4fb6f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -0,0 +1,118 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+
+#include <stdint.h>
+#include <algorithm>
+#include <iterator>
+#include <list>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {}
+
+bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ auto get_element_shape = [&](HloInstruction* instr) {
+ const HloInstruction* element_instr = instr;
+ if (instr->opcode() == HloOpcode::kFusion) {
+ auto fused_expression_root = instr->fused_expression_root();
+ if (instr->IsMultiOutputFusion()) {
+ // The shapes in all tuple operands should agree. Just pick the first
+ // one.
+ element_instr = fused_expression_root->operands()[0];
+ } else {
+ element_instr = fused_expression_root;
+ }
+ }
+ return element_instr->shape();
+ };
+
+ // The elementwise output shapes must be the same (including layout)
+ return ShapeUtil::ShapeUtil::Equal(get_element_shape(instr1),
+ get_element_shape(instr2));
+}
+
+bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
+ // kConstant instruction will not have memory reads, so it won't be a profit
+ // source. Skip them.
+ if (instr->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsEffectiveScalar(instr->shape())) {
+ return false;
+ }
+ // We don't target to fuse producer/consumer instructions -- this should
+ // be taken care of by the instruction_fusion pass. If instr has only
+ // one user, it will not have sibling instructions. We won't consider it.
+ if (instr->user_count() < 2) {
+ return false;
+ }
+ return true;
+}
+
+namespace {
+bool IsReduction(HloInstruction* instr) {
+ if (instr->IsMultiOutputFusion()) {
+ for (const HloInstruction* operand :
+ instr->fused_expression_root()->operands()) {
+ if (operand->opcode() == HloOpcode::kReduce) {
+ return true;
+ }
+ }
+ return false;
+ } else if (instr->opcode() == HloOpcode::kFusion) {
+ return instr->fused_expression_root()->opcode() == HloOpcode::kReduce;
+ } else {
+ return instr->opcode() == HloOpcode::kReduce;
+ }
+}
+} // namespace
+
+bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
+ return IsReduction(instr);
+}
+
+int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ for (auto instr : instr1->operands()) {
+ if (!IsProfitableOperand(instr)) {
+ continue;
+ }
+ in_list.insert(instr);
+ }
+ int64 profit = 0;
+ for (auto instr : instr2->operands()) {
+ if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) {
+ continue;
+ }
+ profit += ShapeUtil::ByteSizeOf(instr->shape());
+ }
+ VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name()
+ << ", the profit is =" << profit;
+ return profit;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
new file mode 100644
index 0000000000..5451a93cec
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
+
+#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+
+namespace xla {
+namespace gpu {
+
+// Multi-output fusion of sibling and producer-consumer instructions for the
+// Jellyfish backend.
+class GpuMultiOutputFusion : public MultiOutputFusion {
+ public:
+ GpuMultiOutputFusion();
+
+ protected:
+ // Test if instr1 and instr2 have the compatible shapes that can be legally
+ // fused.
+ bool ShapesCompatibleForFusion(HloInstruction* instr1,
+ HloInstruction* instr2) override;
+
+ // We currently only consider reduce and reduce fusion nodes as candidates.
+ bool IsFusible(HloInstruction* instr) override;
+
+ // This function estimates the amount of memory reads saved by merging
+ // instr1 and instr2 into one multi-output fusion instruction. For a fusion
+ // instruction, all the operands need to be loaded from memory. If we merge
+ // instr1 and instr2, common operands will not be loaded twice. The profit is
+ // estimated as the size of the common operands b/w instr1 and instr2.
+ int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override;
+
+ // Whether fusing the instruction can reduce memory reads.
+ //
+ // TODO(tjoerg): Move this method up into the MultiOutputFusion base class.
+ bool IsProfitableOperand(HloInstruction* instr) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
new file mode 100644
index 0000000000..d0b4c88487
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace op = xla::testing::opcode_matchers;
+
+namespace xla {
+namespace gpu {
+
+using InstructionFusionTest = HloTestBase;
+
+const char kModulePrefix[] = R"(
+ HloModule test_module
+
+ scalar_add_computation {
+ scalar_lhs = f32[] parameter(0)
+ scalar_rhs = f32[] parameter(1)
+ ROOT add = f32[] add(scalar_lhs, scalar_rhs)
+ })";
+
+TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
+ // Fusion with reduce instruction root and a sibling reduce instruction
+ // sharing the same input param.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] constant(1)
+ fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
+ reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce()));
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
+ // Two sibling fusions with reduce instruction roots sharing the same input
+ // param.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce()));
+}
+
+TEST_F(InstructionFusionTest,
+ MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
+ // Multi-output fusion with two reduce instructions root and a sibling reduce
+ // instruction sharing the same input param.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
+ const.1 = f32[] constant(1)
+ p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
+ reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
+ }
+
+ ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ const = f32[] constant(1)
+ fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
+ get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
+ get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
+ reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Reduce()));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 696fa7e019..6f4bb0580e 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -33,8 +33,7 @@ class StreamAssignmentTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>("test_module", config);
}
// Pre-canned shapes.
diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h
index b15f1f24c6..7f73bba036 100644
--- a/tensorflow/compiler/xla/service/hlo_casting_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h
@@ -18,10 +18,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include <type_traits>
+#include "tensorflow/core/platform/logging.h"
namespace xla {
+class HloInstruction;
+
template <class T>
using EnableIfDerivedFromHlo =
typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type;
diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc
index 436a922234..a336427540 100644
--- a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index b61eabbbf5..763d9d2269 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -64,7 +64,7 @@ HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
HloInstruction* root_instruction, HloInstruction* fusion_instruction)
- : name_(name),
+ : name_(NameUniquer::GetSanitizedName(name)),
unique_id_(-1),
root_instruction_(root_instruction),
fusion_instruction_(fusion_instruction) {
@@ -315,12 +315,49 @@ void ComputeComputationPostOrder(
}
}
+std::list<HloInstruction*> ComputeInstructionPostOrder(
+ HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
+ std::list<HloInstruction*> post_order;
+ std::vector<std::pair<HloInstruction*, bool>> dfs_stack;
+ dfs_stack.emplace_back(root, false);
+ while (!dfs_stack.empty()) {
+ const auto current = dfs_stack.back();
+ if (current.second) {
+ dfs_stack.pop_back();
+ if (!visited->insert(current.first).second) {
+ continue;
+ }
+ post_order.push_back(current.first);
+ } else {
+ if (visited->count(current.first)) {
+ dfs_stack.pop_back();
+ continue;
+ }
+ dfs_stack.back().second = true;
+
+ // Add the operands to the stack in reverse order so the first operand is
+ // processed first. This will produce a more natural ordering and a nicer
+ // result for thigns like HLO stringification.
+ const auto& operands = current.first->operands();
+ for (int64 i = operands.size() - 1; i >= 0; --i) {
+ dfs_stack.emplace_back(operands[i], false);
+ }
+
+ for (HloInstruction* op : current.first->control_predecessors()) {
+ dfs_stack.emplace_back(op, false);
+ }
+ }
+ }
+ return post_order;
+}
+
} // namespace
std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::list<HloInstruction*> post_order;
std::list<HloInstruction*> trace_instructions;
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
+ std::vector<HloInstruction> dfs_stack;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -328,9 +365,9 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- post_order.splice(post_order.end(),
- InstructionPostOrderer::GetOrder(instruction.get(),
- &added_instructions));
+ post_order.splice(
+ post_order.end(),
+ ComputeInstructionPostOrder(instruction.get(), &added_instructions));
}
}
post_order.splice(post_order.end(), trace_instructions);
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 94c9c7eabc..92a66681a9 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -172,7 +172,8 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleSlice(const HloInstruction*) {
+Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
+ current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2;
return Status::OK();
}
@@ -386,6 +387,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index d17678d20f..0d66736fe1 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleBroadcast(const HloInstruction* broadcast) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
+ Status HandleGenerateToken(const HloInstruction* token) override;
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 16fdda8a8b..72adf09c83 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -460,5 +460,20 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
EXPECT_EQ(analysis.flop_count(), 1472);
}
+TEST_F(HloCostAnalysisTest, Slice) {
+ // Test the analysis on a slice.
+ XlaBuilder builder("slice");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x");
+ auto slice = builder.Slice(x, {0}, {1}, {1});
+ auto hlo_module = BuildHloGraph(&builder);
+
+ // Run HLO cost analysis.
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ EXPECT_EQ(analysis.bytes_accessed(), 8);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index dab946a099..a0ee889623 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -135,17 +135,18 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
// instruction for each class.
tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
decltype(cse_equal)>
- representatives(/*N=*/1024, &CseHash, cse_equal);
-
+ representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
+ cse_equal);
for (auto instruction : computation->MakeInstructionPostOrder()) {
// If the instruction has zero operands (constants, parameters, etc.) skip
// over it.
if (instruction->operand_count() == 0) {
continue;
}
-
- // Skip instructions which have side effects.
- if (instruction->HasSideEffect()) {
+ // Skip instructions which have side effects or are a domain (which must
+ // not be CSE-ed).
+ if (instruction->HasSideEffect() ||
+ instruction->opcode() == HloOpcode::kDomain) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index cc130a4900..d020005868 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -931,16 +931,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
const HloUse& use = value.uses()[0];
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- // Loop fusion with kDynamicUpdateSlice fused root.
- //
- // Returns true iff there is exactly one use of 'operand' at shape index
- // 'operand_index', and this singleton use is the fused root at operand
- // index 0.
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == 0;
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ if (user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == 0;
+ }
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -967,6 +968,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
use.operand_number == other_add_operand_index;
}
}
+
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
@@ -998,8 +1000,13 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}) != uses.end();
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
}
- // Check if 'user' is element-wise.
- return user->IsElementwise();
+
+ // Loop fusions that contain transposing copies won't reach here as they have
+ // different layouts, which fails the check in the beginning of this function.
+ //
+ // Multi-output fusion will fail the check here as tuples are not considered
+ // an elementwise operation.
+ return user->IsElementwiseOnOperand(user->operand_index(operand));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 5798326dcb..db1822ec47 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1974,6 +1974,89 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ NonElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "param0"));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
+
+ auto reverse = builder.AddInstruction(
+ HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {reverse, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ MultiOutputFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ Shape in_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, in_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, in_shape, "param1"));
+
+ auto copy0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
+ auto copy1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {0}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {1}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {0}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {1}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ ElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
+
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {exp, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
auto builder = HloComputation::Builder(TestName());
@@ -2048,6 +2131,46 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
fusion, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ FusedDynamicUpdateSliceWithConvertCantShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ auto convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape_bf16, gte1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape_bf16, convert1, update, starts));
+
+ auto convert2 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {convert2, dynamic_update_slice, starts, update, convert1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction can't share with tuple element 1.
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 1e78d775c8..e0648e1467 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -910,6 +910,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+Status HloEvaluator::HandleGenerateToken(HloInstruction* token) {
+ // Literals cannot represent a TOKEN shape so just create an empty tuple as
+ // the "result" of the kGenerateToken operation.
+ // TODO(b/109929053): Add support for TOKENs in Literals.
+ evaluated_[token] = Literal::MakeTuple({});
+ return Status::OK();
+}
+
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index b53d5644de..fc2fc9437b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -174,6 +174,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleGenerateToken(HloInstruction* token) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 84b4ead2dd..72eb9930e9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1248,7 +1248,7 @@ void BM_ReducePrecisely(int num_iters) {
HloComputation::Builder b("BM_ReducePrecisely");
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config);
+ HloModule module("BM_ReducePrecisely", config);
constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
std::vector<float> v(kNumElements, 1.0f);
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 61612bebd1..05aab9a2cd 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -723,11 +723,28 @@ string HloDotDumper::DumpRootTag() {
to_id, node_body, node_shape, NodeColorAttributes(color));
}
+static const HloInstruction* TryGetFusionParameterConstant(
+ const HloInstruction* instr) {
+ if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
+ return nullptr;
+ }
+ const HloInstruction* fusion = instr->parent()->FusionInstruction();
+ const HloInstruction* operand = fusion->operand(instr->parameter_number());
+ if (operand->opcode() == HloOpcode::kConstant) {
+ return operand;
+ }
+ return nullptr;
+}
+
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
// If a node:
//
- // - is a tuple-shaped parameter,
- // - is not a parameter to a fusion node,
+ // - is a parameter of a fusion node which is bound to a constant,
+ //
+ // or
+ //
+ // - is a tuple-shaped parameter, and
+ // - is not a parameter to a fusion node, and
// - has at least kMinUsersToOmit users shown, and
// - all of the shown users are get-tuple-elements,
//
@@ -735,6 +752,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
//
// This helps us handle the common case where a while loop body has one big
// tuple-shaped parameter.
+ if (TryGetFusionParameterConstant(instr) != nullptr) {
+ return true;
+ }
const int kMinUsersToOmit = 3;
return instr->opcode() == HloOpcode::kParameter &&
ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() &&
@@ -841,17 +861,6 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
ShapeUtil::HumanString(constant->shape()));
};
- // Special case: If instr is a parameter to a fusion node, check whether the
- // corresponding operand to the fusion node is a constant.
- if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
- const HloInstruction* fusion = instr->parent()->FusionInstruction();
- const HloInstruction* operand = fusion->operand(instr->parameter_number());
- if (operand->opcode() != HloOpcode::kConstant) {
- return "";
- }
- return StrCat("<b>constant</b> ", stringify_constant(operand));
- }
-
std::vector<string> lines;
for (int64 i = 0; i < instr->operand_count(); ++i) {
const HloInstruction* operand = instr->operand(i);
@@ -859,11 +868,18 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
if (operand->opcode() == HloOpcode::kConstant) {
operand_str = stringify_constant(operand);
} else if (ShouldMergeIntoUsers(operand)) {
- // Special case: If the operand is a parameter, use its parameter number
- // rather than its name, because that's generally how people think of the
- // node.
+ // Special case: If the operand is a parameter to a fusion node and it
+ // always has a constant value, display it like a regular constant.
+ //
+ // For other parameters, use the parameter number rather than the proper
+ // name, because that's generally how people think of the node.
if (operand->opcode() == HloOpcode::kParameter) {
- operand_str = Printf("Parameter %lld", operand->parameter_number());
+ if (const HloInstruction* constant =
+ TryGetFusionParameterConstant(operand)) {
+ operand_str = stringify_constant(constant);
+ } else {
+ operand_str = Printf("Parameter %lld", operand->parameter_number());
+ }
} else {
operand_str = operand->name();
}
@@ -897,11 +913,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
const auto kParameterColor = kOrange;
// Special case: If this instruction has a parameter merged into it, paint it
- // the same color as a parameter.
+ // the same color as a parameter. Unless the merged-in parameter is a
+ // parameter to a fusion node that is bound to a constant -- these aren't
+ // "real" parameters from the user's perspective.
if (std::any_of(instr->operands().begin(), instr->operands().end(),
[&](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kParameter &&
- ShouldMergeIntoUsers(operand);
+ ShouldMergeIntoUsers(operand) &&
+ TryGetFusionParameterConstant(operand) == nullptr;
})) {
return kParameterColor;
}
@@ -964,6 +983,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kBitcast:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
+ case HloOpcode::kGenerateToken:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
@@ -975,7 +995,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
}
return kGreen;
case HloOpcode::kConcatenate:
- case HloOpcode::kCopy:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kPad:
@@ -997,6 +1016,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kWhite;
}
return kGreen;
+ case HloOpcode::kCopy:
+ // Emphasize copy nodes, which are either physical transposes (and thus
+ // significant), or copies of read-only buffers (and thus dead weight).
+ return kGreen;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
case HloOpcode::kFft:
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 8e52d926d8..68f41a1cbb 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -121,7 +121,7 @@ TEST(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
- instruction->set_name("i_am_a_constant_root_instruction");
+ instruction->SetAndSanitizeName("i_am_a_constant_root_instruction");
HloModuleConfig config;
HloModule m(TestName(), config);
HloComputation* root_computation = m.AddEntryComputation(b.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 06775d6a9a..c89d836888 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -27,7 +27,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -60,17 +62,145 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
- auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
- for (const int64 operand_id : proto.operand_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
- << "No instruction with id " << operand_id;
- instruction->AppendOperand(instruction_map.at(operand_id));
- }
- for (const int64 predecessor_id : proto.control_predecessor_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
- << "No instruction with id " << predecessor_id;
- TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
- ->AddControlDependencyTo(instruction.get()));
+ std::unique_ptr<HloInstruction> instruction;
+ const auto operands = [&instruction_map, &proto](int index) {
+ return instruction_map.at(proto.operand_ids(index));
+ };
+ const auto computations = [&computation_map, &proto](int index) {
+ return computation_map.at(proto.called_computation_ids(index));
+ };
+ switch (opcode) {
+ // Ops migrated to subclasses.
+ case HloOpcode::kBatchNormTraining:
+ CHECK_EQ(proto.operand_ids_size(), 3);
+ instruction = CreateBatchNormTraining(
+ proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(),
+ proto.feature_index());
+ break;
+ case HloOpcode::kBatchNormInference:
+ CHECK_EQ(proto.operand_ids_size(), 5);
+ instruction = CreateBatchNormInference(
+ proto.shape(), operands(0), operands(1), operands(2), operands(3),
+ operands(4), proto.epsilon(), proto.feature_index());
+ break;
+ case HloOpcode::kBatchNormGrad:
+ CHECK_EQ(proto.operand_ids_size(), 5);
+ instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1),
+ operands(2), operands(3), operands(4),
+ proto.epsilon(), proto.feature_index());
+ break;
+ case HloOpcode::kFft: {
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ std::vector<int64> fft_length(proto.fft_length().begin(),
+ proto.fft_length().end());
+ instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
+ tensorflow::gtl::ArraySlice<int64>(fft_length));
+ break;
+ }
+ case HloOpcode::kSend:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction = CreateSend(operands(0), proto.channel_id());
+ break;
+ case HloOpcode::kSendDone:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction = CreateSendDone(operands(0));
+ break;
+ case HloOpcode::kRecv:
+ CHECK_EQ(proto.operand_ids_size(), 0);
+ instruction =
+ CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id());
+ break;
+ case HloOpcode::kRecvDone:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction = CreateRecvDone(operands(0));
+ break;
+ case HloOpcode::kReverse:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction = CreateReverse(proto.shape(), operands(0),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()));
+ break;
+ case HloOpcode::kConcatenate: {
+ CHECK_EQ(proto.dimensions_size(), 1);
+ std::vector<HloInstruction*> concat_operands(proto.operand_ids_size());
+ std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
+ concat_operands.begin(),
+ [&instruction_map](int64 operand_id) {
+ return instruction_map.at(operand_id);
+ });
+ instruction = CreateConcatenate(proto.shape(), concat_operands,
+ proto.dimensions(0));
+ break;
+ }
+ case HloOpcode::kReduce:
+ CHECK_EQ(proto.operand_ids_size(), 2);
+ CHECK_EQ(proto.called_computation_ids_size(), 1);
+ instruction = CreateReduce(proto.shape(), operands(0), operands(1),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()),
+ computations(0));
+ break;
+ case HloOpcode::kTranspose:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction =
+ CreateTranspose(proto.shape(), operands(0),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()));
+ break;
+ case HloOpcode::kBroadcast:
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction =
+ CreateBroadcast(proto.shape(), operands(0),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()));
+ break;
+ case HloOpcode::kMap: {
+ CHECK_EQ(proto.called_computation_ids_size(), 1);
+ std::vector<HloInstruction*> map_operands(proto.operand_ids_size());
+ std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
+ map_operands.begin(),
+ [&instruction_map](int64 operand_id) {
+ return instruction_map.at(operand_id);
+ });
+ instruction = CreateMap(proto.shape(), map_operands, computations(0));
+ break;
+ }
+ case HloOpcode::kSlice: {
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ std::vector<int64> slice_starts, slice_limits, slice_strides;
+ for (const HloInstructionProto::SliceDimensions& slice_dimensions :
+ proto.slice_dimensions()) {
+ slice_starts.push_back(slice_dimensions.start());
+ slice_limits.push_back(slice_dimensions.limit());
+ slice_strides.push_back(slice_dimensions.stride());
+ }
+ instruction = CreateSlice(proto.shape(), operands(0), slice_starts,
+ slice_limits, slice_strides);
+ break;
+ }
+ default: {
+ instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
+ for (const int64 operand_id : proto.operand_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
+ << "No instruction with id " << operand_id;
+ instruction->AppendOperand(instruction_map.at(operand_id));
+ }
+ for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+ << "No instruction with id " << predecessor_id;
+ TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
+ ->AddControlDependencyTo(instruction.get()));
+ }
+ if (instruction->opcode() != HloOpcode::kFusion) {
+ for (const int64 computation_id : proto.called_computation_ids()) {
+ TF_RET_CHECK(ContainsKey(computation_map, computation_id))
+ << "No computation with id " << computation_id;
+ instruction->called_computations_.push_back(
+ computation_map.at(computation_id));
+ }
+ }
+ break;
+ }
}
// In the proto, fused computations are held exclusively within the
@@ -91,13 +221,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "No fusion computation with id " << fusion_id;
fused_computation->SetFusionInstruction(instruction.get());
instruction->called_computations_.push_back(fused_computation);
- } else {
- for (const int64 computation_id : proto.called_computation_ids()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_id))
- << "No computation with id " << computation_id;
- instruction->called_computations_.push_back(
- computation_map.at(computation_id));
- }
}
if (instruction->opcode() == HloOpcode::kTrace) {
@@ -108,7 +231,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
TF_RET_CHECK(!proto.name().empty());
- instruction->name_ = proto.name();
+ instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
@@ -119,9 +242,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->parameter_number_ = proto.parameter_number();
instruction->tuple_index_ = proto.tuple_index();
- for (int64 dimension : proto.dimensions()) {
- instruction->dimensions_.push_back(dimension);
- }
if (proto.has_window()) {
instruction->window_ = MakeUnique<Window>(proto.window());
}
@@ -134,12 +254,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->dot_dimension_numbers_ =
MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
}
- for (const HloInstructionProto::SliceDimensions& slice_dimensions :
- proto.slice_dimensions()) {
- instruction->slice_starts_.push_back(slice_dimensions.start());
- instruction->slice_limits_.push_back(slice_dimensions.limit());
- instruction->slice_strides_.push_back(slice_dimensions.stride());
- }
+
instruction->exponent_bits_ = proto.exponent_bits();
instruction->mantissa_bits_ = proto.mantissa_bits();
for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) {
@@ -151,16 +266,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
instruction->outfeed_config_ = proto.outfeed_config();
instruction->distribution_ = proto.distribution();
- instruction->epsilon_ = proto.epsilon();
- instruction->feature_index_ = proto.feature_index();
- instruction->channel_id_ = proto.channel_id();
instruction->infeed_config_ = proto.infeed_config();
instruction->custom_call_target_ = proto.custom_call_target();
instruction->outfeed_shape_ = proto.outfeed_shape();
- instruction->fft_type_ = proto.fft_type();
- for (int64 fft_len : proto.fft_length()) {
- instruction->fft_length_.push_back(fft_len);
- }
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -187,7 +295,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
instruction->parameter_number_ = parameter_number;
- instruction->name_ = name;
+ instruction->SetAndSanitizeName(name);
return instruction;
}
@@ -344,13 +452,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* map_computation,
tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
- CHECK(static_operands.empty()) << "static_operands not yet supported";
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->called_computations_.push_back(map_computation);
- return instruction;
+ return MakeUnique<HloMapInstruction>(shape, operands, map_computation,
+ static_operands);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
@@ -376,11 +479,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape));
- instruction->AppendOperand(operand);
- instruction->fft_type_ = fft_type;
- instruction->fft_length_.assign(fft_length.begin(), fft_length.end());
- return instruction;
+ return MakeUnique<HloFftInstruction>(shape, operand, fft_type, fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
@@ -462,56 +561,44 @@ HloInstruction::CreateCrossReplicaSum(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, int64 channel_id) {
- // Send instruction produces a tuple of {aliased operand, U32 context}.
- Shape output_shape = ShapeUtil::MakeTupleShape(
- {operand->shape(), ShapeUtil::MakeShape(U32, {})});
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = channel_id;
- return instruction;
+ return MakeUnique<HloSendInstruction>(operand, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand) {
- CHECK(operand->opcode() == HloOpcode::kSend)
+ auto send_operand = DynCast<HloSendInstruction>(operand);
+ CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
- auto instruction = WrapUnique(
- new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = operand->channel_id();
- return instruction;
+ return MakeUnique<HloSendDoneInstruction>(send_operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, int64 channel_id) {
- // Recv instruction produces a tuple of {receive buffer, U32 context}.
- Shape output_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
- instruction->channel_id_ = channel_id;
- return instruction;
+ return MakeUnique<HloRecvInstruction>(shape, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand) {
- CHECK(operand->opcode() == HloOpcode::kRecv)
+ auto recv_operand = DynCast<HloRecvInstruction>(operand);
+ CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
- Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = operand->channel_id();
- return instruction;
+ return MakeUnique<HloRecvDoneInstruction>(recv_operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape));
- instruction->AppendOperand(operand);
- instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
+ return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateGenerateToken(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ auto instruction = WrapUnique(new HloInstruction(
+ HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape()));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
return instruction;
}
@@ -548,18 +635,8 @@ HloInstruction::CreateCrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
- instruction->AppendOperand(operand);
- instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
- instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
- instruction->slice_strides_.assign(strides.begin(), strides.end());
- // For backward compatibility with old serialized computations: if there are
- // no strides, assume all strides are 1.
- // TODO(b/63317920): remove this code.
- if (instruction->slice_strides_.empty()) {
- instruction->slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
- }
- return instruction;
+ return MakeUnique<HloSliceInstruction>(shape, operand, start_indices,
+ limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
@@ -590,13 +667,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
int64 dimension) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->dimensions_.push_back(dimension);
- return instruction;
+ return MakeUnique<HloConcatenateInstruction>(shape, operands, dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
@@ -619,13 +690,8 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape));
- instruction->AppendOperand(arg);
- instruction->AppendOperand(init_value);
- instruction->dimensions_.assign(dimensions_to_reduce.begin(),
- dimensions_to_reduce.end());
- instruction->called_computations_.push_back(reduce_computation);
- return instruction;
+ return MakeUnique<HloReduceInstruction>(
+ shape, arg, init_value, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
@@ -646,14 +712,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64 feature_index) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(scale);
- instruction->AppendOperand(offset);
- instruction->epsilon_ = epsilon;
- instruction->feature_index_ = feature_index;
- return instruction;
+ return MakeUnique<HloBatchNormTrainingInstruction>(
+ shape, operand, scale, offset, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -661,16 +721,8 @@ HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(scale);
- instruction->AppendOperand(offset);
- instruction->AppendOperand(mean);
- instruction->AppendOperand(variance);
- instruction->epsilon_ = epsilon;
- instruction->feature_index_ = feature_index;
- return instruction;
+ return MakeUnique<HloBatchNormInferenceInstruction>(
+ shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -679,16 +731,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64 feature_index) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(scale);
- instruction->AppendOperand(mean);
- instruction->AppendOperand(variance);
- instruction->AppendOperand(grad_output);
- instruction->epsilon_ = epsilon;
- instruction->feature_index_ = feature_index;
- return instruction;
+ return MakeUnique<HloBatchNormGradInstruction>(shape, operand, scale, mean,
+ variance, grad_output, epsilon,
+ feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -711,12 +756,8 @@ HloInstruction::CreateSelectAndScatter(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape));
- instruction->AppendOperand(operand);
- instruction->dimensions_.assign(broadcast_dimensions.begin(),
- broadcast_dimensions.end());
- return instruction;
+ return MakeUnique<HloBroadcastInstruction>(shape, operand,
+ broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -795,19 +836,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- CHECK_EQ(shape.dimensions().size(), dimensions.size());
- CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
- CHECK(std::equal(operand->shape().dimensions().begin(),
- operand->shape().dimensions().end(),
- Permute(dimensions, shape.dimensions()).begin()))
- << "shape: " << ShapeUtil::HumanString(shape)
- << ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << Join(dimensions, ", ") << "}";
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
- instruction->AppendOperand(operand);
- instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
- return instruction;
+ return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
@@ -836,12 +865,12 @@ HloInstruction::CreateBroadcastSequence(
return instruction;
}
-void HloInstruction::set_device_sharding(int64 device) {
- HloSharding device_sharding = HloSharding::AssignDevice(device);
+void HloInstruction::set_single_sharding(const HloSharding& sharding) {
+ CHECK(!sharding.IsTuple()) << sharding;
if (ShapeUtil::IsTuple(shape())) {
- set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape())));
+ set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
} else {
- set_sharding(device_sharding);
+ set_sharding(sharding);
}
}
@@ -1275,6 +1304,25 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
// in the face of code changes than copying fields explicitly. This also
// properly sets the user fields of the operands.
switch (opcode_) {
+ // Ops migrated to subclasses.
+ // TODO(b/80131774): Remove this switch when migration is complete.
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kFft:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReverse:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kReduce:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kMap:
+ case HloOpcode::kSlice:
+ clone = CloneWithNewOperandsImpl(shape, new_operands, context);
+ break;
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -1333,10 +1381,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
new_operands[2]);
break;
// Other supported ops.
- case HloOpcode::kBroadcast:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateBroadcast(shape, new_operands[0], dimensions_);
- break;
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
@@ -1355,9 +1399,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateHostCompute(shape, new_operands, channel_name_,
cost_estimate_ns_);
break;
- case HloOpcode::kConcatenate:
- clone = CreateConcatenate(shape, new_operands, dimensions(0));
- break;
case HloOpcode::kConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateConvert(shape, new_operands[0]);
@@ -1381,10 +1422,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateDot(shape, new_operands[0], new_operands[1],
*dot_dimension_numbers_);
break;
- case HloOpcode::kFft:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
- break;
case HloOpcode::kCrossReplicaSum:
clone = CreateCrossReplicaSum(shape, new_operands, to_apply());
break;
@@ -1392,19 +1429,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateGetTupleElement(shape, new_operands[0], tuple_index());
break;
- case HloOpcode::kMap:
- clone = CreateMap(shape, new_operands, to_apply());
- break;
case HloOpcode::kPad:
CHECK_EQ(new_operands.size(), 2);
clone =
CreatePad(shape, new_operands[0], new_operands[1], *padding_config_);
break;
- case HloOpcode::kReduce:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_,
- to_apply());
- break;
case HloOpcode::kReduceWindow:
CHECK_EQ(new_operands.size(), 2);
clone = CreateReduceWindow(shape, new_operands[0], new_operands[1],
@@ -1416,10 +1445,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateSelectAndScatter(shape, new_operands[0], select(), *window_,
new_operands[1], new_operands[2], scatter());
break;
- case HloOpcode::kReverse:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateReverse(shape, new_operands[0], dimensions_);
- break;
case HloOpcode::kRng:
clone = CreateRng(shape, distribution_, new_operands);
break;
@@ -1427,11 +1452,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
break;
- case HloOpcode::kSlice:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
- slice_strides_);
- break;
case HloOpcode::kDynamicSlice:
clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1],
dynamic_slice_sizes_);
@@ -1441,10 +1461,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
new_operands[2]);
break;
- case HloOpcode::kTranspose:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateTranspose(shape, new_operands[0], dimensions_);
- break;
case HloOpcode::kTuple:
clone = CreateTuple(new_operands);
*clone->mutable_shape() = shape;
@@ -1476,18 +1492,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kParameter:
clone = CreateParameter(parameter_number_, shape, name_);
break;
- case HloOpcode::kBatchNormTraining:
- CHECK_EQ(new_operands.size(), 3);
- clone =
- CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
- new_operands[2], epsilon(), feature_index());
- break;
- case HloOpcode::kBatchNormInference:
- CHECK_EQ(new_operands.size(), 5);
- clone = CreateBatchNormInference(
- shape, new_operands[0], new_operands[1], new_operands[2],
- new_operands[3], new_operands[4], epsilon(), feature_index());
- break;
case HloOpcode::kInfeed:
CHECK_EQ(new_operands.size(), 0);
clone = CreateInfeed(shape, infeed_config());
@@ -1496,36 +1500,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
break;
- case HloOpcode::kBatchNormGrad:
- CHECK_EQ(new_operands.size(), 5);
- clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
- new_operands[2], new_operands[3],
- new_operands[4], epsilon(), feature_index());
- break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), 3);
clone = CreateConditional(shape, new_operands[0], new_operands[1],
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kSend:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateSend(new_operands[0], channel_id());
- break;
- case HloOpcode::kSendDone:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateSendDone(new_operands[0]);
- break;
- case HloOpcode::kRecv:
- CHECK_EQ(new_operands.size(), 0);
- // The shape is a tuple, but CreateRecv() wants the raw data shape.
- clone =
- CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
- break;
- case HloOpcode::kRecvDone:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateRecvDone(new_operands[0]);
- break;
case HloOpcode::kGather:
CHECK_EQ(new_operands.size(), 2);
clone = CreateGather(shape, new_operands[0], new_operands[1],
@@ -1537,6 +1517,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
break;
+ case HloOpcode::kGenerateToken:
+ clone = CreateGenerateToken(new_operands);
+ break;
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@@ -1626,28 +1609,6 @@ const Literal& HloInstruction::literal() const {
bool HloInstruction::HasLiteral() const { return literal_ != nullptr; }
-bool HloInstruction::CanHaveDimensionsField() const {
- return (opcode() == HloOpcode::kReverse ||
- opcode() == HloOpcode::kConcatenate ||
- opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
- opcode() == HloOpcode::kTranspose);
-}
-
-const std::vector<int64>& HloInstruction::dimensions() const {
- CHECK(CanHaveDimensionsField());
- return dimensions_;
-}
-
-int64 HloInstruction::dimensions(int64 index) const {
- return dimensions()[index];
-}
-
-int64 HloInstruction::concatenate_dimension() const {
- CHECK(opcode() == HloOpcode::kConcatenate);
- CHECK_EQ(1, dimensions_.size());
- return dimensions(0);
-}
-
int64 HloInstruction::tuple_index() const {
CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
return tuple_index_;
@@ -1813,12 +1774,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kTuple:
return true;
- // Broadcast, Concatenate, and Transpose need the same dimensions field.
- case HloOpcode::kBroadcast:
- case HloOpcode::kConcatenate:
- case HloOpcode::kTranspose:
- return dimensions() == other.dimensions();
-
case HloOpcode::kFusion:
return fusion_kind() == other.fusion_kind() &&
eq_computations(fused_instructions_computation(),
@@ -1829,17 +1784,12 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kRng:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
+ case HloOpcode::kGenerateToken:
return false;
case HloOpcode::kParameter:
return parameter_number() == other.parameter_number();
- case HloOpcode::kBatchNormTraining:
- case HloOpcode::kBatchNormInference:
- case HloOpcode::kBatchNormGrad:
- return feature_index() == other.feature_index() &&
- epsilon() == other.epsilon();
-
// A constant is defined by the value in the literal.
case HloOpcode::kConstant:
return literal() == other.literal();
@@ -1865,16 +1815,6 @@ bool HloInstruction::IdenticalSlowPath(
other.gather_dimension_numbers()) &&
gather_window_bounds() == other.gather_window_bounds();
- // FFT has various types & lengths.
- case HloOpcode::kFft:
- return fft_type() == other.fft_type() &&
- fft_length() == other.fft_length();
-
- // Reduction results are determined by the reduction dimension and the
- // reduction computation.
- case HloOpcode::kReduce:
- return dimensions() == other.dimensions() &&
- eq_computations(to_apply(), other.to_apply());
case HloOpcode::kReduceWindow:
return eq_computations(to_apply(), other.to_apply()) &&
protobuf_util::ProtobufEquals(window(), other.window());
@@ -1886,20 +1826,14 @@ bool HloInstruction::IdenticalSlowPath(
eq_computations(scatter(), other.scatter()) &&
protobuf_util::ProtobufEquals(window(), other.window());
-
// Remaining instructions with special values.
case HloOpcode::kGetTupleElement:
return tuple_index() == other.tuple_index();
case HloOpcode::kPad:
return protobuf_util::ProtobufEquals(padding_config(),
other.padding_config());
- case HloOpcode::kSlice:
- return slice_starts_ == other.slice_starts_ &&
- slice_limits_ == other.slice_limits_ &&
- slice_strides_ == other.slice_strides_;
case HloOpcode::kCall:
case HloOpcode::kCrossReplicaSum:
- case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kCustomCall:
if ((window_ == nullptr) != (other.window_ == nullptr) ||
@@ -1916,8 +1850,6 @@ bool HloInstruction::IdenticalSlowPath(
return false;
}
return custom_call_target_ == other.custom_call_target_;
- case HloOpcode::kReverse:
- return dimensions() == other.dimensions();
case HloOpcode::kConditional:
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
@@ -1926,21 +1858,29 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kSort:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
case HloOpcode::kHostCompute:
return false;
- }
-}
-bool HloInstruction::IsRank2Transpose() const {
- return (opcode_ == HloOpcode::kTranspose) &&
- dimensions_ == std::vector<int64>({1, 0}) &&
- shape_.dimensions_size() == 2 &&
- std::equal(shape_.dimensions().begin(), shape_.dimensions().end(),
- operands_[0]->shape_.dimensions().rbegin());
+ // Ops migrated to subclasses should never come to this line.
+ // TODO(b/80131774): Remove this switch when migration is complete.
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kFft:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReverse:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kReduce:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kMap:
+ case HloOpcode::kSlice:
+ LOG(FATAL) << "Base class impl called for opcode with subclass: "
+ << opcode();
+ }
}
void HloInstruction::RemoveUser(HloInstruction* user) {
@@ -2295,13 +2235,11 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
- std::vector<string> extra;
+ std::vector<string> extra = ExtraAttributesToStringImpl(options);
+
if (opcode() == HloOpcode::kFusion) {
extra.push_back(StrCat("kind=", xla::ToString(fusion_kind())));
}
- if (CanHaveDimensionsField()) {
- extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
- }
if (window_ != nullptr && window_->dimensions_size() != 0) {
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
}
@@ -2309,29 +2247,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
}
- if (opcode() == HloOpcode::kSlice) {
- std::vector<string> bounds;
- bounds.reserve(slice_starts_.size());
- const bool omit_stride =
- std::all_of(slice_strides_.begin(), slice_strides_.end(),
- [](int64 stride) { return stride == 1; });
- for (int i = 0; i < slice_starts_.size(); ++i) {
- string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
- bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i],
- stride_str, "]"));
- }
- extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
- }
+
if (opcode() == HloOpcode::kDynamicSlice) {
extra.push_back(
StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
}
- if (opcode() == HloOpcode::kBatchNormTraining ||
- opcode() == HloOpcode::kBatchNormInference ||
- opcode() == HloOpcode::kBatchNormGrad) {
- extra.push_back(StrCat("epsilon=", epsilon()));
- extra.push_back(StrCat("feature_index=", feature_index()));
- }
if (convolution_dimension_numbers_ != nullptr) {
extra.push_back(StrCat(
@@ -2346,10 +2266,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
}
- if (opcode() == HloOpcode::kFft) {
- extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
- extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
- }
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
@@ -2420,10 +2336,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
break;
}
}
- if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
- opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
- extra.push_back(StrCat("channel_id=", channel_id_));
- }
if (opcode() == HloOpcode::kGetTupleElement) {
extra.push_back(StrCat("index=", tuple_index()));
@@ -2513,9 +2425,6 @@ HloInstructionProto HloInstruction::ToProto() const {
}
proto.set_tuple_index(tuple_index_);
- for (int64 dimension : dimensions_) {
- proto.add_dimensions(dimension);
- }
if (window_ != nullptr) {
*proto.mutable_window() = *window_;
}
@@ -2534,12 +2443,7 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.add_gather_window_bounds(bound);
}
}
- for (int i = 0; i < slice_starts_.size(); ++i) {
- auto* slice_dimension = proto.add_slice_dimensions();
- slice_dimension->set_start(slice_starts_[i]);
- slice_dimension->set_limit(slice_limits_[i]);
- slice_dimension->set_stride(slice_strides_[i]);
- }
+
proto.set_exponent_bits(exponent_bits_);
proto.set_mantissa_bits(mantissa_bits_);
for (int64 slice_size : dynamic_slice_sizes_) {
@@ -2552,16 +2456,9 @@ HloInstructionProto HloInstruction::ToProto() const {
if (opcode() == HloOpcode::kRng) {
proto.set_distribution(distribution_);
}
- proto.set_epsilon(epsilon_);
- proto.set_feature_index(feature_index_);
- proto.set_channel_id(channel_id_);
proto.set_infeed_config(infeed_config_);
proto.set_custom_call_target(custom_call_target_);
*proto.mutable_outfeed_shape() = outfeed_shape_;
- proto.set_fft_type(fft_type_);
- for (int64 fft_len : fft_length_) {
- proto.add_fft_length(fft_len);
- }
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
@@ -2868,6 +2765,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleGather(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
+ case HloOpcode::kGenerateToken:
+ return visitor->HandleGenerateToken(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -3199,7 +3098,6 @@ bool HloInstruction::IsElementwise() const {
// Other operations.
case HloOpcode::kRng:
- case HloOpcode::kMap:
return true;
case HloOpcode::kFusion:
if (fusion_kind() != FusionKind::kLoop) {
@@ -3619,4 +3517,61 @@ void HloInstruction::RelayoutConstant(const Layout& new_layout,
}
}
+// TODO(b/80131774): Remove these temporary methods after transition.
+int64 HloInstruction::feature_index() const {
+ return Cast<HloBatchNormInstruction>(this)->feature_index();
+}
+
+float HloInstruction::epsilon() const {
+ return Cast<HloBatchNormInstruction>(this)->epsilon();
+}
+
+FftType HloInstruction::fft_type() const {
+ return Cast<HloFftInstruction>(this)->fft_type();
+}
+
+const std::vector<int64>& HloInstruction::fft_length() const {
+ return Cast<HloFftInstruction>(this)->fft_length();
+}
+
+int64 HloInstruction::channel_id() const {
+ return Cast<HloSendRecvInstruction>(this)->channel_id();
+}
+
+int64 HloInstruction::concatenate_dimension() const {
+ return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
+}
+
+bool HloInstruction::IsRank2Transpose() const {
+ auto transpose = DynCast<HloTransposeInstruction>(this);
+ return transpose != nullptr && transpose->IsRank2Transpose();
+}
+
+int64 HloInstruction::slice_starts(int64 dimension) const {
+ return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
+}
+
+const std::vector<int64>& HloInstruction::slice_starts() const {
+ return Cast<HloSliceInstruction>(this)->slice_starts();
+}
+
+int64 HloInstruction::slice_limits(int64 dimension) const {
+ return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
+}
+
+const std::vector<int64>& HloInstruction::slice_limits() const {
+ return Cast<HloSliceInstruction>(this)->slice_limits();
+}
+
+int64 HloInstruction::slice_strides(int64 dimension) const {
+ return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
+}
+
+const std::vector<int64>& HloInstruction::slice_strides() const {
+ return Cast<HloSliceInstruction>(this)->slice_strides();
+}
+
+bool HloInstruction::IsInPlaceSlice() const {
+ return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index ef55c6668f..ae1c563b56 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -664,6 +664,11 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates a token instruction used for joining or creating token types which
+ // thread through side-effecting operations.
+ static std::unique_ptr<HloInstruction> CreateGenerateToken(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
tensorflow::gtl::ArraySlice<int64> output_window_dims,
@@ -802,9 +807,6 @@ class HloInstruction {
// Returns whether the instruction has a constant operand.
bool HasConstantOperand() const;
- // Returns whether this instruction does a rank-2 transposition.
- bool IsRank2Transpose() const;
-
// Replaces the use of this instruction in "user" with "new_producer". Note
// that there might be multiple uses of this instruction in "user"; all will
// be replaced.
@@ -889,17 +891,6 @@ class HloInstruction {
return parameter_number_;
}
- // Returns the dimension sizes or numbers associated with this instruction.
- //
- // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
- // and reverse.
- const std::vector<int64>& dimensions() const;
- int64 dimensions(int64 index) const;
-
- // Accessor for the dimension in which a concatenate HLO should occur.
- // Precondition: opcode() == HloOpcode::kConcatenate
- int64 concatenate_dimension() const;
-
// Returns the tuple index associated with this instruction.
//
// Precondition: opcode() == HloOpcode::kGetTupleElement
@@ -999,7 +990,7 @@ class HloInstruction {
string ToShortString() const;
// Returns a serialized representation of this instruction.
- HloInstructionProto ToProto() const;
+ virtual HloInstructionProto ToProto() const;
// Returns a category for the HLO. This could be something like "convolution"
// or "elementwise".
@@ -1011,33 +1002,12 @@ class HloInstruction {
HloInstruction* tracing() const;
void set_tracing(HloInstruction* trace_instruction);
- // Returns the channel id associated with the instruction. The id is
- // shared between each Send/Recv pair and is globally unique to identify each
- // channel.
- //
- // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
- int64 channel_id() const { return channel_id_; }
-
// Returns the channel name associated with the instruction. The name is
// used to identify host Send/Recv operations.
//
// Precondition: opcode() == HloOpcode::kHostCompute
string channel_name() const { return channel_name_; }
- // Returns feature_index field associated with the instruction. The index
- // represents the index of the feature dimension.
- //
- // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
- // or kBatchNormGrad.
- int64 feature_index() const { return feature_index_; }
-
- // Returns a epsilon value associated with the instruction. The is a small
- // number added to the variance to avoid divide-by-zero error.
- //
- // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
- // or kBatchNormGrad.
- float epsilon() const { return epsilon_; }
-
// Returns the infeed configuration string. The infeed configuration includes
// any metadata needed for the backend compiler (e.g., infeed buffer address)
// and is target-dependent.
@@ -1140,8 +1110,11 @@ class HloInstruction {
void set_sharding(const HloSharding& sharding) {
sharding_ = MakeUnique<HloSharding>(sharding);
}
+ void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
- void set_device_sharding(int64 device);
+ void set_device_sharding(int64 device) {
+ set_single_sharding(HloSharding::AssignDevice(device));
+ }
// Remove any sharding from this operator.
void clear_sharding() { sharding_ = nullptr; }
// Return true if this operator has a sharding assigned.
@@ -1216,48 +1189,6 @@ class HloInstruction {
return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
}
- // Returns the start index in the given dimension for a slice node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_starts(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_starts_[dimension];
- }
- const std::vector<int64>& slice_starts() const { return slice_starts_; }
-
- // Returns the (exclusive) limit index in the given dimension for a slice
- // node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_limits(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_limits_[dimension];
- }
- const std::vector<int64>& slice_limits() const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_limits_;
- }
-
- // Returns the stride in the given dimension for a slice node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_strides(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_strides_[dimension];
- }
- const std::vector<int64>& slice_strides() const { return slice_strides_; }
-
- // Returns the flag that describes whether a slice must be lowered into an
- // offset into the original operand.
- bool IsInPlaceSlice() const { return is_in_place_slice_; }
-
- // Sets and returns the flag that describes whether a slice must be lowered
- // into an offset into the original operand.
- bool SetIsInPlaceSlice(bool value) {
- is_in_place_slice_ = value;
- return value;
- }
-
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -1324,16 +1255,6 @@ class HloInstruction {
MakeUnique<ConvolutionDimensionNumbers>(dnums);
}
- FftType fft_type() const {
- CHECK_EQ(HloOpcode::kFft, opcode_);
- return fft_type_;
- }
-
- const std::vector<int64>& fft_length() const {
- CHECK_EQ(HloOpcode::kFft, opcode_);
- return fft_length_;
- }
-
// Returns data on the dimension numbers used for a dot operation.
const DotDimensionNumbers& dot_dimension_numbers() const {
CHECK(dot_dimension_numbers_ != nullptr);
@@ -1371,7 +1292,8 @@ class HloInstruction {
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context = nullptr) const;
// Returns the computations this instruction directly calls (if any).
@@ -1412,7 +1334,7 @@ class HloInstruction {
bool IsElementwiseOnOperand(int64 operand_idx) const;
// Returns true if this instruction is elementwise on all its operands.
- bool IsElementwise() const;
+ virtual bool IsElementwise() const;
// Returns true if this elementwise instruction implicitly broadcasts operand
// `operand_idx`.
@@ -1442,9 +1364,14 @@ class HloInstruction {
std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
- // Gets/sets the string identifier for this instruction.
+ // Gets the string identifier for this instruction.
const string& name() const { return name_; }
- void set_name(tensorflow::StringPiece name) { name_ = std::string(name); }
+
+ // Sets the string identifier for this instruction. Name will be sanitized to
+ // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
+ void SetAndSanitizeName(const string& name) {
+ name_ = NameUniquer::GetSanitizedName(name);
+ }
// Use the given NameUniquer to select a unique name for the instruction based
// on the instruction's existing name.
@@ -1531,12 +1458,82 @@ class HloInstruction {
void RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index = {});
+ // Old methods kept for smooth subclassing transition BEGIN.
+ // TODO(b/80131774): Remove this code.
+
+ // Delegates to HloBatchNormInstruction::feature_index.
+ int64 feature_index() const;
+
+ // Delegates to HloBatchNormInstruction::epsilon.
+ float epsilon() const;
+
+ // Delegates to HloFftInstruction::fft_type.
+ FftType fft_type() const;
+
+ // Delegates to HloFftInstruction::fft_length.
+ const std::vector<int64>& fft_length() const;
+
+ // Delegates to HloSendRecvInstruction::channel_id.
+ int64 channel_id() const;
+
+ // Returns the dimension sizes or numbers associated with this instruction.
+ virtual const std::vector<int64>& dimensions() const {
+ LOG(FATAL) << "Unimplemented method.";
+ }
+ virtual int64 dimensions(int64 index) const {
+ LOG(FATAL) << "Unimplemented method.";
+ }
+
+ // Delegates to HloConcatenateInstruction::concatenate_dimension.
+ int64 concatenate_dimension() const;
+
+ // Returns whether this instruction does a rank-2 transposition.
+ bool IsRank2Transpose() const;
+
+ // Delegates to HloSliceInstruction::slice_start.
+ int64 slice_starts(int64 dimension) const;
+ const std::vector<int64>& slice_starts() const;
+
+ // Delegates to HloSliceInstruction::slice_limits.
+ int64 slice_limits(int64 dimension) const;
+ const std::vector<int64>& slice_limits() const;
+
+ // Delegates to HloSliceInstruction::slice_strides.
+ int64 slice_strides(int64 dimension) const;
+ const std::vector<int64>& slice_strides() const;
+
+ // Delegates to HloSliceInstruction::IsInPlaceSlice.
+ bool IsInPlaceSlice() const;
+ // Old methods kept for smooth subclassing transition END.
+
protected:
// Internal constructor for a given opcode/shape, other fields must be filled
// by factory methods.
HloInstruction(HloOpcode opcode, const Shape& shape);
+ // Appends operand to the list of operands and adds this instruction as a user
+ // of the operand.
+ void AppendOperand(HloInstruction* operand);
+
+ void AppendComputation(HloComputation* computation) {
+ called_computations_.push_back(computation);
+ }
+
private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ // TODO(b/80131774): This should be pure virtual.
+ LOG(FATAL) << "Unimplemented method.";
+ }
+
+ // Implementation for non-common logic of ExtraAttributesToString.
+ virtual std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {};
+ }
// Prints an instruction to a string.
//
// The canonical string representation needs to name operands and instruction
@@ -1561,7 +1558,7 @@ class HloInstruction {
class FusionReusesParamElements;
// See comments on Identical().
- bool IdenticalSlowPath(
+ virtual bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const;
@@ -1571,10 +1568,6 @@ class HloInstruction {
const Shape& shape, HloOpcode opcode,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
- // Appends operand to the list of operands and adds this instruction as a user
- // of the operand.
- void AppendOperand(HloInstruction* operand);
-
// Adds a user for this instruction.
void AddUser(HloInstruction* user);
@@ -1609,10 +1602,6 @@ class HloInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloCloneContext* context = nullptr) const;
- // Returns true if this instruction can legally have the dimensions field
- // set. Used for checking precondition of dimensions field accessors.
- bool CanHaveDimensionsField() const;
-
// Returns how this instruction uses elements of its `i`th operand.
UseKind OperandElementUse(int64 i) const;
@@ -1656,10 +1645,6 @@ class HloInstruction {
// Constant index, only present for kGetTupleElement.
int64 tuple_index_ = -1;
- // Dimensions present for some operations that require reshaping or
- // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
- std::vector<int64> dimensions_;
-
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
@@ -1672,20 +1657,6 @@ class HloInstruction {
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_window_bounds_;
- // Describes FFT type for an FFT instruction.
- FftType fft_type_ = FftType::FFT;
-
- // Indicates the FFT length for an FFT instruction.
- std::vector<int64> fft_length_;
-
- // Describes the [begin, end) index range for a slice.
- std::vector<int64> slice_starts_;
- std::vector<int64> slice_limits_;
- std::vector<int64> slice_strides_;
-
- // Describes whether the slice can be lowered to an offset into the operand.
- bool is_in_place_slice_ = false;
-
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_ = 0;
int32 mantissa_bits_ = 0;
@@ -1752,18 +1723,6 @@ class HloInstruction {
// Only present for kRng.
RandomDistribution distribution_;
- // A small float number added to the variance to avoid divide-by-zero error.
- // Only present for kBatchNormTraining.
- float epsilon_ = 0.0f;
-
- // An integer value representing the index of the feature dimension.
- // Only present for kBatchNormTraining.
- int64 feature_index_ = -1;
-
- // Represents a unique identifier for each Send/Recv instruction pair.
- // Only present for kSend or kRecv.
- int64 channel_id_ = -1;
-
// The string representation of the infeed configuration.
string infeed_config_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 313033ddad..5d6f8b931f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -342,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
// Builds a parameter and feeds it to the map.
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10, ""));
+ HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
module->AddEntryComputation(builder.Build());
@@ -381,7 +381,7 @@ TEST_F(HloInstructionTest, TrivialReduce) {
// Builds a parameter and an initial value and feeds them to the reduce.
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10, ""));
+ HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
builder.AddInstruction(
@@ -980,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) {
}
}
+TEST_F(HloInstructionTest, MapIsElementwise) {
+ auto module = CreateNewModule();
+ const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
+ HloComputation::Builder builder(TestName());
+ HloComputation::Builder map_builder("id");
+ map_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
+ auto map = builder.AddInstruction(
+ HloInstruction::CreateMap(r2f32, {x}, map_computation));
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(map->IsElementwise());
+}
+
TEST_F(HloInstructionTest, PartiallyElementwise) {
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
new file mode 100644
index 0000000000..56792f8b1b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -0,0 +1,589 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+
+namespace xla {
+
+using ::tensorflow::str_util::Join;
+using ::tensorflow::strings::StrCat;
+
+HloBatchNormInstruction::HloBatchNormInstruction(
+ HloOpcode opcode, const Shape& shape, HloInstruction* operand,
+ HloInstruction* scale, float epsilon, int64 feature_index)
+ : HloInstruction(opcode, shape),
+ epsilon_(epsilon),
+ feature_index_(feature_index) {
+ AppendOperand(operand);
+ AppendOperand(scale);
+}
+
+bool HloBatchNormInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
+ return feature_index() == casted_other.feature_index() &&
+ epsilon() == casted_other.epsilon();
+}
+
+HloInstructionProto HloBatchNormInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_epsilon(epsilon_);
+ proto.set_feature_index(feature_index_);
+ return proto;
+}
+
+std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("epsilon=", epsilon()),
+ StrCat("feature_index=", feature_index())};
+}
+
+HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, float epsilon, int64 feature_index)
+ : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
+ scale, epsilon, feature_index) {
+ AppendOperand(offset);
+}
+
+std::unique_ptr<HloInstruction>
+HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 3);
+ return MakeUnique<HloBatchNormTrainingInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
+ feature_index());
+}
+
+HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
+ float epsilon, int64 feature_index)
+ : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
+ scale, epsilon, feature_index) {
+ AppendOperand(offset);
+ AppendOperand(mean);
+ AppendOperand(variance);
+}
+
+std::unique_ptr<HloInstruction>
+HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 5);
+ return MakeUnique<HloBatchNormInferenceInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
+ new_operands[4], epsilon(), feature_index());
+}
+
+HloBatchNormGradInstruction::HloBatchNormGradInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
+ float epsilon, int64 feature_index)
+ : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
+ epsilon, feature_index) {
+ AppendOperand(mean);
+ AppendOperand(variance);
+ AppendOperand(grad_output);
+}
+
+std::unique_ptr<HloInstruction>
+HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 5);
+ return MakeUnique<HloBatchNormGradInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
+ new_operands[4], epsilon(), feature_index());
+}
+
+HloFftInstruction::HloFftInstruction(
+ const Shape& shape, HloInstruction* operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length)
+ : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
+ fft_length_.assign(fft_length.begin(), fft_length.end());
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloFftInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_fft_type(fft_type_);
+ for (int64 fft_len : fft_length_) {
+ proto.add_fft_length(fft_len);
+ }
+ return proto;
+}
+
+std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("fft_type=", FftType_Name(fft_type())),
+ StrCat("fft_length={", Join(fft_length(), ","), "}")};
+}
+
+bool HloFftInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloFftInstruction&>(other);
+ return fft_type() == casted_other.fft_type() &&
+ fft_length() == casted_other.fft_length();
+}
+
+std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloFftInstruction>(shape, new_operands[0], fft_type_,
+ fft_length_);
+}
+
+HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
+ const Shape& shape,
+ int64 channel_id)
+ : HloInstruction(opcode, shape), channel_id_(channel_id) {}
+
+HloInstructionProto HloSendRecvInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_channel_id(channel_id_);
+ return proto;
+}
+
+std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("channel_id=", channel_id_)};
+}
+
+bool HloSendRecvInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+// Send instruction produces a tuple of {aliased operand, U32 context}.
+HloSendInstruction::HloSendInstruction(HloInstruction* operand,
+ int64 channel_id)
+ : HloSendRecvInstruction(
+ HloOpcode::kSend,
+ ShapeUtil::MakeTupleShape(
+ {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}),
+ channel_id) {
+ AppendOperand(operand);
+}
+
+std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloSendInstruction>(new_operands[0], channel_id());
+}
+
+HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
+ : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(),
+ CHECK_NOTNULL(operand)->channel_id()) {
+ AppendOperand(operand);
+}
+
+std::unique_ptr<HloInstruction>
+HloSendDoneInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloSendDoneInstruction>(
+ Cast<HloSendInstruction>(new_operands[0]));
+}
+
+// Recv instruction produces a tuple of {receive buffer, U32 context}.
+HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id)
+ : HloSendRecvInstruction(
+ HloOpcode::kRecv,
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
+ channel_id) {}
+
+std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 0);
+ return MakeUnique<HloRecvInstruction>(
+ ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
+}
+
+HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)
+ : HloSendRecvInstruction(
+ HloOpcode::kRecvDone,
+ ShapeUtil::GetTupleElementShape(operand->shape(), 0),
+ CHECK_NOTNULL(operand)->channel_id()) {
+ AppendOperand(operand);
+}
+
+std::unique_ptr<HloInstruction>
+HloRecvDoneInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloRecvDoneInstruction>(
+ Cast<HloRecvInstruction>(new_operands[0]));
+}
+
+HloReverseInstruction::HloReverseInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions)
+ : HloInstruction(HloOpcode::kReverse, shape),
+ dimensions_(dimensions.begin(), dimensions.end()) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloReverseInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloReverseInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloReverseInstruction>(shape, new_operands[0],
+ dimensions());
+}
+
+HloConcatenateInstruction::HloConcatenateInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ int64 dimension)
+ : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloInstructionProto HloConcatenateInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloConcatenateInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloConcatenateInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction>
+HloConcatenateInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloConcatenateInstruction>(shape, new_operands,
+ dimensions(0));
+}
+
+HloReduceInstruction::HloReduceInstruction(
+ const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation)
+ : HloInstruction(HloOpcode::kReduce, shape),
+ dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
+ AppendOperand(arg);
+ AppendOperand(init_value);
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloReduceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloReduceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
+ // Reduction results are determined by the reduction dimension and the
+ // reduction computation.
+ return dimensions() == casted_other.dimensions() &&
+ eq_computations(to_apply(), casted_other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloReduceInstruction>(
+ shape, new_operands[0], new_operands[1], dimensions(), to_apply());
+}
+
+HloTransposeInstruction::HloTransposeInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions)
+ : HloInstruction(HloOpcode::kTranspose, shape),
+ dimensions_(dimensions.begin(), dimensions.end()) {
+ CHECK_EQ(shape.dimensions().size(), dimensions.size());
+ CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
+ CHECK(std::equal(operand->shape().dimensions().begin(),
+ operand->shape().dimensions().end(),
+ Permute(dimensions, shape.dimensions()).begin()))
+ << "shape: " << ShapeUtil::HumanString(shape)
+ << ", operand->shape(): " << ShapeUtil::HumanString(shape)
+ << ", dimensions: {" << Join(dimensions, ", ") << "}";
+ AppendOperand(operand);
+}
+
+bool HloTransposeInstruction::IsRank2Transpose() const {
+ return dimensions() == std::vector<int64>({1, 0}) &&
+ shape().dimensions_size() == 2 &&
+ std::equal(shape().dimensions().begin(), shape().dimensions().end(),
+ operand(0)->shape().dimensions().rbegin());
+}
+
+HloInstructionProto HloTransposeInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloTransposeInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction>
+HloTransposeInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloTransposeInstruction>(shape, new_operands[0],
+ dimensions());
+}
+
+HloBroadcastInstruction::HloBroadcastInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimension)
+ : HloInstruction(HloOpcode::kBroadcast, shape),
+ dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloBroadcastInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloBroadcastInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction>
+HloBroadcastInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloBroadcastInstruction>(shape, new_operands[0],
+ dimensions());
+}
+
+HloMapInstruction::HloMapInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands)
+ : HloInstruction(HloOpcode::kMap, shape) {
+ CHECK(static_operands.empty()) << "static_operands not yet supported";
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+ AppendComputation(map_computation);
+ // TODO(b/65689298) Remove code below once Map is generalized to accept
+ // arbitrary map dimensions.
+ dimensions_.resize(ShapeUtil::Rank(shape));
+ std::iota(dimensions_.begin(), dimensions_.end(), 0);
+}
+
+HloInstructionProto HloMapInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+bool HloMapInstruction::IsElementwise() const {
+ if (!dimensions().empty()) {
+ // Check that the map is executed in elementwise compatible dimensions.
+ if (dimensions().size() != shape().dimensions_size()) {
+ return false;
+ }
+ for (int i = 0; i < dimensions().size(); ++i) {
+ if (dimensions()[i] != i) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloMapInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ return eq_computations(to_apply(), other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloMapInstruction>(shape, new_operands, to_apply());
+}
+
+HloSliceInstruction::HloSliceInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides)
+ : HloInstruction(HloOpcode::kSlice, shape),
+ slice_starts_(start_indices.begin(), start_indices.end()),
+ slice_limits_(limit_indices.begin(), limit_indices.end()),
+ slice_strides_(strides.begin(), strides.end()) {
+ AppendOperand(operand);
+ // For backward compatibility with old serialized computations: if there are
+ // no strides, assume all strides are 1.
+ // TODO(b/63317920): remove this code.
+ if (slice_strides_.empty()) {
+ slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
+ }
+}
+
+HloInstructionProto HloSliceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int i = 0; i < slice_starts_.size(); ++i) {
+ auto* slice_dimension = proto.add_slice_dimensions();
+ slice_dimension->set_start(slice_starts_[i]);
+ slice_dimension->set_limit(slice_limits_[i]);
+ slice_dimension->set_stride(slice_strides_[i]);
+ }
+ return proto;
+}
+
+std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> bounds;
+ bounds.reserve(slice_starts_.size());
+ const bool omit_stride =
+ std::all_of(slice_strides_.begin(), slice_strides_.end(),
+ [](int64 stride) { return stride == 1; });
+ for (int i = 0; i < slice_starts_.size(); ++i) {
+ string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
+ bounds.push_back(
+ StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
+ }
+ return {StrCat("slice={", Join(bounds, ", "), "}")};
+}
+
+bool HloSliceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
+ return slice_starts_ == other_slice.slice_starts_ &&
+ slice_limits_ == other_slice.slice_limits_ &&
+ slice_strides_ == other_slice.slice_strides_;
+}
+
+std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
+ slice_limits_, slice_strides_);
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
new file mode 100644
index 0000000000..18e786d8b6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -0,0 +1,438 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// All HloInstruction subclasses are put in this file.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+class HloBatchNormInstruction : public HloInstruction {
+ public:
+ // Returns feature_index field associated with the instruction. The index
+ // represents the index of the feature dimension.
+ int64 feature_index() const { return feature_index_; }
+
+ // Returns a epsilon value associated with the instruction. The is a small
+ // number added to the variance to avoid divide-by-zero error.
+ float epsilon() const { return epsilon_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ protected:
+ explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* scale, float epsilon,
+ int64 feature_index);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // A small float number added to the variance to avoid divide-by-zero error.
+ float epsilon_ = 0.0f;
+
+ // An integer value representing the index of the feature dimension.
+ int64 feature_index_ = -1;
+};
+
+class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
+ public:
+ explicit HloBatchNormTrainingInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* scale,
+ HloInstruction* offset,
+ float epsilon, int64 feature_index);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
+ public:
+ explicit HloBatchNormInferenceInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
+ float epsilon, int64 feature_index);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloBatchNormGradInstruction : public HloBatchNormInstruction {
+ public:
+ explicit HloBatchNormGradInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* mean, HloInstruction* variance,
+ HloInstruction* grad_output, float epsilon, int64 feature_index);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloFftInstruction : public HloInstruction {
+ public:
+ explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
+ FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+ FftType fft_type() const { return fft_type_; }
+
+ const std::vector<int64>& fft_length() const { return fft_length_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Describes FFT type for an FFT instruction.
+ FftType fft_type_ = FftType::FFT;
+
+ // Indicates the FFT length for an FFT instruction.
+ std::vector<int64> fft_length_;
+};
+
+class HloSendRecvInstruction : public HloInstruction {
+ public:
+ // Returns the channel id associated with the instruction. The id is
+ // shared between each Send/Recv pair and is globally unique to identify each
+ // channel.
+ int64 channel_id() const { return channel_id_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ protected:
+ explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
+ int64 channel_id);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Represents a unique identifier for each Send/Recv instruction pair.
+ int64 channel_id_;
+};
+
+class HloSendInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloSendInstruction(HloInstruction* operand, int64 channel_id);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloSendDoneInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloSendDoneInstruction(HloSendInstruction* operand);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloRecvInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloRecvInstruction(const Shape& shape, int64 channel_id);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloRecvDoneInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloRecvDoneInstruction(HloRecvInstruction* operand);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloReverseInstruction : public HloInstruction {
+ public:
+ explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloConcatenateInstruction : public HloInstruction {
+ public:
+ explicit HloConcatenateInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ int64 dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Accessor for the dimension in which a concatenate HLO should occur.
+ int64 concatenate_dimension() const { return dimensions(0); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloReduceInstruction : public HloInstruction {
+ public:
+ explicit HloReduceInstruction(
+ const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloTransposeInstruction : public HloInstruction {
+ public:
+ explicit HloTransposeInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns whether this instruction does a rank-2 transposition.
+ bool IsRank2Transpose() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloBroadcastInstruction : public HloInstruction {
+ public:
+ explicit HloBroadcastInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloMapInstruction : public HloInstruction {
+ public:
+ explicit HloMapInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Returns true if this instruction is binary and elementwise.
+ bool IsElementwise() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloSliceInstruction : public HloInstruction {
+ public:
+ explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ HloInstructionProto ToProto() const override;
+
+ // Returns the start index in the given dimension for a slice node.
+ int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
+ const std::vector<int64>& slice_starts() const { return slice_starts_; }
+
+ // Returns the (exclusive) limit index in the given dimension for a slice
+ // node.
+ int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
+ const std::vector<int64>& slice_limits() const { return slice_limits_; }
+
+ // Returns the stride in the given dimension for a slice node.
+ int64 slice_strides(int64 dimension) const {
+ return slice_strides_[dimension];
+ }
+ const std::vector<int64>& slice_strides() const { return slice_strides_; }
+
+ // Returns the flag that describes whether a slice must be lowered into an
+ // offset into the original operand.
+ bool IsInPlaceSlice() const { return is_in_place_slice_; }
+
+ // Sets and returns the flag that describes whether a slice must be lowered
+ // into an offset into the original operand.
+ bool SetIsInPlaceSlice(bool value) {
+ is_in_place_slice_ = value;
+ return value;
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Describes the [begin, end) index range for a slice.
+ std::vector<int64> slice_starts_;
+ std::vector<int64> slice_limits_;
+ std::vector<int64> slice_strides_;
+
+ // Describes whether the slice can be lowered to an offset into the operand.
+ bool is_in_place_slice_ = false;
+};
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index e63424c2df..9c59374b4a 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -32,15 +32,6 @@ limitations under the License.
namespace xla {
-HloModule::HloModule(const string& name,
- const VersionedComputationHandle& entry_computation_handle,
- const HloModuleConfig& config)
- : name_(NameUniquer::GetSanitizedName(name)),
- config_(config),
- has_entry_computation_handle_(true),
- entry_computation_handle_(entry_computation_handle),
- unique_id_(next_unique_module_id_++) {}
-
HloModule::HloModule(const string& name, const HloModuleConfig& config)
: name_(NameUniquer::GetSanitizedName(name)),
config_(config),
@@ -234,8 +225,7 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
- const HloModuleProto& proto, const HloModuleConfig& module_config,
- const VersionedComputationHandle& entry_computation_handle) {
+ const HloModuleProto& proto, const HloModuleConfig& module_config) {
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -287,8 +277,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(entry != nullptr);
- auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
- module_config);
+ auto module = MakeUnique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(),
@@ -401,7 +390,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
// as a parameter in the new function.
arguments.push_back(old_operand);
*operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
- parameter_count, old_operand->shape(), ""));
+ parameter_count, old_operand->shape(), "p"));
++parameter_count;
}
TF_CHECK_OK(
@@ -525,8 +514,6 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
- module->entry_computation_handle_ = entry_computation_handle_;
- module->has_entry_computation_handle_ = has_entry_computation_handle_;
HloCloneContext context(module.get(), suffix);
auto cloned_computation = entry_computation_->Clone(suffix, &context);
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index c93c74d34a..757e65bda2 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -57,10 +56,6 @@ namespace xla {
// attached to.
class HloModule {
public:
- HloModule(const string& name,
- const VersionedComputationHandle& entry_computation_handle,
- const HloModuleConfig& config);
-
// Constructor without a versioned computation handle. This constructor should
// only be used for HloModules used outside of the XLA service (eg
// tests). The versioned handle is used by the service in the compilation
@@ -126,10 +121,6 @@ class HloModule {
return config_.device_entry_computation_layout();
}
- const VersionedComputationHandle& entry_computation_handle() const {
- return entry_computation_handle_;
- }
-
// Gets the computations in this module.
//
// Returns a view of HloComputation*s, so you can iterate over this in the
@@ -188,9 +179,7 @@ class HloModule {
// Convert an HloModule to or from a proto.
HloModuleProto ToProto() const;
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
- const HloModuleProto& proto, const HloModuleConfig& module_config,
- const VersionedComputationHandle& entry_computation_handle =
- VersionedComputationHandle());
+ const HloModuleProto& proto, const HloModuleConfig& module_config);
// Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto.
@@ -264,10 +253,6 @@ class HloModule {
mutable std::mt19937_64 rng_{42};
mutable tensorflow::mutex rng_mutex_;
- // Versioned handle of the entry computation of the module.
- bool has_entry_computation_handle_ = false;
- VersionedComputationHandle entry_computation_handle_;
-
// Unique name generator for computation and instruction names, which are
// unique per module.
NameUniquer computation_name_uniquer_{/*separator=*/"."};
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 1fe06ee0c0..a35546f5f4 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -81,6 +81,7 @@ namespace xla {
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
+ V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
V(kHostCompute, "host-compute") \
diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
index cd2ce5c69f..774345124b 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
@@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
+ case HloOpcode::kGenerateToken:
case HloOpcode::kTuple:
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 3eadedfe1f..4aa4406292 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -606,6 +606,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateReshape(shape, operands[0]));
break;
}
+ case HloOpcode::kGenerateToken: {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateGenerateToken(operands));
+ break;
+ }
case HloOpcode::kTuple: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
@@ -777,6 +785,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
+ optional<std::vector<tensorflow::int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
+ &dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -1137,7 +1148,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloOpcodeString(opcode)));
}
- instruction->set_name(name);
+ instruction->SetAndSanitizeName(name);
+ if (instruction->name() != name) {
+ return Error(name_loc,
+ StrCat("illegal instruction name: ", name,
+ "; suggest renaming to: ", instruction->name()));
+ }
// Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 08068dc504..1c5a47c875 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -765,7 +765,7 @@ add_F32.v3 {
ENTRY MapBinaryAdder.v3 {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
- ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3
+ ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
}
)"
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 39b85de0f1..bd1d9935bd 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -71,6 +71,20 @@ bool IsRematerializable(const HloInstruction* instruction) {
}
}
+// Checks whether an instruction can be rematerialized, by looking up the
+// cache before, and eventually calling the IsRematerializable() API.
+bool CanBeRematerialized(
+ const HloInstruction* instruction,
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
+ auto it = remat_able->find(instruction);
+ if (it != remat_able->end()) {
+ return it->second;
+ }
+ bool rematerializable = IsRematerializable(instruction);
+ (*remat_able)[instruction] = rematerializable;
+ return rematerializable;
+}
+
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
@@ -843,9 +857,10 @@ int64 RematerializationCost(const HloInstruction* instruction,
// candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found.
-Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
- const InstructionList& instruction_list,
- int64 memory_limit_bytes) {
+Item* PickRematerializationCandidate(
+ const MemoryUsageTracker& memory_tracker,
+ const InstructionList& instruction_list, int64 memory_limit_bytes,
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
int64 best_cost = 0;
@@ -869,8 +884,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
<< " is excluded from rematerialization";
continue;
}
-
- if (!IsRematerializable(candidate)) {
+ if (!CanBeRematerialized(candidate, remat_able)) {
VLOG(5) << "candidate " << candidate->name()
<< " not viable: is not rematerializable";
continue;
@@ -974,6 +988,9 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// blacklist.
tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
+ // The map from instructions to their rematerializable status.
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool> remat_able;
+
// The peak memory of the computation at any point in the instruction
// sequence.
int64 peak_memory = memory_tracker.memory_usage();
@@ -1011,7 +1028,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
Item* best_item = PickRematerializationCandidate(
- memory_tracker, instruction_list, memory_limit_bytes);
+ memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
if (best_item == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program "
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83de54f3fa..e81334d5a8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase {
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
//
- // F32[] %param = {...}
+ // F32[1] %param = {...}
+ // F32[] %reshape = reshape(F32[], param)
// F32[1024] %bcast = broadcast(%param)
// F32[1024] %negate = negate(%bcast)
// F32[2048] %concat_1 = concat({%negate, %negate})
@@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase {
const string& suffix = "") {
auto builder = HloComputation::Builder(TestName() + suffix);
auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ HloInstruction::CreateParameter(0, vec1_shape_, "param"));
+ auto reshape = builder.AddInstruction(
+ HloInstruction::CreateReshape(scalar_shape_, param));
auto bcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
+ HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate(
@@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase {
const string& suffix = "") {
auto builder = HloComputation::Builder(TestName() + suffix);
auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ HloInstruction::CreateParameter(0, vec1_shape_, "param"));
+ auto reshape = builder.AddInstruction(
+ HloInstruction::CreateReshape(scalar_shape_, param));
auto bcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
+ HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
auto slice_1 = builder.AddInstruction(
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
/*limit_indices=*/{1},
@@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
+ StatusOr<bool> RunHloRematerialization(
+ int64 memory_limit_bytes, HloModule* module,
+ SequentialHloOrdering::HloModuleSequence* sequence) {
+ TF_EXPECT_OK(verifier().Run(module).status());
+ return HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
+ sequence);
+ }
+
// Various shapes used in the canned computations.
const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {});
const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1});
@@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) {
SequentialHloOrdering::HloModuleSequence sequence;
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/14 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
- EXPECT_EQ(computation->instruction_count(), 7);
+ EXPECT_EQ(computation->instruction_count(), 8);
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/20 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024,
+ module.get(), &sequence));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
- EXPECT_EQ(computation->instruction_count(), 7);
+ EXPECT_EQ(computation->instruction_count(), 8);
}
// Test rematerialization of a computation which calls another computation via a
@@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
- EXPECT_EQ(entry_computation->instruction_count(), 6);
- EXPECT_EQ(body_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 7);
+ EXPECT_EQ(body_computation->instruction_count(), 8);
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/17 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(body_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 8);
+ EXPECT_EQ(body_computation->instruction_count(), 8);
}
// Test rematerialization of a computation which calls another computation via a
@@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
- EXPECT_EQ(entry_computation->instruction_count(), 6);
- EXPECT_EQ(body_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 7);
+ EXPECT_EQ(body_computation->instruction_count(), 8);
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/15 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
- // Both computations should have a rematerialized instruction added.
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(body_computation->instruction_count(), 8);
+ // Both computations should have rematerialized instructions added.
+ EXPECT_EQ(entry_computation->instruction_count(), 9);
+ EXPECT_EQ(body_computation->instruction_count(), 9);
}
// Test rematerialization of a doubly nested computation. All computations
@@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/middle_computation));
- EXPECT_EQ(entry_computation->instruction_count(), 6);
- EXPECT_EQ(middle_computation->instruction_count(), 6);
- EXPECT_EQ(inner_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 7);
+ EXPECT_EQ(middle_computation->instruction_count(), 7);
+ EXPECT_EQ(inner_computation->instruction_count(), 8);
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/13 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
- // All computations should have a rematerialized instruction added.
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(middle_computation->instruction_count(), 7);
- EXPECT_EQ(inner_computation->instruction_count(), 8);
+ // All computations should have rematerialized instructions added.
+ EXPECT_EQ(entry_computation->instruction_count(), 9);
+ EXPECT_EQ(middle_computation->instruction_count(), 9);
+ EXPECT_EQ(inner_computation->instruction_count(), 9);
}
TEST_F(HloRematerializationTest, RngNotRematerialized) {
@@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
+ bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), DefaultMemoryScheduler, &sequence));
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/22 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/22 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024,
+ module.get(), &sequence));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 58224ef870..4fbb7f69ac 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -141,6 +141,20 @@ StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
}
}
+StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
+ if (IsTuple()) {
+ // TODO(b/109903108): An empty tuple has one leaf for ShapeTree, while it
+ // has zero leaves for ShapeUtil. This needs cleanup.
+ int64 shape_leaves =
+ ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape);
+ TF_RET_CHECK(shape_leaves == tuple_elements_.size())
+ << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
+ << " leaf nodes while this sharding has " << tuple_elements_.size();
+ return *this;
+ }
+ return Tuple(ShapeTree<HloSharding>(shape, *this));
+}
+
StatusOr<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
@@ -389,6 +403,19 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape,
: sub_shape_tree.element(ShapeIndex({}));
}
+tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
+ const {
+ if (!IsTuple()) {
+ return *this;
+ }
+ for (int64 i = 1; i < tuple_elements_.size(); ++i) {
+ if (tuple_elements_[0] != tuple_elements_[i]) {
+ return tensorflow::gtl::optional<HloSharding>();
+ }
+ }
+ return tuple_elements_.front();
+}
+
std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
out << sharding.ToString();
return out;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index f4a0fb626f..0a213311b4 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -72,8 +72,7 @@ class HloSharding {
// elements for every leaf shape contained in the tuple.
static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
std::vector<HloSharding> flattened_list;
- flattened_list.reserve(
- std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
+ flattened_list.reserve(sub_shardings.leaf_count());
for (const auto& index_to_sharding : sub_shardings.leaves()) {
flattened_list.push_back(index_to_sharding.second);
}
@@ -172,6 +171,18 @@ class HloSharding {
// REQUIRES: IsTuple()
HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const;
+ // If the current sharding is a tuple sharding, return itself as result.
+ // Otherwise returns a tuple sharding for the input shape, with all the leaves
+ // having this object sharding.
+ StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const;
+
+ // Extracts the sharding that is common within the current sharding.
+ // If the current sharding is not a tuple sharding, the current sharding will
+ // be returned. If it is a tuple, and all the tuple elements are common, the
+ // common element will be returned. Otherwise the optional will contain no
+ // value.
+ tensorflow::gtl::optional<HloSharding> ExtractSingleSharding() const;
+
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
ShapeUtil::Compatible(tile_shape_, other.tile_shape_) &&
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 82cff2a4b7..7b4b071af4 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -31,32 +31,22 @@ struct PassThrough {
HloInstruction* operand = nullptr;
};
-void SetDeviceSharding(HloInstruction* instruction, int64 device) {
- VLOG(4) << " " << instruction->name() << " to device " << device;
- instruction->set_device_sharding(device);
-}
-
-tensorflow::gtl::optional<int64> ShardingUniqueDevice(
- const HloSharding& sharding) {
- if (sharding.IsTileMaximal()) {
- auto device = sharding.UniqueDevice();
- if (device.ok()) {
- return device.ValueOrDie();
- }
- }
- return tensorflow::gtl::optional<int64>();
+void SetSingleSharding(HloInstruction* instruction,
+ const HloSharding& sharding) {
+ VLOG(4) << " " << instruction->name() << " to " << sharding;
+ instruction->set_single_sharding(sharding);
}
bool ShardingMatches(const HloSharding& sharding1,
const HloSharding& sharding2) {
- auto device1 = ShardingUniqueDevice(sharding1);
- if (device1) {
- auto device2 = ShardingUniqueDevice(sharding2);
- if (device2) {
- return *device1 == *device2;
+ auto single_sharding1 = sharding1.ExtractSingleSharding();
+ if (single_sharding1) {
+ auto single_sharding2 = sharding2.ExtractSingleSharding();
+ if (single_sharding2) {
+ return *single_sharding1 == single_sharding2;
}
}
- // Anything which is not tile maximal with unique device, gets a full sharding
+ // Anything which is not unique across all elements, gets a full sharding
// compare.
return sharding1 == sharding2;
}
@@ -119,21 +109,21 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
std::unique_ptr<HloSharding> CloneShardingForDomain(
const HloSharding& sharding) {
- auto device = ShardingUniqueDevice(sharding);
- if (!device) {
+ auto single_sharding = sharding.ExtractSingleSharding();
+ if (!single_sharding) {
return MakeUnique<HloSharding>(sharding);
}
- return MakeUnique<HloSharding>(HloSharding::AssignDevice(*device));
+ return MakeUnique<HloSharding>(*single_sharding);
}
-Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain,
- int64 device) {
- VLOG(4) << "Applying device " << device << " sharding";
+Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
+ const HloSharding& sharding) {
+ VLOG(4) << "Applying " << sharding << " sharding";
for (HloInstruction* instruction : domain.instructions) {
// We only change instructions without sharding, since otherwise we might
// mess up with eventual HLO passes which has knowledge of it.
if (!instruction->has_sharding()) {
- SetDeviceSharding(instruction, device);
+ SetSingleSharding(instruction, sharding);
} else {
VLOG(4) << " " << instruction->name() << " already has sharding "
<< instruction->sharding();
@@ -186,12 +176,15 @@ StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
const HloSharding* tuple_sharding =
GetOperandSharding(tuple, domain, sharding);
if (tuple_sharding != nullptr) {
- TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString();
- HloSharding sub_sharding = tuple_sharding->GetSubSharding(
- tuple->shape(), {instruction->tuple_index()});
- VLOG(4) << " " << instruction->name() << " to sharding "
- << sub_sharding;
- instruction->set_sharding(sub_sharding);
+ if (tuple_sharding->IsTuple()) {
+ HloSharding sub_sharding = tuple_sharding->GetSubSharding(
+ tuple->shape(), {instruction->tuple_index()});
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << sub_sharding;
+ instruction->set_sharding(sub_sharding);
+ } else {
+ SetSingleSharding(instruction, *tuple_sharding);
+ }
++assigned;
}
} else if (instruction->opcode() == HloOpcode::kTuple) {
@@ -242,12 +235,12 @@ StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
const HloSharding& sharding) {
- auto device = ShardingUniqueDevice(sharding);
- if (device) {
- // Shortcut the simple case. We have a unique device sharding, so we call
- // the ApplyDomainDeviceSharding() API which will apply array or tuple
- // shaped device sharding to the domain instructions.
- return ApplyDomainDeviceSharding(domain, *device);
+ auto single_sharding = sharding.ExtractSingleSharding();
+ if (single_sharding) {
+ // Shortcut the simple case. We have a unique sharding, so we call
+ // the ApplyDomainSingleSharding() API which will apply array or tuple
+ // shaped sharding to the domain instructions.
+ return ApplyDomainSingleSharding(domain, *single_sharding);
}
VLOG(1) << "Assigning non-trivial sharding " << sharding;
for (;;) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9cfd8a9bf7..9034073cc8 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -426,6 +426,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
+Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) {
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : token->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
+ return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes));
+}
+
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) {
// If allow_mixed_precision_ is false, check if there are operands with
@@ -791,6 +799,46 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
return Status::OK();
}
+namespace {
+
+// Returns true if the given Shape has a TOKEN shape as any subshape.
+bool ShapeContainsToken(const Shape& shape) {
+ bool contains_token = false;
+ ShapeUtil::ForEachSubshape(
+ shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsToken(subshape)) {
+ contains_token = true;
+ }
+ });
+ return contains_token;
+}
+
+// Verifies that all types entering and exiting the entry computation are
+// legal. For example, TOKEN types have no Literal representation and cannot be
+// on the interface of the entry computation (parameters and root instruction).
+Status VerifyEntryAndExitShapes(const HloModule& module) {
+ for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
+ HloInstruction* param =
+ module.entry_computation()->parameter_instruction(i);
+ if (ShapeContainsToken(param->shape())) {
+ return InternalError(
+ "Entry parameter %d is or contains a token shape: %s", i,
+ ShapeUtil::HumanString(param->shape()).c_str());
+ }
+ }
+ if (ShapeContainsToken(
+ module.entry_computation()->root_instruction()->shape())) {
+ return InternalError(
+ "Entry root is or contains a token shape: %s",
+ ShapeUtil::HumanString(
+ module.entry_computation()->root_instruction()->shape())
+ .c_str());
+ }
+ return Status::OK();
+}
+
+} // namespace
+
StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
@@ -851,6 +899,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
}
+ TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 1392a78097..7283b3e7dc 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -81,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor {
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
+ Status HandleGenerateToken(HloInstruction* token) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 429c850343..abedb4063d 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
+ case HloOpcode::kGenerateToken:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return false;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index 23d2d4e87d..1f6e3c829f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -15,53 +15,57 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
-#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
namespace xla {
-void KernelSupportLibrary::For(
+Status KernelSupportLibrary::For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
- const std::function<void(llvm::Value*, bool)>& for_body_generator) {
- If(ir_builder_->CreateICmpSLT(start, end), [&]() {
- for_body_generator(start, /*is_first_iteration=*/true);
- For(name, ir_builder_->CreateAdd(start, step), end, step,
- [&](llvm::Value* iv) { for_body_generator(iv, false); });
+ const std::function<Status(llvm::Value*, bool)>& for_body_generator) {
+ return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status {
+ TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true));
+ return For(name, ir_builder_->CreateAdd(start, step), end, step,
+ [&](llvm::Value* iv) { return for_body_generator(iv, false); });
});
}
-void KernelSupportLibrary::For(
+Status KernelSupportLibrary::For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
- const std::function<void(llvm::Value*, llvm::Value*)>& for_body_generator) {
+ const std::function<Status(llvm::Value*, llvm::Value*)>&
+ for_body_generator) {
if (peel_first_iteration) {
- For(name, start, end, step, true,
- [&](llvm::Value* indvar, bool is_first_iteration) {
- for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration));
- });
+ return For(name, start, end, step, true,
+ [&](llvm::Value* indvar, bool is_first_iteration) -> Status {
+ return for_body_generator(
+ indvar, ir_builder_->getInt1(is_first_iteration));
+ });
} else {
std::unique_ptr<llvm_ir::ForLoop> loop = llvm_ir::ForLoop::EmitForLoop(
name, start, end, step, ir_builder_,
- /*prevent_unrolling=*/prevent_unrolling_,
+ /*unroll_mode=*/unroll_mode_,
/*prevent_vectorization=*/prevent_vectorization_);
ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back());
- for_body_generator(loop->GetIndVarValue(),
- /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
- loop->GetIndVarValue(), start));
+ TF_RETURN_IF_ERROR(
+ for_body_generator(loop->GetIndVarValue(),
+ /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
+ loop->GetIndVarValue(), start)));
llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_);
+ return Status::OK();
}
}
-void KernelSupportLibrary::If(
- llvm::Value* condition, const std::function<void()>& true_block_generator,
- const std::function<void()>& false_block_generator) {
+Status KernelSupportLibrary::If(
+ llvm::Value* condition, const std::function<Status()>& true_block_generator,
+ const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(condition, "", ir_builder_);
ir_builder_->SetInsertPoint(&if_data.true_block->back());
- true_block_generator();
+ TF_RETURN_IF_ERROR(true_block_generator());
ir_builder_->SetInsertPoint(&if_data.false_block->back());
- false_block_generator();
+ TF_RETURN_IF_ERROR(false_block_generator());
llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
+ return Status::OK();
}
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 64b935bbf1..e17c649e52 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -30,13 +31,14 @@ namespace xla {
class KernelSupportLibrary {
public:
// `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR.
- // If `prevent_unrolling` is true then unrolling is explicitly disabled on
- // every loop generated by this instance of KernelSupportLibrary.
- explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling = true,
- bool prevent_vectorization = true)
+ // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop
+ // generated by this instance of KernelSupportLibrary.
+ explicit KernelSupportLibrary(
+ llvm::IRBuilder<>* ir_builder,
+ llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll,
+ bool prevent_vectorization = true)
: ir_builder_(ir_builder),
- prevent_unrolling_(prevent_unrolling),
+ unroll_mode_(unroll_mode),
prevent_vectorization_(prevent_vectorization) {}
// Generates the following control flow structure:
@@ -46,19 +48,41 @@ class KernelSupportLibrary {
// for (i64 i = `start` + `step`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
// }
- void For(
+ Status For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<Status(llvm::Value* ind_var,
+ bool is_first_iteration)>& for_body_generator);
+
+ void ForReturnVoid(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
- for_body_generator);
+ for_body_generator) {
+ CHECK_EQ(Status::OK(),
+ For(name, start, end, step,
+ [&](llvm::Value* ind_var, bool is_first_iteration) -> Status {
+ for_body_generator(ind_var, is_first_iteration);
+ return Status::OK();
+ }));
+ }
+
+ Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ const std::function<Status(llvm::Value* ind_var,
+ bool is_first_iteration)>&
+ for_body_generator) {
+ return For(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ }
- void For(
+ void ForReturnVoid(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
- For(name, /*start=*/ir_builder_->getInt64(start),
- /*end=*/ir_builder_->getInt64(end),
- /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
}
// Generates the following control flow structure if `peel_first_iteration` is
@@ -75,46 +99,101 @@ class KernelSupportLibrary {
// for (i64 i = `start`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i,
// /*is_first_iteration=*/,(i != `start`))`;
- void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- llvm::Value* step, bool peel_first_iteration,
- const std::function<void(llvm::Value* ind_var,
- llvm::Value* is_first_iteration)>&
+ Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step, bool peel_first_iteration,
+ const std::function<Status(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator);
+
+ void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ llvm::Value* end, llvm::Value* step,
+ bool peel_first_iteration,
+ const std::function<void(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ TF_CHECK_OK(For(
+ name, start, end, step, peel_first_iteration,
+ [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status {
+ for_body_generator(ind_var, is_first_iteration);
+ return Status::OK();
+ }));
+ }
+
+ Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ int64 step, bool peel_first_iteration,
+ const std::function<Status(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ return For(name, /*start=*/start, /*end=*/end,
+ /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
for_body_generator);
+ }
+
+ void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ llvm::Value* end, int64 step, bool peel_first_iteration,
+ const std::function<void(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ ForReturnVoid(name, /*start=*/start, /*end=*/end,
+ /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
+ for_body_generator);
+ }
- void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step, bool peel_first_iteration,
- const std::function<void(llvm::Value* ind_var,
- llvm::Value* is_first_iteration)>&
- for_body_generator) {
- For(name, /*start=*/start, /*end=*/end,
- /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
- for_body_generator);
+ Status For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
+ return For(name, start, end, step,
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) -> Status {
+ return for_body_generator(indvar);
+ });
}
- void For(
+ void ForReturnVoid(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
- For(name, start, end, step,
- /*peel_first_iteration=*/false,
- [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+ ForReturnVoid(name, start, end, step,
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) {
+ return for_body_generator(indvar);
+ });
+ }
+
+ Status For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ int64 step,
+ const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
+ return For(name, start, end, ir_builder_->getInt64(step),
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) -> Status {
+ return for_body_generator(indvar);
+ });
}
- void For(
+ void ForReturnVoid(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
- For(name, start, end, ir_builder_->getInt64(step),
- /*peel_first_iteration=*/false,
- [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+ ForReturnVoid(name, start, end, ir_builder_->getInt64(step),
+ for_body_generator);
+ }
+
+ Status For(
+ tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
+ return For(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
}
- void For(
+ void ForReturnVoid(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
- For(name, /*start=*/ir_builder_->getInt64(start),
- /*end=*/ir_builder_->getInt64(end),
- /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
}
// Generates the following control flow structure:
@@ -123,9 +202,25 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
- void If(llvm::Value* condition,
- const std::function<void()>& true_block_generator,
- const std::function<void()>& false_block_generator = []() {});
+ Status If(llvm::Value* condition,
+ const std::function<Status()>& true_block_generator,
+ const std::function<Status()>& false_block_generator =
+ []() -> Status { return Status::OK(); });
+
+ void IfReturnVoid(llvm::Value* condition,
+ const std::function<void()>& true_block_generator,
+ const std::function<void()>& false_block_generator = []() {
+ }) {
+ TF_CHECK_OK(If(condition,
+ [&]() {
+ true_block_generator();
+ return Status::OK();
+ },
+ [&]() {
+ false_block_generator();
+ return Status::OK();
+ }));
+ }
using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
@@ -183,7 +278,7 @@ class KernelSupportLibrary {
private:
llvm::IRBuilder<>* ir_builder_;
- bool prevent_unrolling_;
+ llvm_ir::UnrollMode unroll_mode_;
bool prevent_vectorization_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 497b48ff22..9f867014fb 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -34,7 +34,7 @@ namespace llvm_ir {
ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index,
- llvm::Value* step, bool prevent_unrolling,
+ llvm::Value* step, UnrollMode unroll_mode,
bool prevent_vectorization)
: prefix_(std::string(prefix)),
suffix_(std::string(suffix)),
@@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
end_index_(end_index),
step_(step),
insert_before_bb_(nullptr),
- prevent_unrolling_(prevent_unrolling),
+ unroll_mode_(unroll_mode),
prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling, bool prevent_vectorization) {
+ UnrollMode unroll_mode, bool prevent_vectorization) {
std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
- end_index, step, prevent_unrolling,
+ end_index, step, unroll_mode,
prevent_vectorization));
loop->Emit(ir_builder);
return loop;
@@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
llvm::IRBuilder<>* ir_builder) {
const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
+ const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
llvm::LLVMContext* ctx = &start_index_->getContext();
std::vector<llvm::Metadata*> result;
- if (prevent_unrolling_) {
+ if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
result.push_back(llvm::MDNode::get(
*ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
}
@@ -162,6 +163,10 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
}
+ if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
+ result.push_back(llvm::MDNode::get(
+ *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
+ }
return result;
}
@@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
- prevent_unrolling, prevent_vectorization);
+ unroll_mode, prevent_vectorization);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
llvm::Value* stride,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
}
std::unique_ptr<ForLoop> loop(new ForLoop(
- /*prefix=*/name_, suffix, start_index, end_index, stride,
- prevent_unrolling, prevent_vectorization));
+ /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
+ prevent_vectorization));
loop->Emit(ir_builder_);
if (outer_loop_preheader_bb_ == nullptr) {
@@ -215,23 +220,23 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
tensorflow::StringPiece suffix,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index),
- ir_builder_->getInt64(end_index), prevent_unrolling,
+ ir_builder_->getInt64(end_index), unroll_mode,
prevent_vectorization);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
tensorflow::StringPiece suffix,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index),
ir_builder_->getInt64(end_index),
- ir_builder_->getInt64(stride), prevent_unrolling,
+ ir_builder_->getInt64(stride), unroll_mode,
prevent_vectorization);
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index d915f95db1..4e403cd994 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -34,6 +34,12 @@ limitations under the License.
namespace xla {
namespace llvm_ir {
+enum class UnrollMode {
+ kDefaultUnroll,
+ kFullyUnroll,
+ kNoUnroll,
+};
+
// A class for constructing a for-loop in LLVM IR.
class ForLoop {
public:
@@ -69,12 +75,13 @@ class ForLoop {
// LLVM IR. If non-empty, it is prepended to the name of the induction
// variable value and each basic block created for the loop.
//
- // If `prevent_unrolling` is true then emit metadata that directs LLVM to not
- // unroll the generated loop.
+ // `unroll_mode` specifies the desired LLVM unrolling behavior for generated
+ // loop.
static std::unique_ptr<ForLoop> EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling = false, bool prevent_vectorization = false);
+ UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// The names of the blocks follow LLVM's conventions. Control flow amongst the
// blocks for the example C code looks like:
@@ -128,7 +135,7 @@ class ForLoop {
ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
- bool prevent_unrolling, bool prevent_vectorization);
+ UnrollMode unroll_mode, bool prevent_vectorization);
// Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* ir_builder);
@@ -161,7 +168,7 @@ class ForLoop {
llvm::BasicBlock* body_bb_;
llvm::BasicBlock* exit_bb_;
llvm::Value* indvar_;
- bool prevent_unrolling_;
+ UnrollMode unroll_mode_;
bool prevent_vectorization_;
TF_DISALLOW_COPY_AND_ASSIGN(ForLoop);
@@ -182,34 +189,34 @@ class ForLoopNest {
// Adds a loop to the nest. If no loop has been added yet then emit a loop at
// the current insert point of the given builder. If one or more loops have
- // been added then emit loop inside the body of the last added loop. If
- // prevent_unrolling is true, then metadata is emitting directing LLVM to not
- // unroll this loop.
- std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index, llvm::Value* stride,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ // been added then emit loop inside the body of the last added loop.
+ // unroll_mode is used to emit metadata that controls LLVM unrolling.
+ std::unique_ptr<ForLoop> AddLoop(
+ tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index, llvm::Value* stride,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
- std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ std::unique_ptr<ForLoop> AddLoop(
+ tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// A convenient wrapper of the other flavor of AddLoop. The given start and
// end index are constant.
- std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
- int64 stride, tensorflow::StringPiece suffix,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ std::unique_ptr<ForLoop> AddLoop(
+ int64 start_index, int64 end_index, int64 stride,
+ tensorflow::StringPiece suffix,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
- std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
- tensorflow::StringPiece suffix,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ std::unique_ptr<ForLoop> AddLoop(
+ int64 start_index, int64 end_index, tensorflow::StringPiece suffix,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// Add loops to iterate through the indices within the specified
// shape. The returned index collects the induction variables of the
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 1d9c9e0678..296d04d436 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
new file mode 100644
index 0000000000..29f787b86b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -0,0 +1,342 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
+ bool changed = false;
+
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ computation_ = computation;
+ reachability_ = computation_->ComputeReachability();
+ candidates_.clear();
+ candidates_index_.clear();
+ all_fusion_candidates_.clear();
+
+ int64 index = 0;
+ for (auto it : computation_->MakeInstructionPostOrder()) {
+ candidates_.emplace_back(it);
+ InsertOrDie(&candidates_index_, it, index++);
+ }
+
+ // Create the initial candidate list for each Node.
+ for (auto& node : candidates_) {
+ HloInstruction* instruction = node.hlo;
+ int64 instruction_id = get_candidate_id(instruction);
+ FusionCandidate& instr_node = candidates_[instruction_id];
+ if (!IsFusible(instruction)) {
+ continue;
+ }
+ all_fusion_candidates_.push_back(instruction);
+
+ std::vector<HloInstruction*> candidates;
+ tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
+ VLOG(10) << "Looking at instruction: " << instruction->name();
+ for (auto operand : instruction->operands()) {
+ // Filter out the non-interesting instructions -- they
+ // will not generate the savings.
+ if (!IsProfitableOperand(operand)) {
+ VLOG(10) << "Operand not profitable: " << operand->name();
+ continue;
+ }
+ VLOG(10) << "Operand profitable: " << operand->name();
+ for (auto user : operand->users()) {
+ VLOG(10) << "User: " << user->name();
+ if (user == instruction || !IsFusible(user)) {
+ VLOG(10) << "User is not fusible, or is the instruction itself: "
+ << user->name();
+ continue;
+ }
+ int64 user_id = get_candidate_id(user);
+ if (is_connected(instruction, user)) {
+ VLOG(10) << "User is connected: " << user->name();
+ continue;
+ }
+ if (instruction_id < user_id &&
+ user->opcode() == HloOpcode::kFusion) {
+ VLOG(10) << "User ID for user: " << user->name() << " is "
+ << user_id << " which is higher than " << instruction_id;
+ continue;
+ }
+ if (!LegalToFuse(instruction, user)) {
+ VLOG(10) << "User not legal to fuse: " << user->name();
+ continue;
+ }
+ if (candidates_set.insert(user).second) {
+ VLOG(10) << "User added to candidate list: " << user->name();
+ candidates.push_back(user);
+ }
+ }
+ }
+
+ // Iterate over candidates rather than candidates_set to avoid
+ // nondeterminism.
+ for (auto candidate : candidates) {
+ int64 profit = GetProfit(instruction, candidate);
+ if (profit > 0) {
+ FusionCandidate& candidate_node =
+ candidates_[get_candidate_id(candidate)];
+ instr_node.fusibles.emplace_back(candidate, profit);
+ candidate_node.fusibles.emplace_back(instruction, profit);
+ worklist_.emplace(instruction, candidate, profit);
+ }
+ }
+ }
+ if (Perform()) {
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ HloInstruction* remaining = instr1;
+ HloInstruction* fused = instr2;
+ // Make sure that if only one of the instructions is a fusion, or if only one
+ // of the instructions is a multi-output fusion, it's what will be fused into.
+ //
+ // An invariant is that no bitcast nodes will show up in the middle of a
+ // fusion node. This invariant must hold in order for us to lower it. Given
+ // that, we require that during multi-output fusion, a fusion node ending with
+ // bitcast to preserve its structure as a nested fusion instead being
+ // merged and flattened.
+ if (fused->opcode() == HloOpcode::kFusion &&
+ fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ std::swap(remaining, fused);
+ }
+ if (fused->IsMultiOutputFusion()) {
+ std::swap(remaining, fused);
+ }
+
+ if (fused->opcode() == HloOpcode::kFusion &&
+ fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ remaining->MergeFusionInstructionIntoMultiOutput(fused);
+ } else {
+ if (remaining->opcode() == HloOpcode::kFusion &&
+ remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) {
+ auto parent_computation = remaining->parent();
+ // Create a nested fusion node.
+ auto remaining_nested_fused =
+ parent_computation->AddInstruction(HloInstruction::CreateFusion(
+ remaining->shape(), HloInstruction::FusionKind::kLoop,
+ remaining));
+ TF_CHECK_OK(parent_computation->ReplaceInstruction(
+ remaining, remaining_nested_fused));
+ remaining = remaining_nested_fused;
+ }
+ remaining->FuseInstructionIntoMultiOutput(fused);
+ }
+
+ return remaining;
+}
+
+void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
+ HloInstruction* fusion = instr1;
+ HloInstruction* fused = instr2;
+ if (is_fused(instr1)) {
+ fusion = instr2;
+ fused = instr1;
+ }
+
+ // Insert the newly created instruction (if any), to candidates_.
+ for (auto use : fusion->users()) {
+ if (candidates_index_.find(use) == candidates_index_.end()) {
+ int64 index = candidates_.size();
+ candidates_.emplace_back(use);
+ InsertOrDie(&candidates_index_, use, index++);
+ }
+ }
+ FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)];
+ FusionCandidate& fused_node = candidates_[get_candidate_id(fused)];
+
+ // Update the reachability graph.
+ UpdateReachability(fusion, fused, all_fusion_candidates_,
+ [this](HloInstruction* instr) { return is_fused(instr); });
+
+ // Update the fusible list for fusion. Variable new_fusibles keeps
+ // track of the new or changed entries.
+ std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
+ tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ auto it = fusion_node.fusibles.begin();
+ while (it != fusion_node.fusibles.end()) {
+ HloInstruction* instr = it->first;
+ if (is_fused(instr) || is_connected(fusion, instr)) {
+ it = fusion_node.fusibles.erase(it);
+ continue;
+ }
+ in_list.insert(instr);
+ int64 profit = GetProfit(instr, fusion);
+ if (profit > it->second) {
+ it->second = profit;
+ new_fusibles.emplace_back(instr, profit);
+ }
+ ++it;
+ }
+
+ // Fused_node has been fused into fusion_node. Take the fusion candidates
+ // (fusibles) from fused_nodes and add them to the fusion_node's. Filter
+ // out those fusibles that no longer valid (or already in the list).
+ for (const auto& it : fused_node.fusibles) {
+ HloInstruction* instr = it.first;
+ if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) {
+ continue;
+ }
+ if (in_list.count(instr) > 0) {
+ continue;
+ }
+ int64 profit = GetProfit(instr, fusion);
+ fusion_node.fusibles.emplace_back(instr, profit);
+ new_fusibles.emplace_back(instr, profit);
+ }
+ fused_node.fusibles.clear();
+
+ // Update the worklist_.
+ for (auto it : new_fusibles) {
+ worklist_.emplace(fusion, it.first, it.second);
+ }
+}
+
+bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ if (instr1 == instr2) {
+ return false;
+ }
+ if (instr1->opcode() != HloOpcode::kFusion) {
+ return false;
+ }
+
+ // Fusing nodes with 0 user makes no sense and the rest of the implementation
+ // doesn't support it either.
+ if (instr1->user_count() == 0 || instr2->user_count() == 0) {
+ return false;
+ }
+
+ // Check if the users of multioutput fusion is not a get-tuple-element.
+ // If this is the case, we bail out because the transformation assumes
+ // the users are get-tuple-element.
+ auto multioutput_user_is_not_gte = [](HloInstruction* instr) {
+ if (!instr->IsMultiOutputFusion()) {
+ return false;
+ }
+ for (auto user : instr->users()) {
+ if (user->opcode() != HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ }
+ return false;
+ };
+ if (multioutput_user_is_not_gte(instr1) ||
+ multioutput_user_is_not_gte(instr2)) {
+ return false;
+ }
+
+ if (is_connected(instr1, instr2)) {
+ return false;
+ }
+ if (!ShapesCompatibleForFusion(instr1, instr2)) {
+ return false;
+ }
+
+ return true;
+}
+
+void MultiOutputFusion::UpdateReachability(
+ HloInstruction* instr1, HloInstruction* instr2,
+ tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ const std::function<bool(HloInstruction*)>& skip) {
+ for (auto instr : instrs_to_update) {
+ if (skip != nullptr && skip(instr)) {
+ continue;
+ }
+ if (reachability_->IsReachable(instr2, instr) &&
+ reachability_->IsReachable(instr1, instr)) {
+ // If a candidate was already reachable by both, no update needed.
+ continue;
+ }
+ if (reachability_->IsReachable(instr2, instr)) {
+ reachability_->FastSetReachabilityToUnion({instr, instr1}, instr);
+ }
+ if (reachability_->IsReachable(instr1, instr)) {
+ reachability_->FastSetReachabilityToUnion({instr, instr2}, instr);
+ }
+ }
+}
+
+bool MultiOutputFusion::Perform() {
+ int changed = false;
+ // Pick the top candidate from queue and try to merge.
+ while (!worklist_.empty()) {
+ if (fuel_ <= 0) {
+ VLOG(2) << "No fusing: run out of fuel.";
+ break;
+ }
+ ToBeFused candidate = worklist_.top();
+ worklist_.pop();
+
+ HloInstruction* instr1 = candidate.instr1;
+ HloInstruction* instr2 = candidate.instr2;
+
+ if (is_fused(instr1) || is_fused(instr2)) {
+ continue;
+ }
+
+ VLOG(1) << "Considering candidate profit_score=" << candidate.score
+ << "\n\t\tinstr1 = " << instr1->ToString()
+ << "\n\t\tinstr2 = " << instr2->ToString();
+
+ if (LegalToFuse(instr1, instr2)) {
+ VLOG(1) << "Fuse!";
+ VLOG(2) << "Before multi_output_fusion:";
+ VLOG(2) << "instr1: " << instr1->ToString();
+ VLOG(2) << "\n"
+ << instr1->fused_instructions_computation()->ToString(
+ HloPrintOptions().set_indent_amount(1));
+ VLOG(2) << "instr2: " << instr2->ToString();
+ if (instr2->opcode() == HloOpcode::kFusion) {
+ VLOG(2) << "\n"
+ << instr2->fused_instructions_computation()->ToString(
+ HloPrintOptions().set_indent_amount(1));
+ }
+ HloInstruction* ret = Fuse(instr1, instr2);
+ set_is_fused(ret == instr1 ? instr2 : instr1);
+ Update(instr1, instr2);
+ changed = true;
+ VLOG(2) << "After fusion, \t this: " << ret->name() << "\n"
+ << ret->fused_instructions_computation()->ToString(
+ HloPrintOptions().set_indent_amount(1));
+ auto users = ret->users();
+ --fuel_;
+ }
+ }
+ if (DoProducerConsumerMultiOutputFusion(computation_)) {
+ changed = true;
+ }
+ return changed;
+}
+
+bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion(
+ HloComputation* /*computation*/) {
+ return false;
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
new file mode 100644
index 0000000000..cfdf83cfe8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
+
+#include <queue>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+
+// This class implements the fusing of sibling fusion instructions that sharing
+// common operands.
+// It constructs the following associated data structures.
+// (1) candidates_: stores the instruction and the set of instructions it can
+// fuse to.
+// (2) candidates_index_: maps instruction to id.
+// (3) reachability_: reachability map in this computation.
+// (4) all_fusion_candidates_: the vector of candidate instructions.
+// (5) worklist_: a priority queue that contains pairs of instructions to be
+// fused and their fusion profit scores.
+//
+// Function Perform() applies the optimization. It picks up the most profitable
+// pair in the worklist_, check if it's legal to fuse and fuse the pair.
+// After fusion, it updates the associated structure such as reachability_,
+// candidates_ and worklist_.
+// Note that the reachability map is updated based on the original computation.
+// This works because the reachability is monotonically increasing with
+// instruction fusion.
+class MultiOutputFusion : public HloPassInterface {
+ public:
+ MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
+
+ tensorflow::StringPiece name() const override {
+ return "multi_output_fusion";
+ }
+
+ // Run multi-output fusion on the given module. Returns whether the module
+ // was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
+ // Main entry for the optimization. Returns true if the optimization happens.
+ bool Perform();
+
+ // Test if instr1 and instr2 have the compatible shapes that can be legally
+ // fused.
+ virtual bool ShapesCompatibleForFusion(HloInstruction* instr1,
+ HloInstruction* instr2) = 0;
+
+ // Whether the instruction is a candidate for fusion.
+ virtual bool IsFusible(HloInstruction* instr) = 0;
+
+ // This function estimates the savings by merging instr1 and instr2 into one
+ // multi-output fusion instruction.
+ virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0;
+
+ // Whether fusing the instruction can reduce cost.
+ virtual bool IsProfitableOperand(HloInstruction* instr) = 0;
+
+ // Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
+ virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
+
+ // Update the reachability map after fusing instr1 and instr2.
+ void UpdateReachability(
+ HloInstruction* instr1, HloInstruction* instr2,
+ tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ const std::function<bool(HloInstruction*)>& skip = nullptr);
+
+ // Hook for multi-output fusion along producer-consumer edges.
+ // Returns whether any instructions were fused.
+ //
+ // TODO(b/80420762): Perform producer-consumer multi-output fusion in
+ // InstructionFusion instead.
+ virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation);
+
+ private:
+ // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
+ // The other instruction is removed from its parent computation.
+ HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
+
+ // Update the internal data structures after instr1 and instr2 are fused into
+ // one fusion instruction.
+ void Update(HloInstruction* instr1, HloInstruction* instr2);
+
+ // Optimization fuel is a compiler debugging technique that makes an
+ // optimization pass stop what it is doing after having made N changes to the
+ // program, where N is the fuel. By varying N, this can be used to find the
+ // first single change that makes a test fail.
+ int64 fuel_;
+
+ // Computation for the pass.
+ HloComputation* computation_;
+
+ // An internal data structure for each instruction in current computation.
+ // When an instruction is removed, member 'hlo' is set to nullptr.
+ struct FusionCandidate {
+ HloInstruction* hlo;
+ std::list<std::pair<HloInstruction*, int64>> fusibles;
+ explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {}
+ };
+ std::vector<FusionCandidate> candidates_;
+
+ // A map that maps an instruction to the index_.
+ tensorflow::gtl::FlatMap<HloInstruction*, int> candidates_index_;
+
+ // The reachability map of current computation.
+ std::unique_ptr<HloReachabilityMap> reachability_;
+
+ // This stores all the candidate instructions in current computation.
+ std::vector<HloInstruction*> all_fusion_candidates_;
+
+ // The pair of candidates to be fused and the profit score.
+ struct ToBeFused {
+ HloInstruction* instr1;
+ HloInstruction* instr2;
+ int64 score;
+ ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score)
+ : instr1(instr1), instr2(instr2), score(score) {}
+ bool operator<(const ToBeFused& rhs) const { return score < rhs.score; }
+ };
+ std::priority_queue<ToBeFused> worklist_;
+
+ int64 get_candidate_id(HloInstruction* instr) {
+ return FindOrDie(candidates_index_, instr);
+ }
+
+ bool is_fused(HloInstruction* instr) {
+ return candidates_[get_candidate_id(instr)].hlo == nullptr;
+ }
+
+ void set_is_fused(HloInstruction* instr) {
+ candidates_[get_candidate_id(instr)].hlo = nullptr;
+ }
+
+ bool is_connected(HloInstruction* instr1, HloInstruction* instr2) {
+ return reachability_->IsConnected(instr1, instr2);
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index d64b2b4d0a..8748a4c144 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -26,14 +26,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/channel_tracker.h"
-#include "tensorflow/compiler/xla/service/compilation_cache.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/execution_tracker.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -297,9 +295,6 @@ class Service : public ServiceInterface {
// Tracks asynchronously launched executions via the API.
ExecutionTracker execution_tracker_;
- // Cache containing previously built Executables.
- CompilationCache compilation_cache_;
-
// Backend to compile and execute computations on.
std::unique_ptr<Backend> execute_backend_;
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index d624f548b1..bd98e86b08 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -44,129 +44,6 @@ namespace xla {
namespace {
-// Return the UnaryOperation proto enum value associated with the given HLO
-// opcode.
-UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAbs:
- return UNOP_ABS;
- case HloOpcode::kCeil:
- return UNOP_CEIL;
- case HloOpcode::kClz:
- return UNOP_CLZ;
- case HloOpcode::kCos:
- return UNOP_COS;
- case HloOpcode::kExp:
- return UNOP_EXP;
- case HloOpcode::kExpm1:
- return UNOP_EXPM1;
- case HloOpcode::kFloor:
- return UNOP_FLOOR;
- case HloOpcode::kImag:
- return UNOP_IMAG;
- case HloOpcode::kIsFinite:
- return UNOP_IS_FINITE;
- case HloOpcode::kLog:
- return UNOP_LOG;
- case HloOpcode::kLog1p:
- return UNOP_LOG1P;
- case HloOpcode::kNot:
- return UNOP_NOT;
- case HloOpcode::kNegate:
- return UNOP_NEGATE;
- case HloOpcode::kReal:
- return UNOP_REAL;
- case HloOpcode::kRoundNearestAfz:
- return UNOP_ROUND_NEAREST_AFZ;
- case HloOpcode::kSign:
- return UNOP_SIGN;
- case HloOpcode::kSin:
- return UNOP_SIN;
- case HloOpcode::kSort:
- return UNOP_SORT;
- case HloOpcode::kTanh:
- return UNOP_TANH;
- default:
- LOG(FATAL) << "Unhandled opcode for conversion to unary operation: "
- << opcode;
- }
-}
-
-// Return the BinaryOperation proto enum value associated with the given HLO
-// opcode.
-BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAtan2:
- return BINOP_ATAN2;
- case HloOpcode::kComplex:
- return BINOP_COMPLEX;
- case HloOpcode::kMultiply:
- return BINOP_MUL;
- case HloOpcode::kAdd:
- return BINOP_ADD;
- case HloOpcode::kSubtract:
- return BINOP_SUB;
- case HloOpcode::kDivide:
- return BINOP_DIV;
- case HloOpcode::kEq:
- return BINOP_EQ;
- case HloOpcode::kGe:
- return BINOP_GE;
- case HloOpcode::kGt:
- return BINOP_GT;
- case HloOpcode::kLe:
- return BINOP_LE;
- case HloOpcode::kLt:
- return BINOP_LT;
- case HloOpcode::kNe:
- return BINOP_NE;
- case HloOpcode::kMaximum:
- return BINOP_MAX;
- case HloOpcode::kMinimum:
- return BINOP_MIN;
- case HloOpcode::kPower:
- return BINOP_POW;
- case HloOpcode::kRemainder:
- return BINOP_REM;
- case HloOpcode::kOr:
- return BINOP_OR;
- case HloOpcode::kAnd:
- return BINOP_AND;
- case HloOpcode::kShiftLeft:
- return BINOP_SHIFT_LEFT;
- case HloOpcode::kShiftRightArithmetic:
- return BINOP_SHIFT_RIGHT_ARITHMETIC;
- case HloOpcode::kShiftRightLogical:
- return BINOP_SHIFT_RIGHT_LOGICAL;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
-// Return the TernaryOperation proto enum value associated with the given HLO
-// opcode.
-TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kClamp:
- return TRIOP_CLAMP;
- case HloOpcode::kSelect:
- return TRIOP_SELECT;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
-// Return the VariadicOperation proto enum value associated with the given HLO
-// opcode.
-VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kTuple:
- return VAROP_TUPLE;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
// Returns true if no element is present in slice more than once.
bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
@@ -321,84 +198,81 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
}
- return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape);
-}
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(shape, "operand of unary operation"));
-/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
- UnaryOperation operation, const Shape& arg) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
-
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg));
- switch (operation) {
- case UNOP_FLOOR:
- case UNOP_CEIL:
- if (!ShapeUtil::ElementIsFloating(arg)) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+ switch (opcode) {
+ case HloOpcode::kFloor:
+ case HloOpcode::kCeil:
+ if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
"Expected element type in shape to be floating for floor/ceil "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
- case UNOP_COS:
- case UNOP_SIN:
- case UNOP_EXP:
- case UNOP_EXPM1:
- case UNOP_LOG:
- case UNOP_LOG1P:
- case UNOP_TANH:
- if (!ShapeUtil::ElementIsFloating(arg) &&
- !ShapeUtil::ElementIsComplex(arg)) {
+ return shape;
+ case HloOpcode::kCos:
+ case HloOpcode::kSin:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kTanh:
+ if (!ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
"sin/cos/exp/log/tanh operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
- case UNOP_REAL:
- case UNOP_IMAG:
- if (!ShapeUtil::ElementIsComplex(arg)) {
+ return shape;
+ case HloOpcode::kReal:
+ case HloOpcode::kImag:
+ if (!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"Expected element type in shape to be complex for real/imag "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(arg, F32);
- case UNOP_ABS:
- if (ShapeUtil::ElementIsComplex(arg)) {
+ return ShapeUtil::ChangeElementType(shape, F32);
+ case HloOpcode::kAbs:
+ if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
- arg, primitive_util::ComplexComponentType(arg.element_type()));
+ shape, primitive_util::ComplexComponentType(shape.element_type()));
}
- return arg;
- case UNOP_CLZ:
- case UNOP_NEGATE:
- case UNOP_ROUND_NEAREST_AFZ:
- case UNOP_SIGN:
- case UNOP_SORT:
- return arg;
-
- case UNOP_NOT:
- if (arg.element_type() != PRED &&
- !primitive_util::IsIntegralType(arg.element_type())) {
+ return shape;
+ case HloOpcode::kClz:
+ case HloOpcode::kNegate:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kSign:
+ case HloOpcode::kSort:
+ return shape;
+
+ case HloOpcode::kNot:
+ if (shape.element_type() != PRED &&
+ !primitive_util::IsIntegralType(shape.element_type())) {
return InvalidArgument(
"Expected pred or an integral element type in argument to Not "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
+ return shape;
- case UNOP_IS_FINITE:
- if (!ShapeUtil::ElementIsFloating(arg)) {
+ case HloOpcode::kIsFinite:
+ if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating point for IsFinite "
+ "Expected element type in shape to be floating "
+ "point for IsFinite "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(arg, PRED);
+ return ShapeUtil::ChangeElementType(shape, PRED);
default:
return InvalidArgument(
"Unknown operation for unary shape inference: \"%s\".",
- UnaryOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -463,6 +337,17 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
+/* static */ StatusOr<Shape> ShapeInference::InferTokenShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ for (const Shape* arg_shape : arg_shapes) {
+ if (arg_shape->element_type() != TOKEN) {
+ return InvalidArgument(
+ "Operands of token instructions must be TOKEN types.");
+ }
+ }
+ return ShapeUtil::MakeTokenShape();
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();
@@ -768,8 +653,9 @@ Status ValidateDotDimensionNumbers(
}
/* static */ StatusOr<Shape>
-ShapeInference::InferDegenerateDimensionBroadcastShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
+ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
+ const Shape& lhs,
+ const Shape& rhs) {
TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
// The shapes have to be compatible. That is, if some dimension d has a
@@ -787,7 +673,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
} else {
return InvalidArgument(
"Binary op %s with incompatible shapes: %s and %s.",
- BinaryOperation_Name(operation).c_str(),
+ HloOpcodeString(operation).c_str(),
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
@@ -797,8 +683,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
- BinaryOperation operation, const Shape& smaller_shape,
- const Shape& larger_shape,
+ const Shape& smaller_shape, const Shape& larger_shape,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
@@ -899,7 +784,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
@@ -909,8 +794,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Binary op %s with different element types: %s and %s.",
- BinaryOperation_Name(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(),
+ HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
@@ -943,10 +827,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
// After InDim broadcasting, perform degenerate dimensions broadcasting.
- TF_ASSIGN_OR_RETURN(
- Shape indim_broadcast_shape,
- InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
- broadcast_dimensions));
+ TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
+ InferInDimBroadcastShape(smaller_shape, larger_shape,
+ broadcast_dimensions));
return InferDegenerateDimensionBroadcastShape(
operation, indim_broadcast_shape, larger_shape);
@@ -955,51 +838,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
- return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(),
- rhs->shape(), /*broadcast_dimensions=*/{});
+ return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
+ /*broadcast_dimensions=*/{});
}
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs,
- broadcast_dimensions);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
VLOG(2) << tensorflow::strings::Printf(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
- BinaryOperation_Name(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
+ HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str(),
Join(broadcast_dimensions, ", ").c_str());
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- BinaryOperation_Name(operation))));
+ HloOpcodeString(opcode))));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- BinaryOperation_Name(operation))));
- switch (operation) {
- case BINOP_MAX:
- case BINOP_MIN:
- case BINOP_SUB:
- case BINOP_ADD:
- case BINOP_ATAN2:
- case BINOP_POW:
- case BINOP_DIV:
- case BINOP_REM:
- case BINOP_MUL:
- case BINOP_SHIFT_LEFT:
- case BINOP_SHIFT_RIGHT_ARITHMETIC:
- case BINOP_SHIFT_RIGHT_LOGICAL:
- return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ HloOpcodeString(opcode))));
+ switch (opcode) {
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kAdd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kPower:
+ case HloOpcode::kDivide:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
- case BINOP_COMPLEX: {
+ case HloOpcode::kComplex: {
if (!ShapeUtil::ElementIsFloating(lhs)) {
return InvalidArgument(
"Expected element type in shape to be floating for complex compose "
@@ -1007,7 +883,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(lhs.element_type()).c_str());
}
TF_ASSIGN_OR_RETURN(const Shape& shape,
- InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
@@ -1015,8 +891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return Unimplemented("Complex component type is not implemented.");
}
}
- case BINOP_AND:
- case BINOP_OR:
+ case HloOpcode::kAnd:
+ case HloOpcode::kOr:
if (lhs.element_type() != PRED &&
!primitive_util::IsIntegralType(lhs.element_type())) {
return InvalidArgument(
@@ -1024,24 +900,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"got %s.",
PrimitiveType_Name(lhs.element_type()).c_str());
}
- return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
- case BINOP_EQ:
- case BINOP_GE:
- case BINOP_GT:
- case BINOP_LE:
- case BINOP_LT:
- case BINOP_NE: {
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kNe: {
TF_ASSIGN_OR_RETURN(const Shape& shape,
- InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));
return ShapeUtil::ChangeElementType(shape, PRED);
}
default:
return Unimplemented(
"Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
- BinaryOperation_Name(operation).c_str(),
- lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
+ HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(),
+ rhs.ShortDebugString().c_str());
}
}
@@ -1053,23 +929,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
- return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
- TernaryOperation operation, const Shape& lhs, const Shape& rhs,
- const Shape& ehs) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
- switch (operation) {
- case TRIOP_CLAMP:
+ switch (opcode) {
+ case HloOpcode::kClamp:
return InferClampShape(lhs, rhs, ehs);
- case TRIOP_SELECT:
+ case HloOpcode::kSelect:
return InferSelectShape(lhs, rhs, ehs);
default:
return InvalidArgument("Unknown operation %s.",
- TernaryOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -1086,18 +956,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
HloOpcode opcode,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
- return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
- operand_shapes);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- VariadicOperation operation,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
- switch (operation) {
- case VAROP_TUPLE: {
+ switch (opcode) {
+ case HloOpcode::kTuple: {
Shape result = ShapeUtil::MakeTupleShape({});
for (const Shape* shape : operand_shapes) {
ShapeUtil::AppendShapeToTuple(*shape, &result);
@@ -1106,7 +969,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
default:
return InvalidArgument("Unknown operation %s.",
- VariadicOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 9da2c99b41..f1f7b50902 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -46,8 +46,6 @@ class ShapeInference {
public:
// Infers the shape produced by applying the given unary operation to the
// given input shape.
- static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation,
- const Shape& arg);
static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
const Shape& shape);
static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
@@ -56,9 +54,6 @@ class ShapeInference {
// Infers the shape produced by applying the given binary operation to the
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- static StatusOr<Shape> InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
@@ -67,9 +62,6 @@ class ShapeInference {
// Infers the shape produced by applying the given ternary operation to the
// given input shapes.
- static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation,
- const Shape& lhs, const Shape& rhs,
- const Shape& ehs);
static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
const Shape& rhs,
const Shape& ehs);
@@ -81,9 +73,6 @@ class ShapeInference {
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- VariadicOperation operation,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
- static StatusOr<Shape> InferVariadicOpShape(
HloOpcode opcode,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
@@ -227,6 +216,13 @@ class ShapeInference {
static StatusOr<Shape> InferConcatOpShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ // Infers the shape produced by a kGenerateToken operation. Trivially this
+ // shape is always a TOKEN shape. However, ShapeInference serves two purposes:
+ // inferring shapes and checking operand shapes. This method verifies that the
+ // operand shapes are all TOKENs.
+ static StatusOr<Shape> InferTokenShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
// the shape is identical except for the element type.
@@ -279,7 +275,7 @@ class ShapeInference {
// the LHS and a single element in the RHS to produce a single output element,
// even in the presence of broadcasting of one of the operands over the other.
static StatusOr<Shape> InferElementwiseBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
@@ -295,7 +291,7 @@ class ShapeInference {
// dimension broadcasting (a dimension of size 1 in one operand is broadcast
// up to match the size of the dimension in the other operand).
static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs);
+ HloOpcode operation, const Shape& lhs, const Shape& rhs);
// Helper for inferring shapes of binary operations using "InDim"
// broadcasting. This is the broadcasting used in the *InDim binary operations
@@ -303,8 +299,7 @@ class ShapeInference {
// lower-rank shape than larger_shape. Returns the shape that the
// smaller_shape is broadcast to.
static StatusOr<Shape> InferInDimBroadcastShape(
- BinaryOperation operation, const Shape& smaller_shape,
- const Shape& larger_shape,
+ const Shape& smaller_shape, const Shape& larger_shape,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 0e61994a78..6d017dffe2 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
- auto inferred_status = ShapeInference::InferUnaryOpShape(
- UnaryOperation::UNOP_NEGATE, matrix_shape);
+ auto inferred_status =
+ ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
}
@@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple);
+ HloOpcode::kSelect, pred_, tuple, tuple);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
@@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, SelectBadShapes) {
auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
+ HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("Operands to select must be the same shape"));
auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("pred operand must have PRED"));
auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
- matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
+ matrix_64_48_);
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("with non-scalar predicate with dimensionality"));
// Tuples have a TUPLE element type and cannot be the pred of a select.
auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}),
+ HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
ShapeUtil::MakeTupleShape({f32_, f32_}),
ShapeUtil::MakeTupleShape({f32_, f32_}));
ASSERT_FALSE(inferred_status_error4.ok());
@@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
- matrix_64_48_);
+ HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampAllScalar) {
- auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
+ auto inferred_status =
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
+ HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
+ HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
+ HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
+ HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
+ HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampBadShapes) {
// Type mismatch
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
- .ok());
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
- .ok());
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
- .ok());
- // Dimension mismatch
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_64_, vector_32_, vector_32_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
.ok());
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_32_, vector_64_, vector_32_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
.ok());
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_32_, vector_32_, vector_64_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
.ok());
- // Dimension mismatch, where one operand is a scalar
+ // Dimension mismatch
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
+ HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
+ HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
+ HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
+ .ok());
+ // Dimension mismatch, where one operand is a scalar
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
+ vector_64_, vector_32_, f32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
+ vector_64_, f32_, vector_32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
+ vector_64_, vector_32_)
.ok());
}
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
const tensorflow::gtl::ArraySlice<int64>& bcast) {
- return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
- lhs, rhs, bcast);
+ return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
+ bcast);
};
// Inputs must be FP.
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
@@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) {
}
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
- StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
- VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
+ StatusOr<Shape> result =
+ ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
ASSERT_IS_OK(result.status());
ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
ShapeUtil::MakeTupleShape({s32_, f32_})));
@@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) {
TEST_F(ShapeInferenceTest, InferPowShape) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
- auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {});
+ auto inferred_status = ShapeInference::InferBinaryOpShape(
+ HloOpcode::kPower, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
}
@@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) {
TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
- auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec8, {1});
+ auto inferred_status_match =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec8, {0});
+ auto inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
ASSERT_FALSE(inferred_status_mismatch.ok());
- inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec16, {0});
+ inferred_status_match =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
- inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec16, {1});
+ inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2});
+ HloOpcode::kAdd, cube, matrix8_4, {1, 2});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2});
+ HloOpcode::kAdd, cube, matrix16_4, {0, 2});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1});
+ HloOpcode::kAdd, cube, matrix16_8, {0, 1});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
}
@@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
// "magical" broadcast rejected
- auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {});
+ auto inferred_status_error1 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("Automatic"));
// broadcast_dimension out of bounds for tensor's rank
- auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {3});
+ auto inferred_status_error2 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
ContainsRegex("Broadcast dimension number .* too large"));
// broadcast_dimension doesn't match corresponding dimension
- auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {0});
+ auto inferred_status_error3 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("Broadcast dimension 0 mismatch"));
// broadcast_dimensions list too long
auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
+ HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(inferred_status_error4.status().error_message(),
HasSubstr("broadcast_dimensions has to match"));
// there's a dimension above the rank of the tensor
auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0});
+ HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(inferred_status_error5.status().error_message(),
ContainsRegex("dimension number .* too large"));
// broadcasting dimensions don't match in this order
auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
+ HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
ASSERT_FALSE(inferred_status_error6.ok());
ASSERT_THAT(inferred_status_error6.status().error_message(),
HasSubstr("dimension 0 mismatch"));
@@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
// in a proper (strictly increasing) order, even if the lower-rank array
// matches the higher-rank array in many different ways.
auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
+ HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
ASSERT_FALSE(inferred_status_error7.ok());
ASSERT_THAT(inferred_status_error7.status().error_message(),
HasSubstr("dimensions order is wrong"));
auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
+ HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
ASSERT_FALSE(inferred_status_error8.ok());
ASSERT_THAT(inferred_status_error8.status().error_message(),
HasSubstr("dimensions order is wrong"));
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 3139801ea3..cccb8f2fbb 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build(mul));
HloInstruction* call = module->OutlineExpressionFromComputation(
- {add, sub, mul}, "", entry_computation);
+ {add, sub, mul}, "entry", entry_computation);
EXPECT_EQ(call, entry_computation->root_instruction());
HloComputation* callee_computation = call->to_apply();
// The arguments to the call should be const1, const2, and const3.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index bb634e6573..eb6d1ada6b 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -723,15 +723,16 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
return false;
}
if (user->opcode() == HloOpcode::kFusion) {
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- // Loop fusion with kDynamicUpdateSlice fused root.
- //
- // Returns true iff there is exactly one use of 'operand' at shape index
- // 'operand_index', and this singleton use is the fused root at operand
- // index 0.
- return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ if (user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
+ }
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -789,8 +790,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
return param_uses.size() == 1 && param_uses[0].first == callee_root &&
callee_root->IsElementwiseOnOperand(param_uses[0].second);
}
- // Check if 'user' is element-wise.
- return user->IsElementwise();
+ // Loop fusions that contain transposing copies won't reach here as they have
+ // different layouts, which fails the check in the beginning of this function.
+ //
+ // Multi-output fusion will fail the check here as tuples are not considered
+ // an elementwise operation.
+ return user->IsElementwiseOnOperand(user->operand_index(operand));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index f558316b05..5734f28407 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1148,5 +1148,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
call, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) {
+ Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32});
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16});
+
+ auto builder = HloComputation::Builder(TestName() + "_fusion");
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, full_shape, "full"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, broadcast_shape, "small"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(full_shape, param1, {0}));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ full_shape, HloOpcode::kAdd, param0, broadcast));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, broadcast}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc
index d668855084..77bdcc9de0 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc
@@ -30,10 +30,17 @@ limitations under the License.
namespace xla {
+TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) :
+ exclude_entry_computation_(exclude_entry_computation) {}
+
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Initially add all GTE and Tuple instructions to the worklist.
std::queue<HloInstruction*> worklist;
for (auto* computation : module->computations()) {
+ if (exclude_entry_computation_ &&
+ computation == module->entry_computation()) {
+ continue;
+ }
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kTuple ||
instruction->opcode() == HloOpcode::kGetTupleElement) {
@@ -69,7 +76,6 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Tuple
//
HloInstruction* top_tuple = nullptr;
- HloInstruction* first_gte = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0;
operand_number < instruction->operand_count(); ++operand_number) {
@@ -79,17 +85,10 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
can_simplify = false;
break;
}
- if (first_gte == nullptr) {
- first_gte = operand;
- } else if (!first_gte->has_compatible_sharding(operand)) {
- can_simplify = false;
- break;
- }
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(),
- instruction->shape()) ||
- !instruction->has_compatible_sharding(top_tuple)) {
+ instruction->shape())) {
can_simplify = false;
break;
}
@@ -118,14 +117,12 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
HloInstruction* element_source =
instruction->mutable_operand(0)->mutable_operand(
instruction->tuple_index());
- if (instruction->has_compatible_sharding(element_source)) {
- changed = true;
- TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
- for (HloInstruction* user : element_source->users()) {
- if (user->opcode() == HloOpcode::kTuple ||
- user->opcode() == HloOpcode::kGetTupleElement) {
- worklist.push(user);
- }
+ changed = true;
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
+ for (HloInstruction* user : element_source->users()) {
+ if (user->opcode() == HloOpcode::kTuple ||
+ user->opcode() == HloOpcode::kGetTupleElement) {
+ worklist.push(user);
}
}
}
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index e5e9b10b5b..7509501883 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -27,13 +27,20 @@ namespace xla {
// the module.
class TupleSimplifier : public HloPassInterface {
public:
- TupleSimplifier() {}
+ TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
+ explicit TupleSimplifier(bool exclude_entry_computation);
~TupleSimplifier() override {}
tensorflow::StringPiece name() const override { return "tuple-simplifier"; }
// Run tuple simplification on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ // When set, this pipeline stage will perform optimization of all computations
+ // apart from the module's entry computation. This is used by Graphcore's
+ // backend.
+ bool exclude_entry_computation_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index ca9ae91281..d3635eae81 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase {
TF_ASSERT_OK(changed_status.status());
EXPECT_EQ(change_expected, changed_status.ValueOrDie());
}
+ void Run(HloModule* module, bool change_expected, bool exclude_entry) {
+ TupleSimplifier simplifier(exclude_entry);
+ auto changed_status = simplifier.Run(module);
+ TF_ASSERT_OK(changed_status.status());
+ EXPECT_EQ(change_expected, changed_status.ValueOrDie());
+ }
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
@@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
}
+TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
+ // Verify that the root computation can be excluded
+ auto module = CreateNewModule();
+
+ HloInstruction* p0;
+ HloInstruction* p1;
+ HloComputation* c0;
+ HloComputation* c1;
+ HloComputation* entry;
+
+ {
+ HloComputation::Builder builder(TestName() + "_1");
+ p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "param"));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
+ HloInstruction* gte2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
+
+ c0 = module->AddEmbeddedComputation(builder.Build());
+ }
+ {
+ HloComputation::Builder builder(TestName() + "_2");
+ p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "param"));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
+ HloInstruction* gte2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
+
+ c1 = module->AddEmbeddedComputation(builder.Build());
+ }
+ {
+ HloComputation::Builder builder(TestName() + "_Entry");
+ HloInstruction* tuple_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "param"));
+ HloInstruction* call0 = builder.AddInstruction(
+ HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
+ HloInstruction* call1 = builder.AddInstruction(
+ HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
+ HloInstruction* tuple0 =
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ HloInstruction* gte2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
+ HloInstruction* gte3 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));
+
+ entry = module->AddEntryComputation(builder.Build());
+ }
+
+ Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+
+ EXPECT_THAT(c0->root_instruction(), p0);
+ EXPECT_THAT(c1->root_instruction(), p1);
+ EXPECT_THAT(entry->instruction_count(), 9);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/compiler/xla/service/versioned_computation_handle.cc
deleted file mode 100644
index a693c4695f..0000000000
--- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc
+++ /dev/null
@@ -1,32 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
-
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace xla {
-
-string VersionedComputationHandle::ToString() const {
- return tensorflow::strings::StrCat(handle.handle(), ":v", version);
-}
-
-std::ostream& operator<<(std::ostream& out,
- const VersionedComputationHandle& versioned_handle) {
- out << versioned_handle.ToString();
- return out;
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h
deleted file mode 100644
index 5732a56caf..0000000000
--- a/tensorflow/compiler/xla/service/versioned_computation_handle.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
-
-#include <ostream>
-
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-
-namespace xla {
-
-// A data structure encapsulating a ComputationHandle and version value of that
-// computation. This object is used to unambiguously refer to a particular
-// computation in the service.
-struct VersionedComputationHandle {
- // A version value unambiguously specifying the state of the computation at a
- // particular point in time as it is being built. This value is the
- // ComputationDataHandle of the current root instruction.
- using Version = int64;
-
- ComputationHandle handle;
- Version version;
-
- string ToString() const;
- bool operator==(const VersionedComputationHandle& other) const {
- return (handle.handle() == other.handle.handle()) &&
- (version == other.version);
- }
- bool operator<(const VersionedComputationHandle& other) const {
- return ((handle.handle() < other.handle.handle()) ||
- ((handle.handle() == other.handle.handle()) &&
- (version < other.version)));
- }
-};
-
-std::ostream& operator<<(std::ostream& out,
- const VersionedComputationHandle& versioned_handle);
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 7f6bbe6f87..e7e0a19db0 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1204,6 +1204,22 @@ xla_test(
)
xla_test(
+ name = "token_hlo_test",
+ srcs = ["token_hlo_test.cc"],
+ tags = [
+ "enable_for_xla_interpreter",
+ ],
+ deps = [
+ ":client_library_test_base",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "call_test",
srcs = ["call_test.cc"],
tags = [
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 34c86e007b..3a0f51fc66 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -671,7 +671,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
auto result_status = Execute(&b, {});
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(result_status.status().error_message(),
- HasSubstr("op BINOP_ADD with incompatible shapes"));
+ HasSubstr("op add with incompatible shapes"));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
@@ -684,7 +684,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
auto result_status = Execute(&b, {});
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(result_status.status().error_message(),
- HasSubstr("op BINOP_ADD with incompatible shapes"));
+ HasSubstr("op add with incompatible shapes"));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 947959beb1..346bb3a399 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase {
#if XLA_TEST_BACKEND_GPU
// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
// convolution. So relax the absolute error threshold.
- ErrorSpec error_spec_ = ErrorSpec(1e-2);
+ ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4);
#else
- ErrorSpec error_spec_ = ErrorSpec(1e-4);
+ ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4);
#endif
};
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 08ed826c80..242cc5db11 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -94,8 +94,7 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
- return MakeUnique<HloModule>(name, VersionedComputationHandle(),
- GetModuleConfigForTest());
+ return MakeUnique<HloModule>(name, GetModuleConfigForTest());
}
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index eb3a2ea76a..249da87f48 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -66,6 +66,15 @@ namespace xla {
//
// For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test {
+ public:
+ // Creates a new HLO module for a test. The module created will have
+ // TestName() for its name; it will also automatically populate its debug
+ // options from command-line flags. If you want a fresh HloModule object and
+ // then add HloComputations to it, it's recommended to use this method in your
+ // tests.
+ static std::unique_ptr<HloModule> CreateNewModule(
+ const string& name = TestName());
+
protected:
// This uses the interpreter backend as the reference backend and
// automatically finds another supported backend as the test backend. If the
@@ -80,14 +89,6 @@ class HloTestBase : public ::testing::Test {
~HloTestBase() override {}
- // Creates a new HLO module for a test. The module created will have
- // TestName() for its name; it will also automatically populate its debug
- // options from command-line flags. If you want a fresh HloModule object and
- // then add HloComputations to it, it's recommended to use this method in your
- // tests.
- static std::unique_ptr<HloModule> CreateNewModule(
- const string& name = TestName());
-
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in
// DebugOptions, e.g. when creating a module from a string or a file.
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index c8a05c2e9e..22c664d142 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() {
<< "TearDown called more than once; it should be called exactly once.";
tear_down_called_ = true;
if (module_) {
- VerifyModule();
+ VerifyModule(module_.get());
+ }
+ for (int i = 0; i < modules_.size(); ++i) {
+ VerifyModule(modules_.at(i).get());
}
HloTestBase::TearDown();
}
-void HloVerifiedTestBase::VerifyModule() {
- HloVerifier verifier;
- xla::StatusOr<bool> mutated = verifier.Run(module_.get());
+void HloVerifiedTestBase::VerifyModule(HloModule* module) {
+ HloVerifier verifier(/*allow_mixed_precision=*/true);
+ xla::StatusOr<bool> mutated = verifier.Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
@@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() {
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = CreateNewModule();
+ module_ = HloTestBase::CreateNewModule();
}
return *module_;
}
+HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
+ modules_.emplace_back(HloTestBase::CreateNewModule());
+ return modules_.back().get();
+}
+
void HloVerifiedTestBase::ParseAndVerifyModule(
tensorflow::StringPiece hlo_text) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
- VerifyModule();
+ VerifyModule(module_.get());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index e5bb14a883..5b59cc77f6 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase {
shape_verifier_ = std::move(shape_verifier);
}
+ // Creates a new module for a test, and stores it in modules_ so it can be
+ // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
+ // creation of unverified modules.
+ HloModule* CreateNewModule(const string& name = TestName());
+
+ // It is confusing to store modules created by module() and CreateNewModule()
+ // in different fields, but it allows us to migrate tests to
+ // HloVerifiedTestBase more easily, so it's a win because we can verify more
+ // modules. See b/80488902.
private:
- std::unique_ptr<HloModule> module_; // Lazily populated. Access via module().
+ // Lazily populated. Access via module().
+ std::unique_ptr<HloModule> module_;
+ // Populated by calls to CreateNewModule.
+ std::vector<std::unique_ptr<HloModule>> modules_;
std::unique_ptr<ShapeVerifier> shape_verifier_;
bool tear_down_called_ = false;
- void VerifyModule();
+ static void VerifyModule(HloModule* module);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 2f46ee0be2..082bc34136 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test {
static std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>(TestName(), config);
}
};
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 7df45bebeb..3975e91257 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -488,10 +488,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_TRUE(!computation_status.ok());
- EXPECT_THAT(
- computation_status.status().ToString(),
- ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with "
- "different element types: f32[] and u16[]"));
+ EXPECT_THAT(computation_status.status().ToString(),
+ ::testing::HasSubstr("error from: ErrorAdd: Binary op add with "
+ "different element types: f32[] and u16[]"));
}
// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 7bfc8eb546..41f723edf1 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -380,5 +380,139 @@ XLA_TEST_F(MultiOutputFusionTest,
Literal::CreateR1<float>({66, 138}))));
}
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
+ ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0})
+ tuple(p0, r1, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result,
+ *Literal::MakeTupleOwned(
+ Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
+ Literal::CreateR2<float>({{3, 7}, {11, 15}}),
+ Literal::CreateR2<float>({{5, 16}, {36, 64}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0})
+ tuple(r1, mul, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result,
+ *Literal::MakeTupleOwned(
+ Literal::CreateR2<float>({{6, 8}, {10, 12}}),
+ Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ Literal::CreateR2<float>({{25, 36}, {49, 64}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1)
+ ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
+ tuple(r1, mul, mul2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result,
+ *Literal::MakeTupleOwned(
+ Literal::CreateR1<float>({14, 22}),
+ Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ Literal::CreateR3<float>(
+ {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ init1 = f32[] parameter(1)
+ init2 = f32[] parameter(2)
+ r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
+ r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ i = f32[] parameter(1)
+ j = f32[] parameter(2)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
+ calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto init1 = Literal::CreateR0<float>(5);
+ auto init2 = Literal::CreateR0<float>(6);
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto result,
+ Execute(std::move(module), {param.get(), init1.get(), init2.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result, *Literal::MakeTupleOwned(
+ Literal::CreateR2<float>({{167, 172}, {176, 180}}),
+ Literal::CreateR2<float>({{6, 6}, {6, 8}}))));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
new file mode 100644
index 0000000000..4585244ce8
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <array>
+
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class TokenHloTest : public HloTestBase {};
+
+// TODO(b/79770375): Compile, not just verify the HLO module when the backends
+// support kGenerateToken.
+XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+
+ module->AddEntryComputation(builder.Build());
+ EXPECT_IS_OK(HloVerifier().Run(module.get()).status());
+}
+
+XLA_TEST_F(TokenHloTest, TokenTree) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ builder.AddInstruction(
+ HloInstruction::CreateGenerateToken({token0, token0, token1, token2}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+
+ module->AddEntryComputation(builder.Build());
+ EXPECT_IS_OK(HloVerifier().Run(module.get()).status());
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Entry parameter 1 is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0,
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}),
+ "param"));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTokenRoot) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Entry root is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ builder.AddInstruction(HloInstruction::CreateGenerateToken({param}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(123)));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr(
+ "Operands of token instructions must be TOKEN types"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index ff5340ee3f..e4a052c8f1 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -85,6 +85,7 @@ cc_library(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/compiler/xla/tests:test_utils",
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index be094b7890..f7574e0b1c 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -24,6 +24,9 @@ limitations under the License.
// passing --use_fake_data on the command line. If the real data is available
// in the proto and --use_fake_data is false, the real data is used.
//
+// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
+// textual HLO string.
+//
// The output format is:
//
// file_path: computation_name :: type:literal_str
@@ -43,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -195,25 +199,45 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
return std::move(*result_literal);
}
+StatusOr<HloSnapshot> ParseInputFile(const string& filename,
+ const Options& opts) {
+ tensorflow::Env* env = tensorflow::Env::Default();
+ HloSnapshot snapshot;
+ if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) {
+ return snapshot;
+ }
+ CHECK(opts.use_fake_data)
+ << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto "
+ "and textual HLO don't carry real data.";
+ fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
+ filename.c_str());
+
+ if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) {
+ return snapshot;
+ }
+ fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str());
+ string contents;
+ TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents));
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(contents);
+ if (module.ok()) {
+ *snapshot.mutable_hlo()->mutable_hlo_module() =
+ module.ValueOrDie()->ToProto();
+ return snapshot;
+ }
+ fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n",
+ filename.c_str());
+ return InvalidArgument("Could not parse %s.", filename.c_str());
+}
+
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
- tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
- HloSnapshot snapshot;
- auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
- if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg);
- status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo());
- if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg,
- status.ToString().c_str());
- continue;
- }
- CHECK(opts.use_fake_data)
- << "HloProto input must be handled with --use_fake_data";
+ StatusOr<HloSnapshot> maybe_snapshot = ParseInputFile(arg, opts);
+ if (!maybe_snapshot.ok()) {
+ continue;
}
-
+ HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie();
StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 53ba120d21..6f07e4606b 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -225,14 +225,6 @@ message ExecutionOptions {
repeated DeviceHandle device_handles = 5;
}
-message SnapshotComputationRequest {
- ComputationHandle computation = 1;
-}
-
-message LoadComputationSnapshotResponse {
- ComputationHandle computation = 1;
-}
-
message GetDeviceHandlesRequest {
int64 device_count = 1;
}
@@ -291,11 +283,6 @@ message ResetDeviceRequest {
message ResetDeviceResponse {
}
-message ComputationStatsRequest {
- ComputationHandle computation = 1;
- DebugOptions debug_options = 2;
-}
-
message ComputationGraphStatsRequest {
HloModuleProto computation = 1;
DebugOptions debug_options = 2;
@@ -305,14 +292,6 @@ message ComputationStatsResponse {
ComputationStats stats = 1;
}
-message ComputationRequest {
- string name = 1;
-}
-
-message ComputationResponse {
- ComputationHandle computation = 1;
-}
-
message CreateChannelHandleRequest {
}
@@ -327,24 +306,6 @@ message UnregisterRequest {
message UnregisterResponse {
}
-message SetReturnValueRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
-}
-
-message SetReturnValueResponse {
-}
-
-message ExecuteRequest {
- reserved 3, 4;
-
- ComputationHandle computation = 1;
- repeated GlobalDataHandle arguments = 2;
-
- // Options that affect how XLA compiles and runs code to service this request.
- ExecutionOptions execution_options = 5;
-}
-
message ExecuteGraphRequest {
HloModuleProto computation = 1;
repeated GlobalDataHandle arguments = 2;
@@ -353,10 +314,6 @@ message ExecuteGraphRequest {
ExecutionOptions execution_options = 3;
}
-message ExecuteParallelRequest {
- repeated ExecuteRequest requests = 1;
-}
-
message ExecuteGraphParallelRequest {
repeated ExecuteGraphRequest requests = 1;
}
@@ -370,21 +327,6 @@ message ExecuteParallelResponse {
repeated ExecuteResponse responses = 1;
}
-message ExecuteAsyncRequest {
- reserved 3, 4;
-
- ComputationHandle computation = 1;
- repeated GlobalDataHandle arguments = 2;
-
- // Options that affect how XLA compiles and runs code to service this request.
- ExecutionOptions execution_options = 6;
-}
-
-message ExecuteAsyncResponse {
- // A handle to the execution launched asynchronously.
- ExecutionHandle execution = 1;
-}
-
message WaitForExecutionRequest {
ExecutionHandle execution = 1;
}
@@ -394,31 +336,13 @@ message WaitForExecutionResponse {
ExecutionProfile profile = 2;
}
-message IsConstantRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
- int64 num_parameters = 3;
-}
-
-message IsConstantResponse {
- bool is_constant = 1;
-}
-
-message ComputeConstantRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
- Layout output_layout = 3;
- repeated LiteralProto parameters = 4;
-}
-
message ComputeConstantGraphRequest {
HloModuleProto computation = 1;
Layout output_layout = 2;
}
message ComputeConstantResponse {
- // A LiteralProto is returned directly for this request, instead of a
- // ComputationDataHandle.
+ // A LiteralProto is returned directly for this request.
LiteralProto literal = 1;
}
@@ -460,14 +384,6 @@ message LoadDataResponse {
int64 nanoseconds = 5;
}
-message SpecializeRequest {
- ComputationHandle computation = 1;
- repeated GlobalDataHandle arguments = 2;
-}
-
-message SpecializeResponse {
-}
-
message GetShapeRequest {
GlobalDataHandle data = 1;
}
@@ -476,14 +392,6 @@ message GetShapeResponse {
Shape shape = 1;
}
-message GetComputationShapeRequest {
- ComputationHandle computation = 1;
-}
-
-message GetComputationShapeResponse {
- ProgramShape program_shape = 1;
-}
-
message UnpackRequest {
GlobalDataHandle data = 1;
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 6bdfb0179c..0af73e8a93 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -276,12 +276,6 @@ message ExecutionProfile {
int64 compute_and_transfer_time_ns = 5;
}
-// Handle given to a user that represents a computation that the user builds up
-// before execution.
-message ComputationHandle {
- int64 handle = 1;
-}
-
// Handle given to a user that represents an execution that the user launched
// asynchronously on the device.
message ExecutionHandle {
@@ -295,13 +289,6 @@ message GlobalDataHandle {
int64 handle = 1;
}
-// Handle given to a user that represents a data result in a computation.
-// This is used to pass to subsequent computations that depends upon the data as
-// an operand.
-message ComputationDataHandle {
- int64 handle = 1;
-}
-
// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
@@ -441,44 +428,6 @@ message GatherDimensionNumbers {
int64 index_vector_dim = 4;
}
-// Operation requests that are all collected as a tagged union with a oneof
-// field in OpRequest.
-
-message ConstantRequest {
- LiteralProto literal = 2;
-}
-
-message GetTupleElementRequest {
- ComputationDataHandle operand = 2;
- int64 index = 3;
-}
-
-message SliceRequest {
- ComputationDataHandle operand = 2;
- repeated int64 start_indices = 3;
- repeated int64 limit_indices = 4;
- repeated int64 strides = 5;
-}
-
-message DynamicSliceRequest {
- // Operand from which to slice at dynamic 'start_indices'.
- ComputationDataHandle operand = 2;
- // Dynamically computed 'start_indices' for slice operation.
- ComputationDataHandle start_indices = 3;
- // Slice sizes for each dimension (note that indices calculations are computed
- // modulo dimension sizes to avoid out-of-bound array accesses).
- repeated int64 slice_sizes = 4;
-}
-
-message DynamicUpdateSliceRequest {
- // Operand on which slice 'update' is to be applied.
- ComputationDataHandle operand = 2;
- // The slice update to apply to 'operand'.
- ComputationDataHandle update = 3;
- // Dynamically computed start indices for the update slice operation.
- ComputationDataHandle start_indices = 4;
-}
-
message ConvolutionDimensionNumbers {
// The number of the dimension that represents batch in the input.
int64 input_batch_dimension = 7;
@@ -516,13 +465,6 @@ message ConvolutionDimensionNumbers {
// Next = 13
};
-message ConvolveRequest {
- ComputationDataHandle lhs = 2;
- ComputationDataHandle rhs = 3; // This is the filter/kernel.
- Window window = 4; // Describes the filter/kernel.
- ConvolutionDimensionNumbers dimension_numbers = 5;
-}
-
enum FftType {
FFT = 0; // Forward FFT; complex in, complex out.
IFFT = 1; // Inverse FFT; complex in, complex out.
@@ -531,56 +473,6 @@ enum FftType {
// fft_length real out
}
-message FftRequest {
- FftType fft_type = 1;
- repeated int64 fft_length = 2; // Multivalent for higher-order FFT.
- ComputationDataHandle operand = 3;
-}
-
-message InfeedRequest {
- // The shape of the data returned by reading the device's infeed buffer.
- Shape shape = 2;
-
- // Additional infeed configuration for the backend.
- bytes config = 3;
-}
-
-message OutfeedRequest {
- // The shape of the data returned by reading the device's outfeed buffer.
- Shape shape = 1;
-
- // Operand to the Outfeed. Supports tuple.
- ComputationDataHandle operand = 2;
-
- // Backend-specific information for how to perform the outfeed.
- bytes outfeed_config = 3;
-}
-
-message CallRequest {
- ComputationHandle to_apply = 2;
- repeated ComputationDataHandle operands = 3;
-}
-
-message CustomCallRequest {
- string call_target_name = 2;
- repeated ComputationDataHandle operands = 3;
- Shape shape = 4;
-}
-
-message HostComputeRequest {
- // Operand to the HostCompute. Supports tuple.
- repeated ComputationDataHandle operands = 1;
-
- // Name used to identify HostSend/Recv channels.
- string channel_name = 2;
-
- // Cost estimate in nanoseconds.
- int64 cost_estimate_ns = 3;
-
- // The shape of any data returned by host.
- Shape shape = 4;
-}
-
message DotDimensionNumbers {
// The dimension numbers that represent the 'lhs' contracting dimensions.
repeated int64 lhs_contracting_dimensions = 1;
@@ -592,297 +484,6 @@ message DotDimensionNumbers {
repeated int64 rhs_batch_dimensions = 4;
};
-message DotRequest {
- ComputationDataHandle lhs = 2;
- ComputationDataHandle rhs = 3;
- DotDimensionNumbers dimension_numbers = 4;
-}
-
-message MapRequest {
- repeated ComputationDataHandle operands = 2;
- ComputationHandle to_apply = 3;
- repeated ComputationDataHandle static_operands = 4;
- // The dimensions over which to map.
- // Example mapping a Dot operation along the batch dimension 0:
- // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3]
- // Map({operand0, operand1}, Dot, {0})
- repeated int64 dimensions = 5;
-}
-
-message ReduceRequest {
- // Operand to the reduction.
- ComputationDataHandle operand = 2;
-
- // Initial value for the reduction. This must be consistent with the result
- // shape of to_apply.
- ComputationDataHandle init_value = 3;
-
- // The dimensions to reduce over.
- repeated int64 dimensions = 4;
-
- // The computation to apply in the reduction.
- ComputationHandle to_apply = 5;
-}
-
-message ReduceWindowRequest {
- ComputationDataHandle operand = 2;
- ComputationDataHandle init_value = 3;
- Window window = 4;
- ComputationHandle to_apply = 5;
-}
-
-message BatchNormTrainingRequest {
- ComputationDataHandle operand = 1;
- ComputationDataHandle scale = 2;
- ComputationDataHandle offset = 3;
- float epsilon = 4;
- int64 feature_index = 5;
-}
-
-message BatchNormInferenceRequest {
- ComputationDataHandle operand = 1;
- ComputationDataHandle scale = 2;
- ComputationDataHandle offset = 3;
- ComputationDataHandle mean = 4;
- ComputationDataHandle variance = 5;
- float epsilon = 6;
- int64 feature_index = 7;
-}
-
-message BatchNormGradRequest {
- ComputationDataHandle operand = 1;
- ComputationDataHandle scale = 2;
- ComputationDataHandle mean = 3;
- ComputationDataHandle variance = 4;
- ComputationDataHandle grad_output = 5;
- float epsilon = 6;
- int64 feature_index = 7;
-}
-
-message CrossReplicaSumRequest {
- ComputationDataHandle operand = 2;
-}
-
-message SelectAndScatterRequest {
- // Operand array on which the windows slide.
- ComputationDataHandle operand = 2;
-
- // Source array for the data to scatter.
- ComputationDataHandle source = 3;
-
- // Initial scalar value for each element in the output.
- ComputationDataHandle init_value = 4;
-
- // Window configuration.
- Window window = 5;
-
- // Binary function used to select an element from each window.
- ComputationHandle select = 6;
-
- // Binary function used to combine each scattered value from source with the
- // current output value at the selected location.
- ComputationHandle scatter = 7;
-}
-
-message ReverseRequest {
- ComputationDataHandle operand = 2;
- repeated int64 dimensions = 3;
-}
-
-message BroadcastRequest {
- ComputationDataHandle operand = 2;
- repeated int64 broadcast_sizes = 3;
-}
-
-message PadRequest {
- ComputationDataHandle operand = 2;
- ComputationDataHandle padding_value = 3;
- PaddingConfig padding_config = 4;
-}
-
-message ReshapeRequest {
- ComputationDataHandle operand = 2;
-
- // The dimension order for collapse (from fastest-changing to slowest).
- repeated int64 dimensions = 3;
-
- // The new dimension sizes (from dimension 0 to n-1).
- repeated int64 new_sizes = 4;
-}
-
-message TransposeRequest {
- ComputationDataHandle operand = 2;
-
- // The permutation of the operand's dimensions (in the range 0 to n-1).
- repeated int64 dimensions = 3;
-}
-
-message ParameterRequest {
- Shape shape = 2;
- int64 parameter = 3;
- string name = 4;
-}
-
-message GetLocalShapeRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
-}
-
-message GetLocalShapeResponse {
- Shape shape = 1;
-}
-
-message TraceRequest {
- string tag = 2;
- ComputationDataHandle operand = 3;
-}
-
-message ConvertRequest {
- ComputationDataHandle operand = 2;
- PrimitiveType new_element_type = 3;
-}
-
-message ConcatenateRequest {
- repeated ComputationDataHandle operands = 2;
- // The dimension in which we concatenate; e.g. if you had dimension arrays of
- // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1].
- // Attempting to concatenate those in dimension 1 would produce an error, as
- // 4 != 5 (and there is no ragged array support).
- int64 dimension = 3;
-}
-
-message ConditionalRequest {
- ComputationDataHandle predicate = 2;
- ComputationDataHandle true_operand = 3;
- ComputationHandle true_computation = 4;
- ComputationDataHandle false_operand = 5;
- ComputationHandle false_computation = 6;
-}
-
-message WhileRequest {
- ComputationHandle condition = 2;
- ComputationHandle body = 3;
- ComputationDataHandle init = 4;
-}
-
-enum UnaryOperation {
- UNOP_INVALID = 0;
-
- // Elementwise, logical negation on booleans and bitwise negation on ints.
- UNOP_NOT = 1;
-
- // Elementwise, computes e^x.
- UNOP_EXP = 2;
-
- // Elementwise, computes -x.
- UNOP_NEGATE = 3;
-
- // Puts the elements in the operand into sorted order.
- UNOP_SORT = 4;
-
- // Elementwise, computes tanh(x).
- UNOP_TANH = 5;
-
- // Elementwise, computes the natural logarithm of x.
- UNOP_LOG = 6;
-
- // Elementwise, computes the floor of x.
- UNOP_FLOOR = 7;
-
- // Elementwise, computes the ceil of x.
- UNOP_CEIL = 8;
-
- // Elementwise, computes the abs of x.
- UNOP_ABS = 9;
-
- // Elementwise, computes the sign of x.
- UNOP_SIGN = 10;
-
- // Elementwise, tests if values are finite (not NaN or inf)
- UNOP_IS_FINITE = 11;
-
- // Elementwise, computes the cosine of x.
- UNOP_COS = 12;
-
- // Elementwise, computes the sine of x.
- UNOP_SIN = 13;
-
- // Elementwise, rounds x to nearest integral value, rounding half-way cases
- // away from zero.
- UNOP_ROUND_NEAREST_AFZ = 14;
-
- // Elementwise, extract real component of complex x.
- UNOP_REAL = 15;
-
- // Elementwise, extract real component of complex x.
- UNOP_IMAG = 16;
-
- // Elementwise, computes clz(x).
- UNOP_CLZ = 17;
-
- // Elementwise, computes exp(x)-1.
- UNOP_EXPM1 = 18;
-
- // Elementwise, computes log(x+1).
- UNOP_LOG1P = 19;
-}
-
-message UnaryOpRequest {
- UnaryOperation unop = 2;
- ComputationDataHandle operand = 3;
-}
-
-enum BinaryOperation {
- BINOP_INVALID = 0;
-
- // Arithmetic operations.
- BINOP_ADD = 1;
- BINOP_DIV = 2;
- BINOP_MUL = 3;
- BINOP_SUB = 4;
-
- // Comparison operators.
- BINOP_EQ = 5;
- BINOP_GE = 6;
- BINOP_GT = 7;
- BINOP_LE = 8;
- BINOP_LT = 9;
- BINOP_NE = 10;
-
- // Element-wise maximum.
- BINOP_MAX = 14;
-
- // Element-wise minimum.
- BINOP_MIN = 15;
-
- // Raises the left-hand-side to the right-hand-side power.
- BINOP_POW = 16;
-
- // Remainder operation.
- BINOP_REM = 17;
-
- // Element-wise, logical operators on booleans and bitwise operators on ints.
- BINOP_AND = 18;
- BINOP_OR = 19;
-
- BINOP_SHIFT_LEFT = 20;
- BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
- BINOP_SHIFT_RIGHT_LOGICAL = 22;
-
- // Complex from real, imag.
- BINOP_COMPLEX = 23;
-
- // Computes the 4-quadrant arctangent of the y, x input arguments.
- BINOP_ATAN2 = 24;
-}
-
-message BinaryOpRequest {
- BinaryOperation binop = 2;
- ComputationDataHandle lhs = 3;
- ComputationDataHandle rhs = 4;
- repeated int64 broadcast_dimensions = 5;
-}
-
enum RandomDistribution {
RNG_INVALID = 0;
@@ -897,67 +498,6 @@ enum RandomDistribution {
// Next: 4
}
-message RngRequest {
- RandomDistribution distribution = 2;
- repeated ComputationDataHandle parameter = 3;
- Shape shape = 4;
-}
-
-enum TernaryOperation {
- TRIOP_INVALID = 0;
-
- // Given a predicate and two operands, selects operand0 if the predicate is
- // true and operand1 if the predicate is false.
- TRIOP_SELECT = 1;
-
- // Given a min, max and an operand returns the operand if between min and max,
- // else returns min if operand is less than min or max if operand is greater
- // than max.
- TRIOP_CLAMP = 3;
-}
-
-message TernaryOpRequest {
- TernaryOperation triop = 2;
- ComputationDataHandle lhs = 3;
- ComputationDataHandle rhs = 4;
- ComputationDataHandle ehs = 5;
-}
-
-enum VariadicOperation {
- VAROP_INVALID = 0;
-
- // Creates a tuple from its operands.
- VAROP_TUPLE = 1;
-}
-
-message VariadicOpRequest {
- VariadicOperation varop = 2;
- repeated ComputationDataHandle operands = 3;
-}
-
-message ReducePrecisionRequest {
- ComputationDataHandle operand = 1;
- int32 exponent_bits = 2;
- int32 mantissa_bits = 3;
-}
-
-message SendRequest {
- ComputationDataHandle operand = 1;
- ChannelHandle channel_handle = 2;
-}
-
-message RecvRequest {
- Shape shape = 1;
- ChannelHandle channel_handle = 2;
-}
-
-message GatherRequest {
- ComputationDataHandle input = 1;
- ComputationDataHandle gather_indices = 2;
- GatherDimensionNumbers dimension_numbers = 3;
- repeated int64 window_bounds = 4;
-}
-
message OpSharding {
enum Type {
// This sharding is replicated across all devices (implies maximal,
@@ -988,59 +528,3 @@ message OpSharding {
// to.
repeated OpSharding tuple_shardings = 5;
}
-
-message OpRequest {
- ComputationHandle computation = 1;
- OpMetadata metadata = 33;
- OpSharding sharding = 40;
-
- oneof op {
- BinaryOpRequest binary_op_request = 2;
- BroadcastRequest broadcast_request = 3;
- CallRequest call_request = 4;
- ConcatenateRequest concatenate_request = 5;
- ConstantRequest constant_request = 6;
- ConvertRequest convert_request = 7;
- ConvolveRequest convolve_request = 8;
- CrossReplicaSumRequest cross_replica_sum_request = 9;
- CustomCallRequest custom_call_request = 10;
- DotRequest dot_request = 43;
- DynamicSliceRequest dynamic_slice_request = 11;
- DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
- GetTupleElementRequest get_tuple_element_request = 13;
- InfeedRequest infeed_request = 14;
- MapRequest map_request = 15;
- PadRequest pad_request = 16;
- ParameterRequest parameter_request = 17;
- ReducePrecisionRequest reduce_precision_request = 36;
- ReduceRequest reduce_request = 18;
- ReduceWindowRequest reduce_window_request = 19;
- ReshapeRequest reshape_request = 20;
- ReverseRequest reverse_request = 21;
- RngRequest rng_request = 22;
- SelectAndScatterRequest select_and_scatter_request = 23;
- SliceRequest slice_request = 24;
- TernaryOpRequest ternary_op_request = 25;
- TraceRequest trace_request = 26;
- TransposeRequest transpose_request = 34;
- UnaryOpRequest unary_op_request = 27;
- VariadicOpRequest variadic_op_request = 28;
- WhileRequest while_request = 29;
- SendRequest send_request = 30;
- RecvRequest recv_request = 31;
- OutfeedRequest outfeed_request = 32;
- BatchNormTrainingRequest batch_norm_training_request = 35;
- BatchNormGradRequest batch_norm_grad_request = 37;
- BatchNormInferenceRequest batch_norm_inference_request = 38;
- FftRequest fft_request = 41;
- ConvertRequest bitcast_convert_request = 42;
- ConditionalRequest conditional_request = 44;
- HostComputeRequest host_compute_request = 45;
- GatherRequest gather_request = 46;
- // Next: 47
- }
-}
-
-message OpResponse {
- ComputationDataHandle output = 1;
-}
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index 79d73af980..dbdbad8f4c 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -30,6 +30,8 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert
from tensorflow.contrib.autograph.impl.api import RunMode
from tensorflow.contrib.autograph.impl.api import to_code
from tensorflow.contrib.autograph.impl.api import to_graph
+from tensorflow.contrib.autograph.impl.directives import set_element_type
+from tensorflow.contrib.autograph.impl.directives import set_loop_options
from tensorflow.contrib.autograph.impl.special_functions import stack
from tensorflow.contrib.autograph.pyct.transformer import AutographParseError
from tensorflow.python.util.all_util import remove_undocumented
@@ -42,8 +44,11 @@ _allowed_symbols = [
'do_not_convert',
'to_code',
'to_graph',
- # Special functions and overloaded operators
+ # Overloaded operators
'operators',
+ # Special functions and directives
+ 'set_element_type',
+ 'set_loop_options',
'stack',
# Exceptions
'AutographParseError',
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD
index 8f9bffa55e..284ad84be5 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/contrib/autograph/converters/BUILD
@@ -31,6 +31,7 @@ py_library(
"name_scopes.py",
"side_effect_guards.py",
"single_return.py",
+ "slices.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
@@ -208,3 +209,14 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "slices_test",
+ srcs = ["slices_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":test_lib",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py
index b49521b2c3..c15dfff9e8 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/contrib/autograph/converters/lists.py
@@ -33,82 +33,193 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.python.framework import dtypes
+from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+# Tags for local state.
+POP_USES = 'pop_uses'
class ListTransformer(transformer.Base):
"""Converts lists and related operations to their TF counterpart."""
- def _empty_list(self, node):
- if not anno.hasanno(node, 'element_type'):
- raise NotImplementedError(
- 'type inference for empty lists is not yet supported; '
- 'use set_element_type(<list>, <dtype>) to continue')
- dtype = anno.getanno(node, 'element_type')
- if not isinstance(dtype, dtypes.DType):
- # TODO(mdan): Allow non-TF dtypes?
- # That would be consistent with the dynamic dispatch pattern, but
- # we must make sure that doesn't become confusing.
- raise NotImplementedError('element type "%s" not yet supported' % dtype)
-
- dtype_name = dtype.name
- # TODO(mdan): Does it ever make sense not to use tensor lists?
+ def visit_List(self, node):
+ node = self.generic_visit(node)
template = """
- tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True)
+ ag__.new_list(elements)
"""
- return templates.replace_as_expression(template, dtype_name=dtype_name)
+ return templates.replace_as_expression(template, elements=node)
- def _pre_populated_list(self, node):
- raise NotImplementedError('pre-populated lists')
+ def _replace_append_call(self, node):
+ assert len(node.args) == 1
+ assert isinstance(node.func, gast.Attribute)
+ template = """
+ target = ag__.list_append(target, element)
+ """
+ return templates.replace(
+ template,
+ target=node.func.value,
+ element=node.args[0])
+
+ def _replace_pop_call(self, node):
+ # Expressions that use pop() are converted to a statement + expression.
+ #
+ # For example:
+ #
+ # print(target.pop())
+ #
+ # ... is converted to:
+ #
+ # target, target_pop = ag__.list_pop(target)
+ # print(target_pop)
+ #
+ # Here, we just generate the variable name and swap it in,
+ # and _generate_pop_operation will handle the rest.
+ #
+ # Multiple uses of pop() are allowed:
+ #
+ # print(tartget.pop(), target.pop())
+ # print(tartget.pop().pop())
+ #
+ assert isinstance(node.func, gast.Attribute)
+ scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
+ target_node = node.func.value
+
+ # Attempt to use a related name if can get one. Otherwise use something
+ # generic.
+ if anno.hasanno(target_node, anno.Basic.QN):
+ target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
+ else:
+ target_name = 'list'
+ pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced)
+
+ pop_uses = self.get_local(POP_USES, [])
+ pop_uses.append((node, pop_var_name))
+ self.set_local(POP_USES, pop_uses)
+
+ return templates.replace_as_expression('var_name', var_name=pop_var_name)
+
+ def _replace_stack_call(self, node):
+ assert len(node.args) == 1
+ dtype = anno.getanno(
+ node.args[0],
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+ template = """
+ ag__.list_stack(
+ target,
+ opts=ag__.ListStackOpts(
+ element_dtype=dtype,
+ original_call=orig_call))
+ """
+ return templates.replace_as_expression(
+ template,
+ dtype=dtype,
+ target=node.args[0],
+ orig_call=node.func)
- def visit_Expr(self, node):
+ def visit_Call(self, node):
node = self.generic_visit(node)
- if isinstance(node.value, gast.Call):
- call_node = node.value
-
- if not anno.hasanno(call_node.func, anno.Basic.QN):
- return node
- qn = anno.getanno(call_node.func, anno.Basic.QN)
-
- if qn.qn[-1] == 'append' and (len(call_node.args) == 1):
- template = """
- target = ag__.utils.dynamic_list_append(target, element)
- """
- node = templates.replace(
- template,
- target=qn.parent.ast(),
- element=call_node.args[0])
+
+ # TODO(mdan): This is insufficient if target is a function argument.
+ # In the case of function arguments, we need to add the list to the
+ # function's return value, because it is being modified.
+ # TODO(mdan): Checking just the name is brittle, can it be improved?
+ if isinstance(node.func, gast.Attribute):
+ func_name = node.func.attr
+ if func_name == 'append' and (len(node.args) == 1):
+ node = self._replace_append_call(node)
+ elif func_name == 'pop' and (len(node.args) <= 1):
+ node = self._replace_pop_call(node)
+ elif func_name == 'stack' and (len(node.args) == 1):
+ node = self._replace_stack_call(node)
+
return node
- def _replace_list_constructors(self, targets, values):
- for target in targets:
- if (isinstance(target, (gast.Tuple, gast.List)) and
- isinstance(values, (gast.Tuple, gast.List))):
- n_targets = len(target.elts)
- for i in range(n_targets):
- target_el, value_el = target.elts[i], values.elts[i]
- values.elts[i] = self._replace_list_constructors(
- (target_el,), value_el)
- return values
- if isinstance(values, gast.List):
- if values.elts:
- return self._pre_populated_list(values)
- else:
- return self._empty_list(values)
- return values
-
- def visit_Assign(self, node):
- node = self.generic_visit(node)
+ def _generate_pop_operation(self, original_call_node, pop_var_name):
+ assert isinstance(original_call_node.func, gast.Attribute)
+
+ if original_call_node.args:
+ pop_element = original_call_node.args[0]
+ else:
+ pop_element = parser.parse_expression('None')
+ # The call will be something like "target.pop()", and the dtype is hooked to
+ # target, hence the func.value.
+ dtype = anno.getanno(
+ original_call_node.func.value,
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+ shape = anno.getanno(
+ original_call_node.func.value,
+ 'element_shape',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ target, pop_var_name = ag__.list_pop(
+ target, element,
+ opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
+ """
+ return templates.replace(
+ template,
+ target=original_call_node.func.value,
+ pop_var_name=pop_var_name,
+ element=pop_element,
+ dtype=dtype,
+ shape=shape)
+
+ def _postprocess_statement(self, node):
+ """Inserts any separate pop() calls that node may use."""
+ pop_uses = self.get_local(POP_USES, None)
+ if pop_uses:
+ replacements = []
+ for original_call_node, pop_var_name in pop_uses:
+ replacements.extend(
+ self._generate_pop_operation(original_call_node, pop_var_name))
+ replacements.append(node)
+ node = replacements
+ self.exit_local_scope()
+ return node, None
+
+ # TODO(mdan): Should we have a generic visit_block instead?
+ # Right now it feels that a visit_block would add too much magic that's
+ # hard to follow.
+
+ def _visit_and_process_block(self, block):
+ return self.visit_block(
+ block,
+ before_visit=self.enter_local_scope,
+ after_visit=self._postprocess_statement)
+
+ def visit_FunctionDef(self, node):
+ node.args = self.generic_visit(node.args)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ node.body = self._visit_and_process_block(node.body)
+ return node
+
+ def visit_For(self, node):
+ node.target = self.visit(node.target)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_If(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
- # Only convert lists when they are assigned to a variable, e.g.:
- # l = []
- # TODO(mdan): A similar pattern exists in type_info.py
- # We should add a generic "unpack_assignment" function to the base
- # transformer, that has the same effect as applying some logic to the SSA
- # form.
- node.value = self._replace_list_constructors(node.targets, node.value)
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self._visit_and_process_block(node.body)
return node
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py
index 74c6dc64f1..9f18ab9f44 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/contrib/autograph/converters/lists_test.py
@@ -22,74 +22,126 @@ from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import lists
from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
class ListTest(converter_test_base.TestCase):
- def test_empty_annotated_list(self):
+ def test_empty_list(self):
def test_fn():
- l = []
- utils.set_element_type(l, dtypes.int32)
- l.append(1)
- return l
+ return []
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ node = self.parse_and_analyze(test_fn, {})
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
- # TODO(mdan): Attach these additional modules automatically.
- result.utils = utils
- result.dtypes = dtypes
+ with self.compiled(node) as result:
+ tl = result.test_fn()
+ # Empty tensor lists cannot be evaluated or stacked.
+ self.assertTrue(isinstance(tl, ops.Tensor))
+ self.assertEqual(tl.dtype, dtypes.variant)
+
+ def test_initialized_list(self):
+
+ def test_fn():
+ return [1, 2, 3]
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
with self.test_session() as sess:
- self.assertAllEqual([1], sess.run(result.test_fn().stack()))
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
- def test_empty_annotated_lists_unpacked(self):
+ def test_list_append(self):
def test_fn():
- l, m = [], []
- utils.set_element_type(l, dtypes.int32)
- utils.set_element_type(m, dtypes.int32)
- l.append(1)
- m.append(2)
- return l, m
+ l = [1]
+ l.append(2)
+ l.append(3)
+ return l
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ node = self.parse_and_analyze(test_fn, {})
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
+
+ def test_list_pop(self):
+
+ def test_fn():
+ l = [1, 2, 3]
+ utils.set_element_type(l, dtypes.int32, ())
+ s = l.pop()
+ return s, l
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
result.utils = utils
result.dtypes = dtypes
with self.test_session() as sess:
- res_l, res_m = result.test_fn()
- self.assertEqual([1], sess.run(res_l.stack()))
- self.assertEqual([2], sess.run(res_m.stack()))
+ ts, tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2])
+ self.assertAllEqual(sess.run(ts), 3)
+
+ def test_double_list_pop(self):
- def test_empty_annotated_lists_list_unpacked(self):
+ def test_fn(l):
+ s = l.pop().pop()
+ return s
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ test_input = [1, 2, [1, 2, 3]]
+ # TODO(mdan): Pass a list of lists of tensor when we fully support that.
+ # For now, we just pass a regular Python list of lists just to verify that
+ # the two pop calls are sequenced properly.
+ self.assertAllEqual(result.test_fn(test_input), 3)
+
+ def test_list_stack(self):
+
+ tf = None # Will be replaced with a mock.
def test_fn():
- [l, m] = [], []
+ l = [1, 2, 3]
utils.set_element_type(l, dtypes.int32)
- utils.set_element_type(m, dtypes.int32)
- l.append(1)
- m.append(2)
- return l, m
-
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ return tf.stack(l)
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
+ with self.compiled(node, array_ops.stack, dtypes.int32) as result:
result.utils = utils
result.dtypes = dtypes
with self.test_session() as sess:
- res_l, res_m = result.test_fn()
- self.assertEqual([1], sess.run(res_l.stack()))
- self.assertEqual([2], sess.run(res_m.stack()))
+ self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py
new file mode 100644
index 0000000000..85aeda9c41
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/slices.py
@@ -0,0 +1,83 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for slice operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.contrib.autograph.pyct import transformer
+
+
+class SliceTransformer(transformer.Base):
+ """Converts slicing operations to their TF counterpart.
+
+ Currently, relying on the default slice operator that Tensor uses is
+ insufficient, because TensorArray and tensor lists use dedicated index read
+ and write functions.
+ """
+
+ def _process_single_assignment(self, target, value):
+ if not isinstance(target, gast.Subscript):
+ return None
+
+ template = """
+ target = ag__.set_item(target, key, item)
+ """
+ return templates.replace(
+ template, target=target.value, key=target.slice, item=value)
+
+ def visit_Assign(self, node):
+ node = self.generic_visit(node)
+ # TODO(mdan): Support unpackings and multiple assignments.
+ if len(node.targets) != 1:
+ raise NotImplementedError('multiple assignment')
+ replacement = self._process_single_assignment(node.targets[0], node.value)
+ if replacement is not None:
+ return replacement
+ return node
+
+ def visit_Subscript(self, node):
+ node = self.generic_visit(node)
+ if not isinstance(node.slice, gast.Index):
+ # TODO(mdan): It might make more sense to wave them through.
+ raise NotImplementedError('non-index slice')
+
+ if not isinstance(node.ctx, gast.Load):
+ # Index writes are handled at a higher level, one at which the rvalue is
+ # also available.
+ return node
+
+ dtype = anno.getanno(
+ node.value,
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ ag__.get_item(
+ target,
+ key,
+ opts=ag__.GetItemOpts(element_dtype=dtype))
+ """
+ return templates.replace_as_expression(
+ template, target=node.value, key=node.slice, dtype=dtype)
+
+
+def transform(node, context):
+ return SliceTransformer(context).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py
new file mode 100644
index 0000000000..6c2d7e1ea1
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/slices_test.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph import utils
+from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.converters import slices
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SliceTest(converter_test_base.TestCase):
+
+ def test_index_access(self):
+
+ def test_fn(l):
+ utils.set_element_type(l, dtypes.int32)
+ return l[1]
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
+ node = slices.transform(node, self.ctx)
+
+ with self.compiled(node, dtypes.int32) as result:
+ result.utils = utils
+ result.dtypes = dtypes
+ with self.test_session() as sess:
+ tl = list_ops.tensor_list_from_tensor(
+ [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
+ y = result.test_fn(tl)
+ self.assertEqual(2, sess.run(y))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD
index 91ae0b9b82..02f16ae187 100644
--- a/tensorflow/contrib/autograph/impl/BUILD
+++ b/tensorflow/contrib/autograph/impl/BUILD
@@ -20,6 +20,7 @@ py_library(
"api.py",
"config.py",
"conversion.py",
+ "directives.py",
"naming.py",
"special_functions.py",
],
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 55a30dc127..7802bbbe27 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -38,6 +38,7 @@ from tensorflow.contrib.autograph.converters import logical_expressions
from tensorflow.contrib.autograph.converters import name_scopes
from tensorflow.contrib.autograph.converters import side_effect_guards
from tensorflow.contrib.autograph.converters import single_return
+from tensorflow.contrib.autograph.converters import slices
from tensorflow.contrib.autograph.impl import config
from tensorflow.contrib.autograph.impl import naming
from tensorflow.contrib.autograph.pyct import ast_util
@@ -371,6 +372,8 @@ def node_to_graph(node, ctx, nocompile_decorators):
# TODO(mdan): Clean this up.
# Some intermediate analyses are not required, and some comments got orphaned.
+ # TODO(mdan): We may assume all converters require analysis to be re-done.
+
# Past this point, line numbers are no longer accurate so we ignore the
# source.
# TODO(mdan): Is it feasible to reconstruct intermediate source code?
@@ -393,6 +396,8 @@ def node_to_graph(node, ctx, nocompile_decorators):
node = _static_analysis_pass(node, ctx)
node = lists.transform(node, ctx)
+ node = _static_analysis_pass(node, ctx)
+ node = slices.transform(node, ctx)
node = builtin_functions.transform(node, ctx)
node = _static_analysis_pass(node, ctx)
diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/impl/directives.py
new file mode 100644
index 0000000000..aabe5d9939
--- /dev/null
+++ b/tensorflow/contrib/autograph/impl/directives.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Directives are special no-op functions that serve as compilation markers.
+
+They provide static information like type hints, compilation and TensorFlow
+overrides.
+
+These serve as annotations in the compiled code, allowing the user some control
+over the compilation process. They have no functional role at runtime.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+UNSPECIFIED = object()
+
+
+def set_element_type(entity, dtype, shape=UNSPECIFIED):
+ """Indicates that the entity is expected hold items of specified type/shape.
+
+ The staged TensorFlow ops will reflect and assert this data type. Ignored
+ otherwise.
+
+ Args:
+ entity: The entity to annotate.
+ dtype: TensorFlow dtype value to assert for entity.
+ shape: Optional shape to assert for entity.
+ """
+ del entity
+ del dtype
+ del shape
+
+
+def set_loop_options(
+ parallel_iterations=UNSPECIFIED,
+ back_prop=UNSPECIFIED,
+ swap_memory=UNSPECIFIED,
+ maximum_iterations=UNSPECIFIED):
+ """Specifies additional arguments to be passed to the enclosing while_loop.
+
+ The parameters apply to and only to the immediately enclosing loop. It only
+ has effect if the loop is staged as a TF while_loop; otherwise the parameters
+ have no effect.
+
+ Args:
+ parallel_iterations: See tf.while_loop.
+ back_prop: See tf.while_loop.
+ swap_memory: See tf.while_loop.
+ maximum_iterations: See tf.while_loop.
+ """
+ del parallel_iterations
+ del back_prop
+ del swap_memory
+ del maximum_iterations
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
index d6555dc7e0..7d1e65c958 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
@@ -17,8 +17,8 @@
This analyzer uses known live values to further infer object types. This
may include for instance constructed objects and object member functions.
-In addition, the analyzer will also process annotations for TF (staged) type
-annotations.
+In addition, the analyzer also handles user annotations made in the code (for
+example, the autograph.set_element_type function).
Requires annotations generated by LiveValuesResolver.
"""
@@ -44,6 +44,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -159,12 +160,10 @@ class TypeInfoResolver(transformer.Base):
# a = b
# then for future references to `a` we should have definition = `b`
definition = self.scope.getval(qn)
- if anno.hasanno(definition, 'type'):
- anno.setanno(node, 'type', anno.getanno(definition, 'type'))
- anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
- if anno.hasanno(definition, 'element_type'):
- anno.setanno(node, 'element_type',
- anno.getanno(definition, 'element_type'))
+ anno.copyanno(definition, node, 'type')
+ anno.copyanno(definition, node, 'type_fqn')
+ anno.copyanno(definition, node, 'element_type')
+ anno.copyanno(definition, node, 'element_shape')
return node
def _process_variable_assignment(self, target, value):
@@ -211,23 +210,20 @@ class TypeInfoResolver(transformer.Base):
if (anno.getanno(node.func, 'live_val') is
self.context.type_annotation_func):
- if len(node.args) != 2:
- raise ValueError('"%s" must have exactly two parameters'
+ if len(node.args) < 2 or len(node.args) > 3:
+ raise ValueError('"%s" must have either two or three parameters'
% self.context.type_annotation_func)
- target_arg, type_arg = node.args
+ if len(node.args) == 2:
+ target_arg, type_arg = node.args
+ shape_arg = parser.parse_expression('None')
+ else:
+ target_arg, type_arg, shape_arg = node.args
if not anno.hasanno(target_arg, anno.Basic.QN):
raise ValueError('the first argument of "%s" must by a symbol'
% self.context.type_annotation_func)
- if isinstance(type_arg, gast.Str):
- element_type = type_arg.s
- elif isinstance(type_arg, gast.Num):
- element_type = type_arg.n
- else:
- if not anno.hasanno(type_arg, 'live_val'):
- raise ValueError(
- 'the second argument of "%s" must be statically resolvable' %
- self.context.type_annotation_func)
- element_type = anno.getanno(type_arg, 'live_val')
+ # TODO(mdan): This is vulnerable to symbol renaming.
+ element_type = type_arg
+ element_shape = shape_arg
target_symbol = anno.getanno(target_arg, anno.Basic.QN)
# Find the definition of this symbol and annotate it with the given
@@ -235,7 +231,9 @@ class TypeInfoResolver(transformer.Base):
# to receive the same type annotation.
definition = self.scope.getval(target_symbol)
anno.setanno(node, 'element_type', element_type)
+ anno.setanno(node, 'element_shape', element_shape)
anno.setanno(definition, 'element_type', element_type)
+ anno.setanno(definition, 'element_shape', element_shape)
# TODO(mdan): Should we update references between definition and here?
return self.generic_visit(node)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
index 95cbf5ca79..484562f294 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
@@ -187,14 +187,14 @@ class TypeInfoResolverTest(test.TestCase):
def test_fn():
f = []
- f = utils.set_element_type(f, Foo)
+ f = utils.set_element_type(f, Foo, (1, 2, 3))
return f
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
f_def = node.body[0].body[0].value
- self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
+ self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo')
f_ref = node.body[0].body[1].value
- self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+ self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo')
def test_type_annotation_args(self):
@@ -207,7 +207,7 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
f_ref = node.body[0].body[1].value
- self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+ self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo')
def test_nested_unpacking(self):
@@ -223,9 +223,9 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
a, b, c = node.body[0].body[1].value.elts
- self.assertEquals(Foo, anno.getanno(a, 'type'))
- self.assertEquals(Bar, anno.getanno(b, 'type'))
- self.assertEquals(Foo, anno.getanno(c, 'type'))
+ self.assertEquals(anno.getanno(a, 'type'), Foo)
+ self.assertEquals(anno.getanno(b, 'type'), Bar)
+ self.assertEquals(anno.getanno(c, 'type'), Foo)
self.assertFalse(anno.hasanno(a, 'live_val'))
self.assertFalse(anno.hasanno(b, 'live_val'))
self.assertFalse(anno.hasanno(c, 'live_val'))
@@ -242,8 +242,8 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'utils': utils})
a, b = node.body[0].body[2].body[2].value.elts
- self.assertEquals(1, anno.getanno(a, 'element_type'))
- self.assertEquals(2, anno.getanno(b, 'element_type'))
+ self.assertEquals(anno.getanno(a, 'element_type').n, 1)
+ self.assertEquals(anno.getanno(b, 'element_type').n, 2)
self.assertFalse(anno.hasanno(a, 'type'))
self.assertFalse(anno.hasanno(b, 'type'))
self.assertFalse(anno.hasanno(a, 'live_val'))
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index baf7923fff..9c479ebc2f 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -239,8 +239,13 @@ def replace_as_expression(template, **replacements):
raise ValueError(
'single expression expected; for more general templates use replace')
node = replacement[0]
- if not isinstance(node, gast.Expr):
- raise ValueError(
- 'the template is expected to generate an expression node; instead '
- 'found %s' % node)
- return node.value
+ node = qual_names.resolve(node)
+
+ if isinstance(node, gast.Expr):
+ return node.value
+ elif isinstance(node, gast.Name):
+ return node
+
+ raise ValueError(
+ 'the template is expected to generate an expression or a name node;'
+ ' instead found %s' % node)
diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py
index 44fa5f42a7..1e503a097a 100644
--- a/tensorflow/contrib/batching/__init__.py
+++ b/tensorflow/contrib/batching/__init__.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Ops and modules related to batch.
+@@batch_function_v1
@@batch_function
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index 921d6917a4..012a51f711 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
@@ -102,6 +103,74 @@ def batch_function(num_batch_threads,
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
+
+ def decorator(fn): # pylint: disable=missing-docstring
+
+ def decorated(*args): # pylint: disable=missing-docstring
+ types = [arg.dtype for arg in args]
+
+ @function.Defun(*types)
+ def computation(*computation_args):
+ return fn(*computation_args)
+
+ with ops.name_scope("batch") as name:
+ for a in args:
+ if not isinstance(a, ops.Tensor):
+ raise ValueError("All arguments to functions decorated with "
+ "`batch_function` are supposed to be Tensors; "
+ "found %s" % repr(a))
+ for inp in computation.captured_inputs:
+ print("inp: %s" % inp)
+ for op in inp.consumers():
+ print("op: %s" % op)
+ return gen_batch_ops.batch_function(
+ num_batch_threads=num_batch_threads,
+ max_batch_size=max_batch_size,
+ batch_timeout_micros=batch_timeout_micros,
+ allowed_batch_sizes=allowed_batch_sizes,
+ max_enqueued_batches=max_enqueued_batches,
+ shared_name=name,
+ f=computation,
+ in_tensors=list(args),
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ return decorated
+
+ return decorator
+
+
+def batch_function_v1(num_batch_threads,
+ max_batch_size,
+ batch_timeout_micros,
+ allowed_batch_sizes=None,
+ grad_timeout_micros=60 * 1000 * 1000,
+ unbatch_timeout_micros=60 * 1000 * 1000,
+ max_enqueued_batches=10):
+ """Batches the computation done by the decorated function.
+
+ This is the older version of batch_function(). Please use the former instead
+ of this.
+
+ Args:
+ num_batch_threads: Number of scheduling threads for processing batches
+ of work. Determines the number of batches processed in parallel.
+ max_batch_size: Batch sizes will never be bigger than this.
+ batch_timeout_micros: Maximum number of microseconds to wait before
+ outputting an incomplete batch.
+ allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
+ does nothing. Otherwise, supplies a list of batch sizes, causing the op
+ to pad batches up to one of those sizes. The entries must increase
+ monotonically, and the final entry must equal max_batch_size.
+ grad_timeout_micros: The timeout to use for the gradient. See the
+ documentation of the unbatch op for more details. Defaults to 60s.
+ unbatch_timeout_micros: The timeout to use for unbatching. See the
+ documentation of the unbatch op for more details. Defaults to 60s.
+ max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
+
+ Returns:
+ The decorated function will return the unbatched computation output Tensors.
+ """
def decorator(f): # pylint: disable=missing-docstring
def decorated(*args):
with ops.name_scope("batch") as name:
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index ea8339334f..7846814546 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -188,12 +188,62 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBasicUnbatchV1Decorated(self):
+ """Tests that the batch_function_v1 decorator works."""
+ with self.test_session() as sess:
+ @batch_ops.batch_function_v1(1, 10, 100000)
+ def computation(in_t):
+ return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
with self.test_session() as sess:
+ # TODO(apassos): Removing this line causes test flakiness! Ideally should
+ # be investigated.
+ default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
+
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBatchDecoratedWithCapturedInput(self):
+ """Tests that the batch_function decorator works."""
+ with self.test_session() as sess:
+ captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
+ captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+
+ @batch_ops.batch_function(1, 10, 100000)
+ def computation(in_t):
+ return in_t + captured_inp0 - captured_inp1
+
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
index 5770bcdd70..68fa415eea 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Monte Carlo integration and helpers.
-
-See the @{$python/contrib.bayesflow.monte_carlo} guide.
-"""
+"""Monte Carlo integration and helpers."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 89d0d611d2..9c36c30221 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -41,7 +41,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- use_core_libs=False):
+ use_core_libs=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -66,6 +67,16 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -74,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
# supports second order derivative.
def loss_fn(labels, logits, weights=None):
result = losses.per_example_maxent_loss(
- labels=labels, logits=logits, weights=weights,
+ labels=labels,
+ logits=logits,
+ weights=weights,
num_classes=n_classes)
return math_ops.reduce_mean(result[0])
else:
@@ -102,6 +115,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
+ 'output_leaf_index': output_leaf_index,
},
model_dir=model_dir,
config=config,
@@ -124,7 +138,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- use_core_libs=False):
+ use_core_libs=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -151,6 +166,13 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. For example,
+ result_dict = classifier.predict(...)
+ for example_prediction_result in result_dict:
+ # access leaf index list by example_prediction_result["leaf_index"]
+ # which contains one leaf index per tree
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -173,6 +195,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
+ 'output_leaf_index': False,
},
model_dir=model_dir,
config=config,
@@ -197,7 +220,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- use_core_libs=False):
+ use_core_libs=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -220,6 +244,13 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. For example,
+ result_dict = classifier.predict(...)
+ for example_prediction_result in result_dict:
+ # access leaf index list by example_prediction_result["leaf_index"]
+ # which contains one leaf index per tree
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -233,6 +264,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
+ 'output_leaf_index': False,
},
model_dir=model_dir,
config=config,
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 0d58317bd5..75ef1b0500 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -68,6 +68,28 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
classifier.export(self._export_dir_base)
+ def testThatLeafIndexIsInPredictions(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")],
+ output_leaf_index=True)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ result_iter = classifier.predict(input_fn=_eval_input_fn)
+ for prediction_dict in result_iter:
+ self.assertTrue("leaf_index" in prediction_dict)
+ self.assertTrue("logits" in prediction_dict)
+
def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 15ab6d8145..1ee8911989 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -63,6 +63,8 @@ def model_builder(features, labels, mode, params, config):
num_trees = params["num_trees"]
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
+ output_leaf_index = params["output_leaf_index"]
+
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -96,7 +98,8 @@ def model_builder(features, labels, mode, params, config):
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=training_features,
- use_core_columns=use_core_libs)
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -127,6 +130,9 @@ def model_builder(features, labels, mode, params, config):
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
if num_trees:
if center_bias:
num_trees += 1
diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
index b3fe38614e..9493c1a139 100644
--- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
@@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout";
const char* kApplyAveragingAttributeName = "apply_averaging";
const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights";
const char* kPredictionsTensorName = "predictions";
+const char* kLeafIndexTensorName = "leaf_index";
void CalculateTreesToInclude(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
@@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel {
core::ScopedUnref unref_me(ensemble_resource);
if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex());
- DoCompute(context, ensemble_resource);
+ DoCompute(context, ensemble_resource,
+ /*return_output_leaf_index=*/false);
} else {
- DoCompute(context, ensemble_resource);
+ DoCompute(context, ensemble_resource,
+ /*return_output_leaf_index=*/false);
}
}
- private:
- void DoCompute(OpKernelContext* context,
- DecisionTreeEnsembleResource* ensemble_resource) {
+ protected:
+ // return_output_leaf_index is a boolean variable indicating whether to output
+ // leaf index in prediction. Though this class invokes only with this param
+ // value as false, the subclass GradientTreesPredictionVerboseOp will invoke
+ // with the true value.
+ virtual void DoCompute(OpKernelContext* context,
+ DecisionTreeEnsembleResource* ensemble_resource,
+ const bool return_output_leaf_index) {
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
@@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel {
&output_predictions_t));
auto output_predictions = output_predictions_t->matrix<float>();
+ // Allocate output leaf index matrix.
+ Tensor* output_leaf_index_t = nullptr;
+ if (return_output_leaf_index) {
+ OP_REQUIRES_OK(context, context->allocate_output(
+ kLeafIndexTensorName,
+ {batch_size, ensemble_resource->num_trees()},
+ &output_leaf_index_t));
+ }
// Run predictor.
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
@@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel {
i, weight * (num_ensembles - i + start_averaging) / num_ensembles);
}
MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features,
- worker_threads, output_predictions);
+ worker_threads, output_predictions,
+ output_leaf_index_t);
} else {
MultipleAdditiveTrees::Predict(
ensemble_resource->decision_tree_ensemble(), trees_to_include,
- batch_features, worker_threads, output_predictions);
+ batch_features, worker_threads, output_predictions,
+ output_leaf_index_t);
}
// Output dropped trees and original weights.
@@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel {
{2, static_cast<int64>(dropped_trees.size())},
&output_dropout_info_t));
auto output_dropout_info = output_dropout_info_t->matrix<float>();
-
for (int32 i = 0; i < dropped_trees.size(); ++i) {
output_dropout_info(0, i) = dropped_trees[i];
output_dropout_info(1, i) = original_weights[i];
@@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU),
GradientTreesPredictionOp);
+// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp
+// and have an additional output of tensor of rank 2 containing leaf ids for
+// each tree where an instance ended up with.
+class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
+ public:
+ explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context)
+ : GradientTreesPredictionOp(context) {}
+
+ protected:
+ void DoCompute(OpKernelContext* context,
+ DecisionTreeEnsembleResource* ensemble_resource,
+ bool return_output_leaf_index) override {
+ GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
+ /*return_output_leaf_index=*/true);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU),
+ GradientTreesPredictionVerboseOp);
+
class GradientTreesPartitionExamplesOp : public OpKernel {
public:
explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context)
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
index 43b00d4c6d..c9223afeab 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
@@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict(
const std::vector<int32>& trees_to_include,
const boosted_trees::utils::BatchFeatures& features,
tensorflow::thread::ThreadPool* const worker_threads,
- tensorflow::TTypes<float>::Matrix output_predictions) {
+ tensorflow::TTypes<float>::Matrix output_predictions,
+ Tensor* const output_leaf_index) {
// Zero out predictions as the model is additive.
output_predictions.setZero();
@@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict(
// Lambda for doing a block of work.
auto update_predictions = [&config, &features, &trees_to_include,
- &output_predictions](int64 start, int64 end) {
+ &output_predictions,
+ &output_leaf_index](int64 start, int64 end) {
auto examples_iterable = features.examples_iterable(start, end);
+ Tensor dummy_tensor(DT_INT32, TensorShape({1, 1}));
+ tensorflow::TTypes<int>::Matrix output_leaf_index_mat =
+ output_leaf_index != nullptr ? output_leaf_index->matrix<int>()
+ : dummy_tensor.matrix<int>();
for (const auto& example : examples_iterable) {
for (const int32 tree_idx : trees_to_include) {
const boosted_trees::trees::DecisionTreeConfig& tree =
@@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict(
const float tree_weight = config.tree_weights(tree_idx);
const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
+ // Checks if output leaf tree index is required.
+ if (output_leaf_index != nullptr) {
+ output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx;
+ }
const auto& leaf_node = tree.nodes(leaf_idx);
QCHECK(leaf_node.has_leaf())
<< "Invalid leaf node: " << leaf_node.DebugString();
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
index cc3dc226cd..940531c4ba 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
@@ -33,12 +33,17 @@ class MultipleAdditiveTrees {
public:
// Predict runs tree ensemble on the given batch and updates
// output predictions accordingly, for the given list of trees.
+ // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not
+ // nullptr, this method fills output_leaf_indices with a per-tree leaf id
+ // where each of the instances from 'features' ended up in. Its shape is num
+ // examples X num of trees.
static void Predict(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const std::vector<int32>& trees_to_include,
const boosted_trees::utils::BatchFeatures& features,
tensorflow::thread::ThreadPool* const worker_threads,
- tensorflow::TTypes<float>::Matrix output_predictions);
+ tensorflow::TTypes<float>::Matrix output_predictions,
+ Tensor* const output_leaf_index);
};
} // namespace models
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
index 4ca18bedb1..462a9ac86f 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
@@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) {
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
kNumThreadsSingleThreaded);
MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix,
+ /*output_leaf_index=*/nullptr);
EXPECT_EQ(0, output_matrix(0, 0));
EXPECT_EQ(0, output_matrix(1, 0));
}
@@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) {
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
- batch_features_, &threads, output_matrix);
+ batch_features_, &threads, output_matrix,
+ /*output_leaf_index=*/nullptr);
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
}
+ // Normal case with leaf node.
+ {
+ // Initialize output leaf index tensor, since leaf index is positive in this
+ // case, initialize with the value of -1. Since there are 2 examples and
+ // there are 2 trees, initialize leaf output index by 2 * 2.
+ Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2}));
+ MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
+ batch_features_, &threads, output_matrix,
+ &output_leaf_index_tensor);
+ EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
+ EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
+ EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix<int>()(
+ 0, 0)); // 1st leaf for the first example
+ EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix<int>()(
+ 1, 0)); // 1st leaf for the second example
+ EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix<int>()(
+ 0, 1)); // 2nd leaf for the first example
+ EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix<int>()(
+ 1, 1)); // 2nd leaf for the second example
+ }
// Weighted case
{
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
weighted.set_tree_weights(0, 6.0);
weighted.set_tree_weights(1, 3.2);
MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads,
- output_matrix);
+ output_matrix, nullptr);
// -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0));
// -0.4 (bias) + 0.9 (leaf 1).
@@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) {
// Drop first tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2).
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1).
}
// Drop second tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias).
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias).
}
// Drop all trees.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0));
}
@@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
- batch_features_, &threads, output_matrix);
+ batch_features_, &threads, output_matrix,
+ nullptr);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1)
@@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
weighted.set_tree_weights(0, 6.0);
weighted.set_tree_weights(1, 3.2);
MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads,
- output_matrix);
+ output_matrix, nullptr);
// bias
EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0));
// bias + leaf 2
@@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Dropout first tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2)
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2)
@@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Dropout second tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias)
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias)
@@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Drop both trees.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1));
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0));
@@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) {
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
- batch_features_, &threads, output_matrix);
+ batch_features_, &threads, output_matrix,
+ nullptr);
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2)
EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2)
diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
index d66f645f62..6491d58794 100644
--- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
@@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
return Status::OK();
}
+static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) {
+ string learner_config_str;
+ c->GetAttr("learner_config", &learner_config_str).IgnoreError();
+ LearnerConfig learner_config;
+ ParseProtoUnlimited(&learner_config, learner_config_str);
+
+ bool reduce_dim;
+ c->GetAttr("reduce_dim", &reduce_dim).IgnoreError();
+ // Sets the shape of the output as a matrix.
+ c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
+ reduce_dim ? learner_config.num_classes() - 1
+ : learner_config.num_classes())});
+ c->set_output(1, {c->UnknownShape()});
+ c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim)});
+ return Status::OK();
+}
+
REGISTER_OP("GradientTreesPrediction")
.Attr("learner_config: string")
.Attr("num_dense_float_features: int >= 0")
@@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
and original weights of those trees during prediction.
)doc");
+REGISTER_OP("GradientTreesPredictionVerbose")
+ .Attr("learner_config: string")
+ .Attr("num_dense_float_features: int >= 0")
+ .Attr("num_sparse_float_features: int >= 0")
+ .Attr("num_sparse_int_features: int >= 0")
+ .Attr("use_locking: bool = false")
+ .Attr("apply_dropout: bool")
+ .Attr("apply_averaging: bool")
+ .Attr("center_bias: bool")
+ .Attr("reduce_dim: bool")
+ .Input("tree_ensemble_handle: resource")
+ .Input("seed: int64")
+ .Input("dense_float_features: num_dense_float_features * float")
+ .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
+ .Input("sparse_float_feature_values: num_sparse_float_features * float")
+ .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
+ .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_values: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
+ .Output("predictions: float")
+ .Output("drop_out_tree_indices_weights: float")
+ .Output("leaf_index: int32")
+ .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn)
+ .Doc(R"doc(
+Runs multiple additive regression forests predictors on input instances
+and computes the final prediction for each class, and outputs a matrix of
+leaf ids per each tree in an ensemble.
+
+learner_config: Config for the learner of type LearnerConfig proto. Prediction
+ops for now uses only LearningRateDropoutDrivenConfig config from the learner.
+num_dense_float_features: Number of dense float features.
+num_sparse_float_features: Number of sparse float features.
+num_sparse_int_features: Number of sparse int features.
+use_locking: Whether to use locking.
+seed: random seed to be used for dropout.
+reduce_dim: whether to reduce the dimension (legacy impl) or not.
+apply_dropout: whether to apply dropout during prediction.
+apply_averaging: whether averaging of tree ensembles should take place. If set
+to true, will be based on AveragingConfig from learner_config.
+tree_ensemble_handle: The handle to the tree ensemble.
+dense_float_features: Rank 2 Tensors containing dense float feature values.
+sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
+sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
+sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
+sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
+sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
+sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
+predictions: Rank 2 Tensor containing predictions per example per class.
+drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
+leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up.
+)doc");
+
REGISTER_OP("GradientTreesPartitionExamples")
.Attr("num_dense_float_features: int >= 0")
.Attr("num_sparse_float_features: int >= 0")
diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
index 58f0d36b0f..7f6e55ae58 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
@@ -21,4 +21,5 @@ from __future__ import print_function
from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader
from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples
from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction
+from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 5dd2e0c7f2..47698d45c8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -58,6 +58,7 @@ NUM_LAYERS_ATTEMPTED = "num_layers"
NUM_TREES_ATTEMPTED = "num_trees"
NUM_USED_HANDLERS = "num_used_handlers"
USED_HANDLERS_MASK = "used_handlers_mask"
+LEAF_INDEX = "leaf_index"
_FEATURE_NAME_TEMPLATE = "%s_%d"
@@ -71,18 +72,24 @@ def _get_column_by_index(tensor, indices):
return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1])
-def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats,
- used_handlers):
+def _make_predictions_dict(stamp,
+ logits,
+ partition_ids,
+ ensemble_stats,
+ used_handlers,
+ leaf_index=None):
"""Returns predictions for the given logits and n_classes.
Args:
stamp: The ensemble stamp.
- logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1].
- that contains predictions when no dropout was applied.
+ logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that
+ contains predictions when no dropout was applied.
partition_ids: A rank 1 `Tensor` with shape [batch_size].
ensemble_stats: A TreeEnsembleStatsOp result tuple.
used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a
- boolean mask..
+ boolean mask.
+ leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that
+ contains leaf id for each example prediction.
Returns:
A dict of predictions.
@@ -95,6 +102,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats,
result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees
result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers
result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask
+ if leaf_index is not None:
+ result[LEAF_INDEX] = leaf_index
return result
@@ -268,7 +277,8 @@ class GradientBoostedDecisionTreeModel(object):
features,
logits_dimension,
feature_columns=None,
- use_core_columns=False):
+ use_core_columns=False,
+ output_leaf_index=False):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -276,13 +286,15 @@ class GradientBoostedDecisionTreeModel(object):
num_ps_replicas: Number of parameter server replicas, can be 0.
ensemble_handle: A handle to the ensemble variable.
center_bias: Whether to center the bias before growing trees.
- examples_per_layer: Number of examples to accumulate before growing
- a tree layer. It can also be a function that computes the number of
- examples based on the depth of the layer that's being built.
+ examples_per_layer: Number of examples to accumulate before growing a tree
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
learner_config: A learner config.
features: `dict` of `Tensor` objects.
logits_dimension: An int, the dimension of logits.
feature_columns: A list of feature columns.
+ output_leaf_index: A boolean variable indicating whether to output leaf
+ index into predictions dictionary.
Raises:
ValueError: if inputs are not valid.
@@ -359,6 +371,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.TREE_PER_CLASS and
learner_config.num_classes == 2)
+ self._output_leaf_index = output_leaf_index
def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
"""Runs prediction and returns a dictionary of the prediction results.
@@ -388,22 +401,44 @@ class GradientBoostedDecisionTreeModel(object):
# Make sure ensemble stats run. This will check that the ensemble has
# the right stamp.
with ops.control_dependencies(ensemble_stats):
- predictions, _ = prediction_ops.gradient_trees_prediction(
- ensemble_handle,
- seed,
- self._dense_floats,
- self._sparse_float_indices,
- self._sparse_float_values,
- self._sparse_float_shapes,
- self._sparse_int_indices,
- self._sparse_int_values,
- self._sparse_int_shapes,
- learner_config=self._learner_config_serialized,
- apply_dropout=apply_dropout,
- apply_averaging=mode != learn.ModeKeys.TRAIN,
- use_locking=True,
- center_bias=self._center_bias,
- reduce_dim=self._reduce_dim)
+ leaf_index = None
+ # Only used in infer (predict), not used in train and eval.
+ if self._output_leaf_index and mode == learn.ModeKeys.INFER:
+ predictions, _, leaf_index = (
+ prediction_ops).gradient_trees_prediction_verbose(
+ ensemble_handle,
+ seed,
+ self._dense_floats,
+ self._sparse_float_indices,
+ self._sparse_float_values,
+ self._sparse_float_shapes,
+ self._sparse_int_indices,
+ self._sparse_int_values,
+ self._sparse_int_shapes,
+ learner_config=self._learner_config_serialized,
+ apply_dropout=apply_dropout,
+ apply_averaging=mode != learn.ModeKeys.TRAIN,
+ use_locking=True,
+ center_bias=self._center_bias,
+ reduce_dim=self._reduce_dim)
+ else:
+ leaf_index = None
+ predictions, _ = prediction_ops.gradient_trees_prediction(
+ ensemble_handle,
+ seed,
+ self._dense_floats,
+ self._sparse_float_indices,
+ self._sparse_float_values,
+ self._sparse_float_shapes,
+ self._sparse_int_indices,
+ self._sparse_int_values,
+ self._sparse_int_shapes,
+ learner_config=self._learner_config_serialized,
+ apply_dropout=apply_dropout,
+ apply_averaging=mode != learn.ModeKeys.TRAIN,
+ use_locking=True,
+ center_bias=self._center_bias,
+ reduce_dim=self._reduce_dim)
partition_ids = prediction_ops.gradient_trees_partition_examples(
ensemble_handle,
self._dense_floats,
@@ -416,7 +451,7 @@ class GradientBoostedDecisionTreeModel(object):
use_locking=True)
return _make_predictions_dict(ensemble_stamp, predictions, partition_ids,
- ensemble_stats, used_handlers)
+ ensemble_stats, used_handlers, leaf_index)
def predict(self, mode):
"""Returns predictions given the features and mode.
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 289fb195db..e3d4397fad 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -19,18 +19,15 @@ from __future__ import division
from __future__ import print_function
from google.protobuf import text_format
-
from tensorflow.contrib import layers
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.ops import model_ops
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.boosted_trees.python.utils import losses
-
-from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn
-
+from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
@@ -782,6 +779,118 @@ class GbdtTest(test_util.TensorFlowTestCase):
[[0.25], [0.25], [0.25], [0.25]])
self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
+ def testPredictFnWithLeafIndexAdvancedLeft(self):
+ """Tests the predict function with output leaf ids."""
+ with self.test_session() as sess:
+ # Create ensemble with one bias node.
+ ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ dense_float_binary_split {
+ threshold: 1.0
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 0
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.15
+ }
+ }
+ }
+ }
+ trees {
+ nodes {
+ dense_float_binary_split {
+ threshold: 0.99
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 00
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.23
+ }
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }""", ensemble_config)
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=3,
+ tree_ensemble_config=ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
+ learner_config.num_classes = 2
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.constraints.min_node_weight = 0
+ features = {}
+ features["dense_float"] = array_ops.constant(
+ [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32)
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=False,
+ num_ps_replicas=0,
+ center_bias=True,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features,
+ output_leaf_index=True)
+
+ # Create predict op.
+ mode = model_fn.ModeKeys.INFER
+ predictions_dict = sess.run(gbdt_model.predict(mode))
+ self.assertEquals(predictions_dict["ensemble_stamp"], 3)
+ # here are how the numbers in expected results are calculated,
+ # 0.5 = 0.25 + 0.25
+ # 0.48 = 0.25 + 0.23
+ # 0.38 = 0.15 + 0.23
+ # 0.38 = 0.15 + 0.23
+ self.assertAllClose(predictions_dict["predictions"],
+ [[0.5], [0.48], [0.38], [0.38]])
+ self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
+ self.assertAllClose(predictions_dict["leaf_index"],
+ [[1, 1], [1, 2], [2, 2], [2, 2]])
+
def testTrainFnMulticlassFullHessian(self):
"""Tests the GBDT train for multiclass full hessian."""
with self.test_session() as sess:
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index a5a9630a4a..3a1d90e77d 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -256,6 +256,10 @@ class TPUClusterResolver(ClusterResolver):
request = self._service.projects().locations().nodes().get(name=full_name)
response = request.execute()
+ if 'state' in response and response['state'] != 'READY':
+ raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
+ (self._tpu, response['state']))
+
if 'health' in response and response['health'] != 'HEALTHY':
raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
response['health']))
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 5fac55fd02..86e9d9ddad 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase):
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testUnhealthyCloudTpu(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470',
+ 'health': 'UNHEALTHY'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project=None,
+ zone=None,
+ tpu='test-tpu-1',
+ coordinator_name=None,
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ with self.assertRaises(RuntimeError):
+ tpu_cluster_resolver.cluster_spec()
+
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testNotReadyCloudTpu(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470',
+ 'state': 'CREATING'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project=None,
+ zone=None,
+ tpu='test-tpu-1',
+ coordinator_name=None,
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ with self.assertRaises(RuntimeError):
+ tpu_cluster_resolver.cluster_spec()
+
def testSimpleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 1959ad028a..9244604489 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -756,6 +756,8 @@ add_custom_command(
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
+ "--package=tensorflow.python"
+ "--apiname=tensorflow"
"${api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
@@ -765,7 +767,49 @@ add_custom_command(
add_custom_target(tf_python_api SOURCES ${api_init_files})
add_dependencies(tf_python_api tf_python_ops)
+# TODO(mikecase): This can be removed once tf.estimator is moved
+# out of TensorFlow.
+########################################################
+# Generate API __init__.py files for tf.estimator.
+########################################################
+
+# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
+FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
+STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text})
+string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "," ";" api_init_files_list ${api_init_files_text})
+
+set(api_init_files "")
+foreach(api_init_file ${api_init_files_list})
+ string(STRIP "${api_init_file}" api_init_file)
+ if(api_init_file)
+ string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes
+ list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}")
+ endif()
+endforeach(api_init_file)
+set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt")
+file(WRITE "${estimator_api_init_list_file}" "${api_init_files}")
+
+# Run create_python_api.py to generate __init__.py files.
+add_custom_command(
+ OUTPUT ${api_init_files}
+ DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
+
+ # Run create_python_api.py to generate API init files.
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api"
+ "--package=tensorflow.python.estimator"
+ "--apiname=estimator"
+ "${estimator_api_init_list_file}"
+
+ COMMENT "Generating __init__.py files for Python API."
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
+)
+add_custom_target(estimator_python_api SOURCES ${api_init_files})
+add_dependencies(estimator_python_api tf_python_ops)
############################################################
# Build a PIP package containing the TensorFlow runtime.
############################################################
@@ -776,6 +820,7 @@ add_dependencies(tf_python_build_pip_package
tf_python_touchup_modules
tf_python_ops
tf_python_api
+ estimator_python_api
tf_extension_ops)
# Fix-up Python files that were not included by the add_python_module() macros.
diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD
index 746b5b5b5e..e8036d63ae 100644
--- a/tensorflow/contrib/control_flow/BUILD
+++ b/tensorflow/contrib/control_flow/BUILD
@@ -20,13 +20,16 @@ py_library(
srcs = ["python/cond_v2.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:c_api_util",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
+ "//tensorflow/python:function_def_to_graph",
"//tensorflow/python:functional_ops_gen",
"//tensorflow/python:gradients",
"//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:util",
],
)
@@ -42,7 +45,9 @@ tf_py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
+ "//tensorflow/python:training",
],
grpc_enabled = True,
)
diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py
index 90c678d0f6..9ffad9caa9 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2.py
@@ -26,10 +26,12 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import function
+from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.util import compat
# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
@@ -78,28 +80,21 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
_create_new_tf_function(false_graph),
name=scope)
- # TODO(b/79883549): if we could make Graphs from FunctionDefs, we wouldn't
- # need this extra state. Requiring extra state also prevents the ability to
- # take the gradient of deserialized If ops.
- tensors[0].op._true_graph = true_graph
- tensors[0].op._false_graph = false_graph
-
return tensors[:num_cond_outputs]
@ops.RegisterGradient("If")
def _IfGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of an If op produced by cond_v2."""
- true_graph = op._true_graph
- false_graph = op._false_graph
+ true_graph, false_graph = _get_func_graphs(op)
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
# functions.
true_grad_graph = _create_grad_func(
- true_graph, grads, "%sgrad" % true_graph.name)
+ true_graph, grads, _get_grad_fn_name(true_graph))
false_grad_graph = _create_grad_func(
- false_graph, grads, "%sgrad" % false_graph.name)
+ false_graph, grads, _get_grad_fn_name(false_graph))
assert ([t.dtype for t in true_grad_graph.outputs] ==
[t.dtype for t in false_grad_graph.outputs])
@@ -136,13 +131,35 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
_create_new_tf_function(true_grad_graph),
_create_new_tf_function(false_grad_graph))
- tensors[0].op._true_graph = true_grad_graph
- tensors[0].op._false_graph = false_grad_graph
# The predicate has no gradient.
return [None] + tensors[:num_grad_outputs]
+def _get_func_graphs(if_op):
+ """Returns `_FuncGraph`s for the input op branches.
+
+ Args:
+ if_op: The _If Operation.
+
+ Returns:
+ A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch.
+ """
+ def _get_func_graph_for_branch(branch_name):
+ extra_inputs = if_op.inputs[1:] # First input is pred.
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = if_op.get_attr(branch_name).name
+ fdef = if_op.graph._get_function(func_name).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+ func_graph.extra_inputs = extra_inputs
+ func_graph.extra_args = func_graph.inputs
+ func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ return func_graph
+
+ return (_get_func_graph_for_branch("then_branch"),
+ _get_func_graph_for_branch("else_branch"))
+
+
def _grad_fn(func_graph, grads):
"""The gradient function for each conditional branch.
@@ -242,10 +259,9 @@ def _create_new_tf_function(func_graph):
Returns:
The name of the new TF_Function.
"""
- func_graph.name = "%s_" % func_graph.name
c_func = c_api.TF_GraphToFunction_wrapper(
func_graph._c_graph,
- func_graph.name,
+ compat.as_str(func_graph.name),
False, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in func_graph.inputs],
@@ -253,9 +269,15 @@ def _create_new_tf_function(func_graph):
[],
None, # opts
None) # description
- c_func = c_api_util.ScopedTFFunction(c_func)
- c_api.TF_GraphCopyFunction(
- ops.get_default_graph()._c_graph, c_func.func, None)
+ _ = c_api_util.ScopedTFFunction(c_func)
+
+ # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
+ # deserializing it into a Python FunctionDef, then reserializing it to create
+ # a new TF_Function that we add to the graph.
+ fdef = function.function_def_from_tf_function(c_func)
+ defined_func = function._from_definition(fdef)
+ defined_func.add_to_graph(ops.get_default_graph())
+
return func_graph.name
@@ -381,6 +403,19 @@ def _create_dummy_params(func_graph, template_tensors):
for t in template_tensors]
+def _get_grad_fn_name(func_graph):
+ """Returns a unique name to use for the grad function of `func_graph`."""
+ name = "%s_grad" % func_graph.name
+
+ base_name = name
+ counter = 1
+ if ops.get_default_graph()._is_function(name):
+ name = "%s_%s" % (base_name, counter)
+ counter += 1
+
+ return name
+
+
def _check_same_outputs(true_graph, false_graph):
"""Raises an error if true_graph and false_graph have different outputs."""
true_output_types = [t.dtype for t in true_graph.outputs]
diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py
index 166002ca7f..338601aa2c 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2_test.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py
@@ -22,11 +22,13 @@ from __future__ import print_function
from tensorflow.contrib.control_flow.python import cond_v2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+from tensorflow.python.training import saver
class NewCondTest(test.TestCase):
@@ -79,8 +81,22 @@ class NewCondTest(test.TestCase):
self._testCond(true_fn, false_fn, [x, y])
self._testCond(true_fn, false_fn, [y])
+ def testNoInputs(self):
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+
+ def true_fn():
+ return constant_op.constant(1.0)
+
+ def false_fn():
+ return constant_op.constant(2.0)
+
+ out = cond_v2.cond_v2(pred, true_fn, false_fn)
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(out, {pred: True}), [1.0])
+ self.assertEqual(sess.run(out, {pred: False}), [2.0])
+
def testSecondDerivative(self):
- self.skipTest("b/109758172")
pred = array_ops.placeholder(dtypes.bool, name="pred")
x = constant_op.constant(3.0, name="x")
@@ -109,6 +125,47 @@ class NewCondTest(test.TestCase):
# d2[x]/dx2 = 0
self.assertEqual(false_val, [0.0])
+ def testGradientOfDeserializedCond(self):
+ with ops.Graph().as_default():
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+ x = constant_op.constant(3.0, name="x")
+ ops.add_to_collection("x", x)
+
+ def true_fn():
+ return math_ops.pow(x, 3)
+
+ def false_fn():
+ return x
+
+ ops.add_to_collection("pred", pred)
+ cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
+ for c in cond:
+ ops.add_to_collection("cond", c)
+ meta_graph = saver.export_meta_graph()
+
+ with ops.Graph().as_default() as g:
+ saver.import_meta_graph(meta_graph)
+ x = ops.get_collection("x")[0]
+ pred = ops.get_collection("pred")[0]
+ cond = ops.get_collection("cond")
+ cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
+ cond_grad_grad = gradients_impl.gradients(
+ cond_grad, [x], name="cond_grad_grad")
+ with self.test_session(graph=g) as sess:
+ # d[x^3]/dx = 3x^2
+ true_val = sess.run(cond_grad, {pred: True})
+ self.assertEqual(true_val, [27.0])
+ # d[x]/dx = 1
+ false_val = sess.run(cond_grad, {pred: False})
+ self.assertEqual(false_val, [1.0])
+
+ true_val = sess.run(cond_grad_grad, {pred: True})
+ # d2[x^3]/dx2 = 6x
+ self.assertEqual(true_val, [18.0])
+ false_val = sess.run(cond_grad_grad, {pred: False})
+ # d2[x]/dx2 = 0
+ self.assertEqual(false_val, [0.0])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index e88ad3dc32..4657807785 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -236,7 +236,7 @@ class CSVDatasetOp : public DatasetOpKernel {
size_t num_parsed = 0;
size_t num_selected_parsed = 0;
- Status result = Status::OK();
+ Status result;
while (!end_of_record) { // Read till we reach \n, \r or EOF
bool include =
@@ -329,6 +329,7 @@ class CSVDatasetOp : public DatasetOpKernel {
size_t start = pos_;
pos_++; // Starting quotation mark
+ Status parse_result;
while (true) { // Each iter reads 1 char, filling buffer if necessary
if (pos_ >= buffer_.size()) {
Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
@@ -351,8 +352,9 @@ class CSVDatasetOp : public DatasetOpKernel {
if (errors::IsOutOfRange(s)) {
// This was the last field. We are done
*end_of_record = true;
- return QuotedFieldToOutput(ctx, StringPiece(), out_tensors,
- earlier_pieces, include);
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(), out_tensors, earlier_pieces, include));
+ return parse_result;
} else if (!s.ok()) {
return s;
}
@@ -361,20 +363,24 @@ class CSVDatasetOp : public DatasetOpKernel {
char next = buffer_[pos_];
pos_++;
if (next == dataset()->delim_) {
- return QuotedFieldToOutput(
+ parse_result.Update(QuotedFieldToOutput(
ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
- out_tensors, earlier_pieces, include);
+ out_tensors, earlier_pieces, include));
+ return parse_result;
} else if (next == '\n' || next == '\r') {
*end_of_record = true;
- Status s = QuotedFieldToOutput(
+ parse_result.Update(QuotedFieldToOutput(
ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
- out_tensors, earlier_pieces, include);
+ out_tensors, earlier_pieces, include));
if (next == '\r') SkipNewLineIfNecessary();
- return s;
+ return parse_result;
} else if (next != '"') {
- return errors::InvalidArgument(
- "Quote inside a string has to be escaped by another quote");
+ // Take note of the error, but keep going to end of field.
+ include = false; // So we don't get funky errors when trying to
+ // unescape the quotes.
+ parse_result.Update(errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another quote"));
}
} else {
@@ -454,6 +460,8 @@ class CSVDatasetOp : public DatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<Piece> earlier_pieces;
size_t start = pos_;
+ Status parse_result;
+
while (true) { // Each iter reads 1 char, filling buffer if necessary
if (pos_ >= buffer_.size()) {
Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
@@ -461,9 +469,10 @@ class CSVDatasetOp : public DatasetOpKernel {
if (errors::IsOutOfRange(s)) {
// Whatever we have is the last field of the last record
*end_of_record = true;
- return UnquotedFieldToOutput(
+ parse_result.Update(UnquotedFieldToOutput(
ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include);
+ earlier_pieces, include));
+ return parse_result;
} else if (!s.ok()) {
return s; // Surface all other errors to caller
}
@@ -472,66 +481,33 @@ class CSVDatasetOp : public DatasetOpKernel {
char ch = buffer_[pos_];
if (ch == dataset()->delim_) {
- Status s = UnquotedFieldToOutput(
+ parse_result.Update(UnquotedFieldToOutput(
ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include);
+ earlier_pieces, include));
pos_++;
- return s;
+ return parse_result;
}
if (ch == '\n' || ch == '\r') {
// need special case to skip over first \n of record if the line
// breaks are \r\n
- Status s = UnquotedFieldToOutput(
+ parse_result.Update(UnquotedFieldToOutput(
ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include);
+ earlier_pieces, include));
*end_of_record = true;
pos_++;
if (ch == '\r') SkipNewLineIfNecessary();
- return s;
+ return parse_result;
}
if (dataset()->use_quote_delim_ && ch == '"') {
- // Advance pos_ to the next field anyway so that we can ignore
- // errors gracefully if required. The caller of this will be able to
- // call ParseOneField and continue with the rest of the record.
- AdvanceToNextField(end_of_record);
- return errors::InvalidArgument(
- "Unquoted fields cannot have quotes inside");
+ // Take note of the error, but keep going to end of field.
+ parse_result.Update(errors::InvalidArgument(
+ "Unquoted fields cannot have quotes inside"));
}
// Otherwise, go to next character
pos_++;
}
}
- // Advances pos_ to the start of the next field, as delimited by delim,
- // CRLF, or EOF, ignoring errors, and not keeping track of characters in
- // the current field.
- void AdvanceToNextField(bool* end_of_record)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- while (true) {
- if (pos_ >= buffer_.size()) {
- Status s = FillBuffer(&buffer_);
- pos_ = 0;
- if (!s.ok()) {
- *end_of_record = true;
- return;
- }
- }
-
- char ch = buffer_[pos_];
- pos_++;
-
- if (ch == dataset()->delim_) {
- return;
- }
-
- if (ch == '\n' || ch == '\r') {
- *end_of_record = true;
- if (ch == '\r') SkipNewLineIfNecessary();
- return;
- }
- }
- }
-
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
result->clear();
Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result);
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ba707d8d6e..be834d7dfd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -330,6 +330,26 @@ py_test(
],
)
+py_library(
+ name = "reader_dataset_ops_test_base",
+ testonly = 1,
+ srcs = [
+ "reader_dataset_ops_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
py_test(
name = "reader_dataset_ops_test",
size = "medium",
@@ -339,8 +359,8 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
+ ":reader_dataset_ops_test_base",
"//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -352,6 +372,7 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
"//third_party/py/numpy",
],
)
@@ -441,6 +462,7 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -448,6 +470,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
@@ -478,10 +501,15 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
+ ":reader_dataset_ops_test_base",
"//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 74b90ec7d1..97b5e94165 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -162,9 +162,28 @@ class CsvDatasetOpTest(test.TestCase):
expected_err_re='Unquoted fields cannot have quotes inside',
record_defaults=record_defaults)
+ def testCsvDataset_errWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['"a"b","c","d"']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Quote inside a string has to be escaped by another quote',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
+ filenames = self.setup_files(inputs)
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3
- inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']]
+ inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
filenames = self.setup_files(inputs)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index 78ecce8f7d..393f08850b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase):
ckpt_saved=False,
init_before_restore=False,
sparse_tensors=False,
- verify_exhausted=True):
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
"""Generates elements from input dataset while stopping at break points.
Produces `num_outputs` outputs and saves the state of the iterator in the
@@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase):
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
Returns:
A list of `num_outputs` items.
@@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase):
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
- self._save(sess, saver)
- ckpt_saved = True
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
return outputs
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index e0237198b7..3b07ef290b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -24,9 +24,8 @@ import zlib
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
-from tensorflow.core.example import example_pb2
-from tensorflow.core.example import feature_pb2
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
@@ -280,163 +279,8 @@ def _interleave(iterators, cycle_length):
num_open -= 1
-class ReadBatchFeaturesTest(test.TestCase):
-
- def setUp(self):
- super(ReadBatchFeaturesTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self.test_filenames = self._createFiles()
-
- def _read_batch_features(self,
- filenames,
- num_epochs,
- batch_size,
- reader_num_threads=1,
- parser_num_threads=1,
- shuffle=False,
- shuffle_seed=None,
- drop_final_batch=False):
- self.filenames = filenames
- self.num_epochs = num_epochs
- self.batch_size = batch_size
-
- return readers.make_batched_features_dataset(
- file_pattern=self.filenames,
- batch_size=self.batch_size,
- features={
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string)
- },
- reader=core_readers.TFRecordDataset,
- num_epochs=self.num_epochs,
- shuffle=shuffle,
- shuffle_seed=shuffle_seed,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads,
- drop_final_batch=drop_final_batch).make_one_shot_iterator(
- ).get_next()
-
- def _record(self, f, r):
- example = example_pb2.Example(
- features=feature_pb2.Features(
- feature={
- "file":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[f])),
- "record":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[r])),
- "keywords":
- feature_pb2.Feature(
- bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r)))
- }))
- return example.SerializeToString()
-
- def _get_keywords(self, f, r):
- num_keywords = 1 + (f + r) % 2
- keywords = []
- for index in range(num_keywords):
- keywords.append(compat.as_bytes("keyword%d" % index))
- return keywords
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j))
- writer.close()
- return filenames
-
- def _run_actual_batch(self, outputs, sess):
- file_op = outputs["file"]
- keywords_indices_op = outputs["keywords"].indices
- keywords_values_op = outputs["keywords"].values
- keywords_dense_shape_op = outputs["keywords"].dense_shape
- record_op = outputs["record"]
- return sess.run([
- file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op
- ])
-
- def _next_actual_batch(self, sess):
- return self._run_actual_batch(self.outputs, sess)
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length=1):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i
-
- def _next_record_interleaved(file_indices, cycle_length):
- return _interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for record in next_records:
- f = record[0]
- r = record[1]
- file_batch.append(f)
- record_batch.append(r)
- keywords = self._get_keywords(f, r)
- keywords_batch_values.extend(keywords)
- keywords_batch_indices.extend(
- [[batch_index, i] for i in range(len(keywords))])
- batch_index += 1
- keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
- if len(file_batch) == batch_size:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch
- ]
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- if file_batch:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch
- ]
-
- def _verify_records(self,
- sess,
- batch_size,
- file_index=None,
- num_epochs=1,
- interleave_cycle_length=1):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length):
- actual_batch = self._next_actual_batch(sess)
- for i in range(len(expected_batch)):
- self.assertAllEqual(expected_batch[i], actual_batch[i])
+class ReadBatchFeaturesTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
def testRead(self):
for batch_size in [1, 2]:
@@ -444,33 +288,33 @@ class ReadBatchFeaturesTest(test.TestCase):
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Basic test: read from file 0.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
- batch_size=batch_size)
- self._verify_records(sess, batch_size, 0, num_epochs=num_epochs)
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, 0, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Basic test: read from file 1.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
num_epochs=num_epochs,
- batch_size=batch_size)
- self._verify_records(sess, batch_size, 1, num_epochs=num_epochs)
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, 1, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Basic test: read from both files.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
- batch_size=batch_size)
- self._verify_records(sess, batch_size, num_epochs=num_epochs)
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
@@ -504,18 +348,18 @@ class ReadBatchFeaturesTest(test.TestCase):
# Test that shuffling with same seed produces the same result.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- outputs1 = self._read_batch_features(
+ outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=5)
- outputs2 = self._read_batch_features(
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=5)
+ shuffle_seed=5).make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
batch2 = self._run_actual_batch(outputs2, sess)
@@ -525,18 +369,18 @@ class ReadBatchFeaturesTest(test.TestCase):
# Test that shuffling with different seeds produces a different order.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- outputs1 = self._read_batch_features(
+ outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=5)
- outputs2 = self._read_batch_features(
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=15)
+ shuffle_seed=15).make_one_shot_iterator().get_next()
all_equal = True
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
@@ -552,13 +396,14 @@ class ReadBatchFeaturesTest(test.TestCase):
for parser_num_threads in [2, 4]:
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads)
- self._verify_records(
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
@@ -571,11 +416,11 @@ class ReadBatchFeaturesTest(test.TestCase):
for num_epochs in [1, 10]:
with ops.Graph().as_default():
# Basic test: read from file 0.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
- drop_final_batch=True)
+ drop_final_batch=True).make_one_shot_iterator().get_next()
for _, tensor in self.outputs.items():
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
new file mode 100644
index 0000000000..805a7c7b73
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -0,0 +1,218 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class ReadBatchFeaturesTestBase(test.TestCase):
+ """Base class for setting up and testing `make_batched_feature_dataset`."""
+
+ def setUp(self):
+ super(ReadBatchFeaturesTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self.test_filenames = self._createFiles()
+
+ def make_batch_feature(self,
+ filenames,
+ num_epochs,
+ batch_size,
+ reader_num_threads=1,
+ parser_num_threads=1,
+ shuffle=False,
+ shuffle_seed=None,
+ drop_final_batch=False):
+ self.filenames = filenames
+ self.num_epochs = num_epochs
+ self.batch_size = batch_size
+
+ return readers.make_batched_features_dataset(
+ file_pattern=self.filenames,
+ batch_size=self.batch_size,
+ features={
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "keywords": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ reader=core_readers.TFRecordDataset,
+ num_epochs=self.num_epochs,
+ shuffle=shuffle,
+ shuffle_seed=shuffle_seed,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads,
+ drop_final_batch=drop_final_batch)
+
+ def _record(self, f, r):
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ "file":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[f])),
+ "record":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[r])),
+ "keywords":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=self._get_keywords(f, r)))
+ }))
+ return example.SerializeToString()
+
+ def _get_keywords(self, f, r):
+ num_keywords = 1 + (f + r) % 2
+ keywords = []
+ for index in range(num_keywords):
+ keywords.append(compat.as_bytes("keyword%d" % index))
+ return keywords
+
+ def _sum_keywords(self, num_files):
+ sum_keywords = 0
+ for i in range(num_files):
+ for j in range(self._num_records):
+ sum_keywords += 1 + (i + j) % 2
+ return sum_keywords
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j))
+ writer.close()
+ return filenames
+
+ def _run_actual_batch(self, outputs, sess):
+ file_op = outputs["file"]
+ keywords_indices_op = outputs["keywords"].indices
+ keywords_values_op = outputs["keywords"].values
+ keywords_dense_shape_op = outputs["keywords"].dense_shape
+ record_op = outputs["record"]
+ return sess.run([
+ file_op, keywords_indices_op, keywords_values_op,
+ keywords_dense_shape_op, record_op
+ ])
+
+ def _next_actual_batch(self, sess):
+ return self._run_actual_batch(self.outputs, sess)
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=1):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for record in next_records:
+ f = record[0]
+ r = record[1]
+ file_batch.append(f)
+ record_batch.append(r)
+ keywords = self._get_keywords(f, r)
+ keywords_batch_values.extend(keywords)
+ keywords_batch_indices.extend(
+ [[batch_index, i] for i in range(len(keywords))])
+ batch_index += 1
+ keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
+ if len(file_batch) == batch_size:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [batch_size, keywords_batch_max_len], record_batch
+ ]
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ if file_batch:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [len(file_batch), keywords_batch_max_len], record_batch
+ ]
+
+ def verify_records(self,
+ sess,
+ batch_size,
+ file_index=None,
+ num_epochs=1,
+ interleave_cycle_length=1):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length):
+ actual_batch = self._next_actual_batch(sess)
+ for i in range(len(expected_batch)):
+ self.assertAllEqual(expected_batch[i], actual_batch[i])
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index bcc644c097..1b67a33f04 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -20,11 +20,13 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
class ShuffleDatasetSerializationTest(
@@ -50,26 +52,100 @@ class ShuffleDatasetSerializationTest(
num_repeats = 5
num_outputs = range_limit * num_repeats
buffer_sizes = [1, 3, 8, 10, 25, 50]
- reshuffle_each_iteration = False
# pylint: disable=cell-var-from-loop
# pylint: disable=g-long-lambda
- for buffer_size in buffer_sizes:
- self.run_core_tests(
- lambda: self._build_shuffle_dataset(
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+ self.run_core_tests(
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=10,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ num_outputs)
+ # pylint: enable=cell-var-from-loop
+ # pylint: enable=g-long-lambda
+
+ def testNonDeterministicSeeding(self):
+
+ range_limit = 10
+ num_repeats = 5
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 8, 10, 25, 50]
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration),
- lambda: self._build_shuffle_dataset(
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ # We checkpoint the initial state of the Dataset so that we can restore
+ # the seeds in the next run. Since the seeding is non-deterministic
+ # the dataset gets initialized with different seeds each time.
+ expected = self.gen_outputs(
+ ds_fn,
+ break_points=[0],
+ num_outputs=num_outputs,
+ ckpt_saved=False,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points=self.gen_break_points(num_outputs),
+ num_outputs=num_outputs,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.match(expected, actual)
+
+ def testMultipleIterators(self):
+ range_limit = 10
+ num_repeats = 5
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 8, 10, 25, 50]
+
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
- seed=10,
- reshuffle_each_iteration=reshuffle_each_iteration),
- num_outputs)
- # pylint: enable=cell-var-from-loop
- # pylint: enable=g-long-lambda
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ with ops.Graph().as_default() as g:
+ ds = ds_fn()
+ iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
+ get_next_ops = [it.get_next() for it in iterators]
+ saveables = [
+ contrib_iterator_ops.make_saveable_from_iterator(it)
+ for it in iterators
+ ]
+ for saveable in saveables:
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ saver = saver_lib.Saver(allow_empty=True)
+ with self.test_session(graph=g) as sess:
+ self._save(sess, saver)
+ expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self._restore(saver, sess)
+ actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self.match(expected, actual)
class ShuffleAndRepeatTest(
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 5c74ed6ae7..17b6644759 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.ops import dataset_ops
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class StatsDatasetTest(test.TestCase):
+class StatsDatasetTestBase(test.TestCase):
def _assertSummaryHasCount(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
@@ -49,6 +50,9 @@ class StatsDatasetTest(test.TestCase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+class StatsDatasetTest(StatsDatasetTestBase):
+
def testBytesProduced(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
@@ -193,6 +197,45 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+class FeatureStatsDatasetTest(
+ StatsDatasetTestBase,
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
+
+ def testFeaturesStats(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ batch_size = 2
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5,
+ drop_final_batch=True).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(total_records // batch_size):
+ sess.run(next_element)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats:features", total_records)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats:feature-values", total_records)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats:features", total_records * 3)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats:feature-values",
+ self._sum_keywords(1) * num_epochs + 2 * total_records)
+
+
class StatsDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 086661adb7..33b7a75046 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -96,8 +96,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
+ ":gen_dataset_ops",
":interleave_ops",
":shuffle_ops",
+ ":stats_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
@@ -106,12 +108,12 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
@@ -142,6 +144,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index b9393de4e9..17256eb972 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.framework import with_shape
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import deprecation
def dense_to_sparse_batch(batch_size, row_shape):
@@ -218,6 +220,8 @@ def filter_irregular_batches(batch_size):
return _apply_fn
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.")
def batch_and_drop_remainder(batch_size):
"""A batching transformation that omits the final small batch (if present).
@@ -250,12 +254,16 @@ def batch_and_drop_remainder(batch_size):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
+ # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time
+ # after 6/30/2018.
batched = dataset.batch(batch_size)
return filter_irregular_batches(batch_size)(batched)
return _apply_fn
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.")
def padded_batch_and_drop_remainder(batch_size,
padded_shapes,
padding_values=None):
@@ -284,6 +292,8 @@ def padded_batch_and_drop_remainder(batch_size,
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
+ # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)`
+ # any time after 6/30/2018.
batched = dataset.padded_batch(
batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
return filter_irregular_batches(batch_size)(batched)
@@ -309,7 +319,7 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset):
return gen_dataset_ops.dense_to_sparse_batch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._batch_size,
- row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access
+ row_shape=convert.partial_shape_to_tensor(self._row_shape),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
output_types=nest.flatten(
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index ea229b5b27..520f784228 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -300,6 +300,7 @@ class GroupByReducerDataset(dataset_ops.Dataset):
raise ValueError(
"`key_func` must return a single tf.int64 tensor. "
"Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
return ret
self._key_func = tf_key_func
@@ -327,6 +328,8 @@ class GroupByReducerDataset(dataset_ops.Dataset):
self._state_types = nest.pack_sequence_as(
ret, [t.dtype for t in nest.flatten(ret)])
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
+
# Serialize any sparse tensors.
ret = nest.pack_sequence_as(
ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
@@ -398,6 +401,8 @@ class GroupByReducerDataset(dataset_ops.Dataset):
nest.pack_sequence_as(self._state_types,
[t.dtype for t in flat_new_state])))
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
+
# Serialize any sparse tensors.
ret = nest.pack_sequence_as(
ret,
@@ -464,6 +469,8 @@ class GroupByReducerDataset(dataset_ops.Dataset):
self._output_types = nest.pack_sequence_as(
ret, [t.dtype for t in nest.flatten(ret)])
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
+
# Serialize any sparse tensors.
ret = nest.pack_sequence_as(
ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
@@ -525,6 +532,7 @@ class GroupByWindowDataset(dataset_ops.Dataset):
if window_size.dtype != dtypes.int64:
raise ValueError(
"`window_size_func` must return a single tf.int64 tensor.")
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access
return window_size
self._window_size_func = tf_window_size_func
@@ -557,6 +565,7 @@ class GroupByWindowDataset(dataset_ops.Dataset):
ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
if ret.dtype != dtypes.int64:
raise ValueError("`key_func` must return a single tf.int64 tensor.")
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access
return ret
self._key_func = tf_key_func
@@ -580,6 +589,7 @@ class GroupByWindowDataset(dataset_ops.Dataset):
self._output_classes = output_dataset.output_classes
self._output_types = output_dataset.output_types
self._output_shapes = output_dataset.output_shapes
+ dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access
return output_dataset._as_variant_tensor() # pylint: disable=protected-access
self._reduce_func = tf_reduce_func
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index f938153f5f..83095c7ba1 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import convert
@@ -754,6 +755,8 @@ def make_batched_features_dataset(file_pattern,
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+ dataset = dataset.apply(stats_ops.feature_stats("record_stats"))
+
if drop_final_batch:
dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
else:
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index e911ad0fa0..9909ca8d9d 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -148,6 +148,8 @@ class _ScanDataset(dataset_ops.Dataset):
self._output_types = nest.pack_sequence_as(
output_value, [t.dtype for t in nest.flatten(output_value)])
+ dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access
+
# Serialize any sparse tensors.
new_state = nest.pack_sequence_as(new_state, [
t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3cbaab5aff..8c30202ba7 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -176,6 +176,27 @@ def latency_stats(tag):
return _apply_fn
+def feature_stats(tag):
+ """Records the features stats from `Example` records of the input dataset.
+
+ To consume the statistics, associate a `StatsAggregator` with the output
+ dataset.
+
+ Args:
+ tag: String. All statistics recorded by the returned transformation will be
+ associated with the given `tag`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag)
+
+ return _apply_fn
+
+
class _StatsDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index a91c54153f..b572512bbb 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -148,6 +148,7 @@ py_library(
],
deps = [
":mirrored_strategy",
+ ":multi_worker_strategy",
":one_device_strategy",
":tpu_strategy",
"//tensorflow/contrib/optimizer_v2:training",
@@ -311,7 +312,6 @@ cuda_py_test(
tags = [
"multi_and_single_gpu",
"no_pip",
- "noguitar", # TODO(b/109653107): test is flaky.
],
)
@@ -447,8 +447,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":values",
+ "//tensorflow/contrib/all_reduce:all_reduce_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
@@ -496,6 +498,7 @@ cuda_py_test(
additional_deps = [
":combinations",
":cross_tower_ops",
+ ":multi_worker_test_base",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
@@ -505,6 +508,7 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
+ shard_count = 15,
tags = [
"multi_and_single_gpu",
"no_pip",
@@ -582,6 +586,7 @@ cuda_py_test(
],
tags = [
"multi_and_single_gpu",
+ "noguitar",
"notsan",
],
)
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 98e7228f24..ba03b14deb 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -47,6 +47,7 @@ from absl.testing import parameterized
import six
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
+from tensorflow.contrib.distribute.python import multi_worker_strategy
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
@@ -338,6 +339,34 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)
+multi_worker_strategy_with_cpu = NamedDistribution(
+ "MultiWorkerCPU",
+ lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={
+ "worker": [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ },
+ num_gpus_per_worker=0), 0)
+multi_worker_strategy_with_one_gpu = NamedDistribution(
+ "MultiWorker1GPU",
+ lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={
+ "worker": [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ },
+ num_gpus_per_worker=1), 1)
+multi_worker_strategy_with_two_gpus = NamedDistribution(
+ "MultiWorker2GPUs",
+ lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={
+ "worker": [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ },
+ num_gpus_per_worker=2), 2)
+
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index a411b880e8..f8ae8b9712 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import six
from tensorflow.contrib.distribute.python import cross_tower_utils
@@ -234,7 +235,13 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
def _group_value_by_device(per_device_values):
"""Group values into sublists by their devices.
- This grouping is needed to call the all-reduce library.
+ This grouping is needed to call the all-reduce library because it expects a
+ list of the following form:
+ [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...
+ (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...
+ (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...
+ ...
+ ]
Args:
per_device_values: a list of PerDevice obejcts.
@@ -322,7 +329,17 @@ class ConcatAndSplitPacker(object):
# TODO(zhengxq): it is also possible to optimize away all the concat
# as well.
num_splits = self.num_packs
- total_grad_size = array_ops.size(concat_grads)
+
+ # The array_ops.size function will sometimes remove static shapes. So if
+ # all gradient shapes are defined, we use another method to get the
+ # total size.
+ # TODO(yuefengz): move this logic to array_ops.size.
+ if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]):
+ total_grad_size = sum(
+ [g.shape.num_elements() for g, _ in tower_grads_and_vars])
+ else:
+ total_grad_size = array_ops.size(concat_grads)
+
split_size = total_grad_size // num_splits
split_size_last = total_grad_size - split_size * (num_splits - 1)
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
@@ -412,6 +429,31 @@ class AggregateSmallTensorPacker(object):
self.packing)
+def _pack_tensors(device_grads,
+ num_packs=0,
+ agg_small_grads_max_bytes=0,
+ agg_small_grads_max_group=0):
+ """Pack tensors if specified."""
+ if num_packs > 0:
+ tensor_packer = ConcatAndSplitPacker(num_packs)
+ device_grad_packs = tensor_packer.pack(device_grads)
+ elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0:
+ tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes,
+ agg_small_grads_max_group)
+ device_grad_packs = tensor_packer.pack(device_grads)
+ else:
+ tensor_packer = None
+ device_grad_packs = device_grads
+ return device_grad_packs, tensor_packer
+
+
+def _unpack_tensors(reduced, tensor_packer=None):
+ """Unpack tensors if they are packed before all-reduce."""
+ if tensor_packer:
+ return tensor_packer.unpack(reduced)
+ return reduced
+
+
class AllReduceCrossTowerOps(CrossTowerOps):
"""Reduction using all reduce."""
@@ -440,10 +482,10 @@ class AllReduceCrossTowerOps(CrossTowerOps):
agg_small_grads_max_group: see above.
tensors.
"""
- self.all_reduce_alg = all_reduce_alg
- self.num_packs = num_packs
- self.agg_small_grads_max_bytes = agg_small_grads_max_bytes
- self.agg_small_grads_max_group = agg_small_grads_max_group
+ self._all_reduce_alg = all_reduce_alg
+ self._num_packs = num_packs
+ self._agg_small_grads_max_bytes = agg_small_grads_max_bytes
+ self._agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossTowerOps, self).__init__()
def _reduce(self, method_string, per_device_value, destinations):
@@ -485,37 +527,24 @@ class AllReduceCrossTowerOps(CrossTowerOps):
def _batch_all_reduce(self, method_string, per_device_values):
"""All reduce algorithm in a batch."""
+ logging.info(
+ "batch_all_reduce invoked for batches size = %d with "
+ "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and "
+ "agg_small_grads_max_group = %d", len(per_device_values),
+ self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
destinations = per_device_values[0].devices
grouped = _group_value_by_device(per_device_values)
- if self.num_packs > 0:
- logging.info(
- "batch_all_reduce invoked for batches size = %d with "
- "algorithm = %s and num_packs = %d", len(per_device_values),
- self.all_reduce_alg, self.num_packs)
- tensor_packer = ConcatAndSplitPacker(self.num_packs)
- device_grad_packs = tensor_packer.pack(grouped)
- elif (self.agg_small_grads_max_bytes > 0 and
- self.agg_small_grads_max_group > 0):
- logging.info(
- "batch_all_reduce invoked for batches size = %d with "
- "algorithm = %s, agg_small_grads_max_bytes = %d and "
- "agg_small_grads_max_group = %d", len(per_device_values),
- self.all_reduce_alg, self.agg_small_grads_max_bytes,
- self.agg_small_grads_max_group)
- tensor_packer = AggregateSmallTensorPacker(
- self.agg_small_grads_max_bytes, self.agg_small_grads_max_group)
- device_grad_packs = tensor_packer.pack(grouped)
- else:
- logging.info(
- "batch_all_reduce invoked for batches size = %d with algorithm = %s",
- len(per_device_values), self.all_reduce_alg)
- tensor_packer = None
- device_grad_packs = grouped
+
+ device_grad_packs, self._tensor_packer = _pack_tensors(
+ grouped, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
# The actual aggregation of the repacked gradients. Note that they are
# sharded among different aggregation trees. So it is important to strike
# the balance on num_splits.
- if self.all_reduce_alg == "nccl":
+ if self._all_reduce_alg == "nccl":
+ # TODO(yuefengz): merge this into the all-reduce library.
reduced = cross_tower_utils.aggregate_gradients_using_nccl(
device_grad_packs)
else:
@@ -525,13 +554,137 @@ class AllReduceCrossTowerOps(CrossTowerOps):
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
- if tensor_packer:
- reduced = tensor_packer.unpack(reduced)
-
+ reduced = _unpack_tensors(reduced, self._tensor_packer)
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
method_string)
+AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
+ "alg shards limit")
+
+
+class MultiWorkerAllReduce(AllReduceCrossTowerOps):
+ """All-reduce algorithms for distributed TensorFlow."""
+
+ def __init__(self,
+ worker_devices,
+ num_gpus_per_worker,
+ all_reduce_spec=("pscpu/pscpu", 2, -1),
+ num_packs=0,
+ agg_small_grads_max_bytes=0,
+ agg_small_grads_max_group=10):
+ """Initialize the all-reduce algorithm.
+
+ Args:
+ worker_devices: a list of device strings for workers participating in
+ all-reduce.
+ num_gpus_per_worker: number of GPU devices per worker.
+ all_reduce_spec: a tuple or a named tuple or a list of tuples specifying
+ the all-reduce algorithm.
+ 1. The first element of a tuple is the name of the all-reduce algorithm.
+ Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd",
+ "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with
+ a "/" are hierarchical, so two all-reduces are executed, the first one
+ aggregates tensors within a worker and the second aggregates across
+ workers.
+ 2. The second element of a tuple is the number of shards when doing
+ all-reduce. Let's say its values is M, each tensor after packing will be
+ split into M shards and then M parallel all-reduces would be performed
+ before finally they are concatenated backed into a complete tensor.
+ 3. The third element is the maximum size of tensors that will be
+ applicable for the algorithm specified by the first element. For
+ example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)],
+ tensors with size not larger than 1024 bytes will be applied a 2-shard
+ "nccl" all-reduce and other tensors will be applied a 2-shard
+ "pscpu/pscpu" algorithm. The third elements should be in increasing
+ order across tuples and end with -1 which indicates infinity.
+ num_packs: see AllReduceCrossTowerOps.
+ agg_small_grads_max_bytes: see AllReduceCrossTowerOps.
+ agg_small_grads_max_group: see AllReduceCrossTowerOps.
+ """
+ self._worker_devices = worker_devices
+ self._num_gpus_per_worker = num_gpus_per_worker
+ super(MultiWorkerAllReduce, self).__init__(
+ num_packs=num_packs,
+ agg_small_grads_max_bytes=agg_small_grads_max_bytes,
+ agg_small_grads_max_group=agg_small_grads_max_group)
+
+ def validate_and_complete_spec(spec):
+ """Validate and complete the all-reduce spec."""
+ # TODO(yuefengz): support namedtuple.
+ if not isinstance(spec, tuple):
+ raise ValueError(
+ "A tuple is expected for all-reduce spec: %r" % all_reduce_spec)
+ if not spec or len(spec) > 3:
+ raise ValueError(
+ "Too many elements in the all-reduce spec tuple: %r" % spec)
+ if len(spec) == 1:
+ return AllReduceSpecTuple(spec[0], 1, -1)
+ elif len(spec) == 2:
+ return AllReduceSpecTuple(spec[0], spec[1], -1)
+ else:
+ return AllReduceSpecTuple(*spec)
+
+ self._all_reduce_spec = []
+ if isinstance(all_reduce_spec, six.string_types):
+ self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1))
+ elif isinstance(all_reduce_spec, tuple):
+ self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec))
+ elif isinstance(all_reduce_spec, list):
+ self._all_reduce_spec = [
+ validate_and_complete_spec(spec) for spec in all_reduce_spec
+ ]
+
+ def _batch_all_reduce(self, method_string, per_device_values):
+ """All reduce algorithm in a batch."""
+ logging.info(
+ "distributed batch_all_reduce invoked for batches size = %d with "
+ "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d "
+ "and agg_small_grads_max_group = %d", len(per_device_values),
+ self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
+
+ destinations = sorted(per_device_values[0].devices)
+ device_grads = _group_value_by_device(per_device_values)
+
+ # The all reduce library requires fully defined shapes.
+ # TODO(yuefengz): when tensor sharding is not needed, static shapes are not
+ # required as well.
+ for device_grad in device_grads:
+ for grad, _ in device_grad:
+ if not grad.shape.is_fully_defined():
+ raise ValueError("Shape is unknown for node %r" % grad)
+
+ remaining_grads = device_grads
+ aggregated_grads = []
+ for spec_tuple in self._all_reduce_spec:
+ if spec_tuple.limit < 0:
+ this_grads = remaining_grads
+ remaining_grads = []
+ else:
+ (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size(
+ spec_tuple.limit, remaining_grads)
+ if this_grads:
+ device_grad_packs, self._tensor_packer = _pack_tensors(
+ this_grads, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
+ range_agg_grads = cross_tower_utils.sum_gradients_all_reduce(
+ self._worker_devices, device_grad_packs, len(self._worker_devices),
+ spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
+ range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer)
+
+ if not aggregated_grads:
+ aggregated_grads = range_agg_grads
+ else:
+ assert len(aggregated_grads) == len(range_agg_grads)
+ for i in range(len(aggregated_grads)):
+ aggregated_grads[i] += range_agg_grads[i]
+ assert not remaining_grads
+
+ return _ungroup_and_make_mirrored(aggregated_grads, destinations,
+ method_string)
+
+
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 2a26632608..fed5505d92 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -75,7 +76,7 @@ def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
_cpu_device = "/device:CPU:0"
-class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
+class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
def _assert_indexed_slices_equal(self, left, right):
self.assertIsInstance(left, ops.IndexedSlices)
@@ -94,7 +95,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(type(left), type(right))
self.assertEqual(left.devices, right.devices)
if isinstance(list(left._index.values())[0], ops.IndexedSlices):
- for (d, v) in left._index.iteritems():
+ for (d, v) in left._index.items():
self._assert_indexed_slices_equal(v, right._index[d])
elif context.executing_eagerly():
self.assertEqual([v.numpy() for v in left._index.values()],
@@ -104,51 +105,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(
sess.run(list(left._index.values())), list(right._index.values()))
- # TODO(yuefengz): decouple the num_gpus check from distribution in
- # combinations module so that we can pass in devices instead of a distribution
- # strategy.
- reduction_to_one_combinations = combinations.combine(
- cross_tower_ops=[
- combinations.NamedObject(
- "DefaultReductionToOneDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
- combinations.NamedObject(
- "ReductionToCPUDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
- reduce_to_device=_cpu_device)),
- combinations.NamedObject(
- "AccumulateNCrossTowerOp",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
- accumulation_fn=math_ops.accumulate_n)),
- ],
- distribution=[
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus
- ],
- mode=["graph", "eager"])
- allreduce_combinations = combinations.combine(
- cross_tower_ops=[
- combinations.NamedObject(
- "AllReduce",
- cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
- combinations.NamedObject(
- "HierarchicalCopy",
- cross_tower_ops_lib.AllReduceCrossTowerOps(
- "hierarchical_copy", 8, 0, 0)),
- combinations.NamedObject(
- "AllReduceNoGradientRepacking",
- cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
- combinations.NamedObject(
- "HierarchicalCopyAggregateSmallTensors",
- cross_tower_ops_lib.AllReduceCrossTowerOps(
- "hierarchical_copy", 0, 100, 10))
- ],
- distribution=[combinations.mirrored_strategy_with_two_gpus],
- mode=["graph", "eager"])
-
- @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
- def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ def _testReductionAndBroadcast(self, cross_tower_ops, distribution):
devices = distribution.worker_devices
values = [constant_op.constant(float(d)) for d in range(len(devices))]
@@ -208,20 +165,70 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
_fake_mirrored(1., destinations))
+
+class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
+ # TODO(yuefengz): decouple the num_gpus check from distribution in
+ # combinations module so that we can pass in devices instead of a distribution
+ # strategy.
+ reduction_to_one_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "DefaultReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "ReductionToCPUDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ reduce_to_device=_cpu_device)),
+ combinations.NamedObject(
+ "AccumulateNCrossTowerOp",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ accumulation_fn=math_ops.accumulate_n)),
+ ],
+ distribution=[
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ mode=["graph", "eager"])
+ allreduce_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "AllReduce",
+ cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
+ combinations.NamedObject(
+ "HierarchicalCopy",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "hierarchical_copy", 8, 0, 0)),
+ combinations.NamedObject(
+ "AllReduceNoGradientRepacking",
+ cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
+ combinations.NamedObject(
+ "HierarchicalCopyAggregateSmallTensors",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "hierarchical_copy", 0, 100, 10))
+ ],
+ distribution=[combinations.mirrored_strategy_with_two_gpus],
+ mode=["graph", "eager"])
+
+ @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
+ def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ with distribution.scope():
+ self._testReductionAndBroadcast(cross_tower_ops, distribution)
+
def testChooseAlgorithm(self):
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
- self.assertEqual(result.num_packs, 8)
+ self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
+ self.assertEqual(result._num_packs, 8)
# if there are only 4 devices
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "nccl")
- self.assertEqual(result.num_packs, 1)
+ self.assertEqual(result._all_reduce_alg, "nccl")
+ self.assertEqual(result._num_packs, 1)
# if devices links contain each device itself
device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
@@ -229,16 +236,16 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
- self.assertEqual(result.num_packs, 8)
+ self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
+ self.assertEqual(result._num_packs, 8)
# if not dgx1-like links
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "nccl")
- self.assertEqual(result.num_packs, 1)
+ self.assertEqual(result._all_reduce_alg, "nccl")
+ self.assertEqual(result._num_packs, 1)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
@@ -316,5 +323,44 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self._assert_values_equal(total_mirrored_without_dups, result)
+class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
+ CrossTowerOpsTestBase):
+
+ worker_devices = [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ multi_worker_allreduce_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "MultiWorkerAllReduce",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)),
+ combinations.NamedObject(
+ "MultiWorkerAllReducePack",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)),
+ combinations.NamedObject(
+ "MultiWorkerAllReduceAggregation",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)),
+ combinations.NamedObject(
+ "MultiWorkerAllReduceMultipleSpecs",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, [("pscpu/pscpu", 2, 100),
+ ("xring", 2, -1)], 0, 0, 0)),
+ ],
+ distribution=[
+ combinations.multi_worker_strategy_with_cpu,
+ combinations.multi_worker_strategy_with_one_gpu,
+ combinations.multi_worker_strategy_with_two_gpus
+ ],
+ mode=["graph"])
+
+ @combinations.generate(multi_worker_allreduce_combinations)
+ def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ with distribution.scope():
+ self._testReductionAndBroadcast(cross_tower_ops, distribution)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 137fabf4c7..2bb088e704 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections as pycoll
from tensorflow.contrib import nccl
+from tensorflow.contrib.all_reduce.python import all_reduce
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -158,6 +159,148 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
return (grad, v), None
+def group_device_names(devices, group_size):
+ """Group device names into groups of group_size.
+
+ Args:
+ devices: a list of canonical device strings.
+ group_size: integer which is equal to or greater than 1.
+
+ Returns:
+ list of lists of devices, where each inner list is group_size long,
+ and each device appears at least once in an inner list. If
+ len(devices) % group_size == 0 then each device will appear exactly once.
+
+ Raises:
+ ValueError: if group_size > len(devices)
+ """
+ num_devices = len(devices)
+ if group_size > num_devices:
+ raise ValueError(
+ 'only %d devices, but group_size=%d' % (num_devices, group_size))
+ num_groups = (
+ num_devices // group_size + (1 if (num_devices % group_size != 0) else 0))
+ groups = [[] for i in range(num_groups)]
+ for i in range(num_groups * group_size):
+ groups[i % num_groups].append(devices[i % num_devices])
+ return groups
+
+
+def split_grads_by_size(threshold_size, device_grads):
+ """Break gradients into two sets according to tensor size.
+
+ Args:
+ threshold_size: int size cutoff for small vs large tensor.
+ device_grads: List of lists of (gradient, variable) tuples. The outer
+ list is over devices. The inner list is over individual gradients.
+
+ Returns:
+ small_grads: Subset of device_grads where shape is <= threshold_size
+ elements.
+ large_grads: Subset of device_grads where shape is > threshold_size
+ elements.
+ """
+ small_grads = []
+ large_grads = []
+ for dl in device_grads:
+ small_dl = []
+ large_dl = []
+ for (g, v) in dl:
+ tensor_size = g.get_shape().num_elements()
+ if tensor_size <= threshold_size:
+ small_dl.append([g, v])
+ else:
+ large_dl.append([g, v])
+ if small_dl:
+ small_grads.append(small_dl)
+ if large_dl:
+ large_grads.append(large_dl)
+ return small_grads, large_grads
+
+
+def sum_grad_and_var_all_reduce(grad_and_vars,
+ num_workers,
+ alg,
+ gpu_indices,
+ aux_devices=None,
+ num_shards=1):
+ """Apply all-reduce algorithm over specified gradient tensors."""
+ with ops.name_scope('allreduce'):
+ # Note that each grad_and_vars looks like the following:
+ # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
+ scaled_grads = [g for g, _ in grad_and_vars]
+ if alg == 'nccl':
+ summed_grads = nccl.all_sum(scaled_grads)
+ elif alg == 'xring':
+ summed_grads = all_reduce.build_ring_all_reduce(
+ scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add)
+ elif alg == 'nccl/xring':
+ summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards,
+ math_ops.add)
+ elif alg == 'nccl/rechd':
+ summed_grads = all_reduce.build_nccl_then_recursive_hd(
+ scaled_grads, math_ops.add)
+ elif alg == 'nccl/pscpu':
+ summed_grads = all_reduce.build_nccl_then_shuffle(
+ scaled_grads, aux_devices, math_ops.add, math_ops.add_n)
+ elif alg == 'pscpu/pscpu':
+ second_gather_devices = aux_devices[:num_shards]
+ summed_grads = all_reduce.build_shuffle_then_shuffle(
+ scaled_grads, aux_devices, second_gather_devices, math_ops.add_n)
+ elif alg in ['pscpu', 'psgpu']:
+ summed_grads = all_reduce.build_shuffle_all_reduce(
+ scaled_grads, aux_devices, math_ops.add_n)
+ else:
+ raise ValueError('unsupported all_reduce alg: ', alg)
+
+ result = []
+ for (_, v), g in zip(grad_and_vars, summed_grads):
+ result.append([g, v])
+ return result
+
+
+def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg,
+ num_shards, gpu_indices):
+ """Apply all-reduce algorithm over specified gradient tensors.
+
+ Args:
+ dev_prefixes: list of prefix strings to use to generate PS device names.
+ tower_grads: the gradients to reduce.
+ num_workers: number of worker processes across entire job.
+ alg: the all-reduce algorithm to apply.
+ num_shards: alg-specific sharding factor.
+ gpu_indices: indices of local GPUs in order usable for ring-reduce.
+
+ Returns:
+ list of reduced tensors
+ """
+ alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']])
+ is_hierarchical = '/' in alg
+ if 'pscpu' in alg:
+ aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes]
+ elif 'psgpu' in alg:
+ aux_devices = [
+ prefix + '/gpu:%d' % i
+ for i in range(len(gpu_indices))
+ for prefix in dev_prefixes
+ ]
+ else:
+ aux_devices = ['/job:localhost/cpu:0']
+ # Auxiliary devices for hierarchical all-reduces.
+ aux_device_groups = group_device_names(
+ aux_devices, num_shards if alg_contains_shuffle else 1)
+ group_index = 0
+ reduced_gv_list = []
+ for grad_and_vars in zip(*tower_grads):
+ reduced_gv_list.append(
+ sum_grad_and_var_all_reduce(
+ grad_and_vars, num_workers, alg, gpu_indices, aux_devices
+ if is_hierarchical else aux_device_groups[group_index], num_shards))
+ group_index = (group_index + 1) % len(aux_device_groups)
+ new_tower_grads = [list(x) for x in zip(*reduced_gv_list)]
+ return new_tower_grads
+
+
def extract_ranges(index_list, range_size_limit=32):
"""Extract consecutive ranges and singles from index_list.
@@ -330,7 +473,7 @@ def unpack_small_tensors(tower_grads, packing):
for dev_idx, gv_list in enumerate(tower_grads):
gv_list = list(gv_list)
new_gv_list = gv_list[num_packed:]
- for i in xrange(0, num_packed):
+ for i in range(num_packed):
k = '%d:%d' % (dev_idx, i)
gpt = packing[k]
gv = unpack_grad_tuple(gv_list[i], gpt)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
index a552b370eb..0f21a42732 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
@@ -121,7 +121,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy):
worker: [device_util.canonicalize(worker, '/device:CPU:0')]
for worker in self._workers
}
- self._devices = nest.flatten(self._worker_device_map.values())
+ self._devices = nest.flatten(self._worker_device_map)
super(MultiWorkerMirroredStrategy, self).__init__(
devices=self._devices, prefetch_on_device=prefetch_on_device)
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 23d9dbcd91..51f7028566 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -941,6 +941,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "fill_triangular_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "gumbel_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/gumbel_test.py"],
@@ -1119,6 +1138,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "scale_tril_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "sigmoid_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/sigmoid_test.py"],
@@ -1236,6 +1274,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "transform_diagonal_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "weibull_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/weibull_test.py"],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 802538ba97..5cec93c4df 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Classes representing statistical distributions and ops for working with them.
-
-See the @{$python/contrib.distributions} guide.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py
new file mode 100644
index 0000000000..caeaf2a0c6
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for FillTriangular bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class FillTriangularBijectorTest(test.TestCase):
+ """Tests the correctness of the FillTriangular bijector."""
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBijector(self):
+ x = np.float32(np.array([1., 2., 3.]))
+ y = np.float32(np.array([[3., 0.],
+ [2., 1.]]))
+
+ b = bijectors.FillTriangular()
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
+ self.assertAllClose(fldj, 0.)
+
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllClose(ildj, 0.)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testShape(self):
+ x_shape = tensor_shape.TensorShape([5, 4, 6])
+ y_shape = tensor_shape.TensorShape([5, 4, 3, 3])
+
+ b = bijectors.FillTriangular(validate_args=True)
+
+ x = array_ops.ones(shape=x_shape, dtype=dtypes.float32)
+ y_ = b.forward(x)
+ self.assertAllEqual(y_.shape.as_list(), y_shape.as_list())
+ x_ = b.inverse(y_)
+ self.assertAllEqual(x_.shape.as_list(), x_shape.as_list())
+
+ y_shape_ = b.forward_event_shape(x_shape)
+ self.assertAllEqual(y_shape_.as_list(), y_shape.as_list())
+ x_shape_ = b.inverse_event_shape(y_shape)
+ self.assertAllEqual(x_shape_.as_list(), x_shape.as_list())
+
+ y_shape_tensor = self.evaluate(
+ b.forward_event_shape_tensor(x_shape.as_list()))
+ self.assertAllEqual(y_shape_tensor, y_shape.as_list())
+ x_shape_tensor = self.evaluate(
+ b.inverse_event_shape_tensor(y_shape.as_list()))
+ self.assertAllEqual(x_shape_tensor, x_shape.as_list())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testShapeError(self):
+
+ b = bijectors.FillTriangular(validate_args=True)
+
+ x_shape_bad = tensor_shape.TensorShape([5, 4, 7])
+ with self.assertRaisesRegexp(ValueError, "is not a triangular number"):
+ b.forward_event_shape(x_shape_bad)
+ with self.assertRaisesOpError("is not a triangular number"):
+ self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list()))
+
+ y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2])
+ with self.assertRaisesRegexp(ValueError, "Matrix must be square"):
+ b.inverse_event_shape(y_shape_bad)
+ with self.assertRaisesOpError("Matrix must be square"):
+ self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list()))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py
new file mode 100644
index 0000000000..566a7b3dff
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py
@@ -0,0 +1,69 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ScaleTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class ScaleTriLBijectorTest(test.TestCase):
+ """Tests the correctness of the ScaleTriL bijector."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testComputesCorrectValues(self):
+ shift = 1.61803398875
+ x = np.float32(np.array([-1, .5, 2]))
+ y = np.float32(np.array([[np.exp(2) + shift, 0.],
+ [.5, np.exp(-1) + shift]]))
+
+ b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(),
+ diag_shift=shift)
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvertible(self):
+
+ # Generate random inputs from an unconstrained space, with
+ # event size 6 to specify 3x3 triangular matrices.
+ batch_shape = [2, 1]
+ x = np.float32(np.random.randn(*(batch_shape + [6])))
+ b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(),
+ diag_shift=3.14159)
+ y = self.evaluate(b.forward(x))
+ self.assertAllEqual(y.shape, batch_shape + [3, 3])
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllClose(fldj, -ildj)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
index 45760a29ee..795f1993ba 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
@@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase):
self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.)
self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.)
- # Do the numpy calculation in float128 to avoid inf/nan.
- y_float128 = np.float128(y)
- self.assertAllClose(
- np.log(np.cosh(
- np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt(
- y_float128**2 + 1)) -
- np.log(tailweight),
- bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
- rtol=1e-4,
- atol=0.)
+ # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision.
+ # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and
+ # below test fails due to overflow error giving inf. So this check avoids that error by skipping square
+ # calculation and corresponding assert.
+
+ if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \
+ np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)):
+
+ # Do the numpy calculation in float128 to avoid inf/nan.
+ y_float128 = np.float128(y)
+ self.assertAllClose(
+ np.log(np.cosh(
+ np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt(
+ y_float128**2 + 1)) -
+ np.log(tailweight),
+ bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
+ rtol=1e-4,
+ atol=0.)
self.assertAllClose(
-bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
bijector.forward_log_det_jacobian(x, event_ndims=0).eval(),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py
new file mode 100644
index 0000000000..6428a68702
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TransformDiagonal bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class TransformDiagonalBijectorTest(test.TestCase):
+ """Tests correctness of the TransformDiagonal bijector."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBijector(self):
+ x = np.float32(np.random.randn(3, 4, 4))
+
+ y = x.copy()
+ for i in range(x.shape[0]):
+ np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :])))
+
+ exp = bijectors.Exp()
+ b = bijectors.TransformDiagonal(diag_bijector=exp)
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2))
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllEqual(
+ fldj,
+ self.evaluate(exp.forward_log_det_jacobian(
+ np.array([np.diag(x_mat) for x_mat in x]),
+ event_ndims=1)))
+ self.assertAllEqual(
+ ildj,
+ self.evaluate(exp.inverse_log_det_jacobian(
+ np.array([np.diag(y_mat) for y_mat in y]),
+ event_ndims=1)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 4965381ef3..e141f8b5c6 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -24,6 +24,7 @@
@@CholeskyOuterProduct
@@ConditionalBijector
@@Exp
+@@FillTriangular
@@Gumbel
@@Identity
@@Inline
@@ -36,12 +37,14 @@
@@PowerTransform
@@RealNVP
@@Reshape
+@@ScaleTriL
@@Sigmoid
@@SinhArcsinh
@@SoftmaxCentered
@@Softplus
@@Softsign
@@Square
+@@TransformDiagonal
@@Weibull
@@masked_autoregressive_default_template
@@ -64,6 +67,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
+from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import *
from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
@@ -75,12 +79,14 @@ from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
+from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
from tensorflow.contrib.distributions.python.ops.bijectors.softsign import *
from tensorflow.contrib.distributions.python.ops.bijectors.square import *
+from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import *
from tensorflow.python.ops.distributions.bijector import *
from tensorflow.python.ops.distributions.identity_bijector import Identity
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py
new file mode 100644
index 0000000000..7b06325ead
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py
@@ -0,0 +1,148 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FillTriangular bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.ops.distributions import util as dist_util
+
+
+__all__ = [
+ "FillTriangular",
+]
+
+
+class FillTriangular(bijector.Bijector):
+ """Transforms vectors to triangular.
+
+ Triangular matrix elements are filled in a clockwise spiral.
+
+ Given input with shape `batch_shape + [d]`, produces output with
+ shape `batch_shape + [n, n]`, where
+ `n = (-1 + sqrt(1 + 8 * d))/2`.
+ This follows by solving the quadratic equation
+ `d = 1 + 2 + ... + n = n * (n + 1)/2`.
+
+ #### Example
+
+ ```python
+ b = tfb.FillTriangular(upper=False)
+ b.forward([1, 2, 3, 4, 5, 6])
+ # ==> [[4, 0, 0],
+ # [6, 5, 0],
+ # [3, 2, 1]]
+
+ b = tfb.FillTriangular(upper=True)
+ b.forward([1, 2, 3, 4, 5, 6])
+ # ==> [[1, 2, 3],
+ # [0, 5, 6],
+ # [0, 0, 4]]
+
+ ```
+ """
+
+ def __init__(self,
+ upper=False,
+ validate_args=False,
+ name="fill_triangular"):
+ """Instantiates the `FillTriangular` bijector.
+
+ Args:
+ upper: Python `bool` representing whether output matrix should be upper
+ triangular (`True`) or lower triangular (`False`, default).
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._upper = upper
+ super(FillTriangular, self).__init__(
+ forward_min_event_ndims=1,
+ inverse_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ return dist_util.fill_triangular(x, upper=self._upper)
+
+ def _inverse(self, y):
+ return dist_util.fill_triangular_inverse(y, upper=self._upper)
+
+ def _forward_log_det_jacobian(self, x):
+ return array_ops.zeros_like(x[..., 0])
+
+ def _inverse_log_det_jacobian(self, y):
+ return array_ops.zeros_like(y[..., 0, 0])
+
+ def _forward_event_shape(self, input_shape):
+ batch_shape, d = input_shape[:-1], input_shape[-1].value
+ if d is None:
+ n = None
+ else:
+ n = vector_size_to_square_matrix_size(d, self.validate_args)
+ return batch_shape.concatenate([n, n])
+
+ def _inverse_event_shape(self, output_shape):
+ batch_shape, n1, n2 = (output_shape[:-2],
+ output_shape[-2].value,
+ output_shape[-1].value)
+ if n1 is None or n2 is None:
+ m = None
+ elif n1 != n2:
+ raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2))
+ else:
+ m = n1 * (n1 + 1) / 2
+ return batch_shape.concatenate([m])
+
+ def _forward_event_shape_tensor(self, input_shape_tensor):
+ batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1]
+ n = vector_size_to_square_matrix_size(d, self.validate_args)
+ return array_ops.concat([batch_shape, [n, n]], axis=0)
+
+ def _inverse_event_shape_tensor(self, output_shape_tensor):
+ batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
+ if self.validate_args:
+ is_square_matrix = check_ops.assert_equal(
+ n, output_shape_tensor[-2], message="Matrix must be square.")
+ with ops.control_dependencies([is_square_matrix]):
+ n = array_ops.identity(n)
+ d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
+ return array_ops.concat([batch_shape, [d]], axis=0)
+
+
+def vector_size_to_square_matrix_size(d, validate_args, name=None):
+ """Convert a vector size to a matrix size."""
+ if isinstance(d, (float, int, np.generic, np.ndarray)):
+ n = (-1 + np.sqrt(1 + 8 * d)) / 2.
+ if float(int(n)) != n:
+ raise ValueError("Vector length is not a triangular number.")
+ return int(n)
+ else:
+ with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name:
+ n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2.
+ if validate_args:
+ with ops.control_dependencies([check_ops.assert_equal(
+ math_ops.to_float(math_ops.to_int32(n)), n,
+ message="Vector length is not a triangular number")]):
+ n = array_ops.identity(n)
+ return math_ops.cast(n, d.dtype)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
new file mode 100644
index 0000000000..96bd242c63
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ScaleTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar
+from tensorflow.contrib.distributions.python.ops.bijectors import chain
+from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular
+from tensorflow.contrib.distributions.python.ops.bijectors import softplus
+from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal
+
+__all__ = [
+ "ScaleTriL",
+]
+
+
+class ScaleTriL(chain.Chain):
+ """Transforms unconstrained vectors to TriL matrices with positive diagonal.
+
+ This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular`
+ followed by `tfb.TransformDiagonal`, and provided mostly as a
+ convenience. The default setup is somewhat opinionated, using a
+ Softplus transformation followed by a small shift (`1e-5`) which
+ attempts to avoid numerical issues from zeros on the diagonal.
+
+ #### Examples
+
+ ```python
+ tfb = tf.contrib.distributions.bijectors
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Exp(),
+ diag_shift=None)
+ b.forward(x=[0., 0., 0.])
+ # Result: [[1., 0.],
+ # [0., 1.]]
+ b.inverse(y=[[1., 0],
+ [.5, 2]])
+ # Result: [log(2), .5, log(1)]
+
+ # Define a distribution over PSD matrices of shape `[3, 3]`,
+ # with `1 + 2 + 3 = 6` degrees of freedom.
+ dist = tfd.TransformedDistribution(
+ tfd.Normal(tf.zeros(6), tf.ones(6)),
+ tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()]))
+
+ # Using an identity transformation, ScaleTriL is equivalent to
+ # tfb.FillTriangular.
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Identity(),
+ diag_shift=None)
+
+ # For greater control over initialization, one can manually encode
+ # pre- and post- shifts inside of `diag_bijector`.
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Chain([
+ tfb.AffineScalar(shift=1e-3),
+ tfb.Softplus(),
+ tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.)
+ # = log(expm1(1.)) = 0.5413
+ diag_shift=None)
+ ```
+ """
+
+ def __init__(self,
+ diag_bijector=None,
+ diag_shift=1e-5,
+ validate_args=False,
+ name="scale_tril"):
+ """Instantiates the `ScaleTriL` bijector.
+
+ Args:
+ diag_bijector: `Bijector` instance, used to transform the output diagonal
+ to be positive.
+ Default value: `None` (i.e., `tfb.Softplus()`).
+ diag_shift: Float value broadcastable and added to all diagonal entries
+ after applying the `diag_bijector`. Setting a positive
+ value forces the output diagonal entries to be positive, but
+ prevents inverting the transformation for matrices with
+ diagonal entries less than this value.
+ Default value: `1e-5` (i.e., no shift is applied).
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ Default value: `False` (i.e., arguments are not validated).
+ name: Python `str` name given to ops managed by this object.
+ Default value: `scale_tril`.
+ """
+
+ if diag_bijector is None:
+ diag_bijector = softplus.Softplus(validate_args=validate_args)
+
+ if diag_shift is not None:
+ diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift),
+ diag_bijector])
+
+ super(ScaleTriL, self).__init__(
+ [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
+ fill_triangular.FillTriangular()],
+ validate_args=validate_args,
+ name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py
new file mode 100644
index 0000000000..65669fc2bf
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TransformDiagonal bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions import bijector
+
+__all__ = [
+ "TransformDiagonal",
+]
+
+
+class TransformDiagonal(bijector.Bijector):
+ """Applies a Bijector to the diagonal of a matrix.
+
+ #### Example
+
+ ```python
+ b = tfb.TransformDiagonal(diag_bijector=tfb.Exp())
+
+ b.forward([[1., 0.],
+ [0., 1.]])
+ # ==> [[2.718, 0.],
+ [0., 2.718]]
+ ```
+
+ """
+
+ def __init__(self,
+ diag_bijector,
+ validate_args=False,
+ name="transform_diagonal"):
+ """Instantiates the `TransformDiagonal` bijector.
+
+ Args:
+ diag_bijector: `Bijector` instance used to transform the diagonal.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._diag_bijector = diag_bijector
+ super(TransformDiagonal, self).__init__(
+ forward_min_event_ndims=2,
+ inverse_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x))
+ return array_ops.matrix_set_diag(x, diag)
+
+ def _inverse(self, y):
+ diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y))
+ return array_ops.matrix_set_diag(y, diag)
+
+ def _forward_log_det_jacobian(self, x):
+ # We formulate the Jacobian with respect to the flattened matrices
+ # `vec(x)` and `vec(y)`. Suppose for notational convenience that
+ # the first `n` entries of `vec(x)` are the diagonal of `x`, and
+ # the remaining `n**2-n` entries are the off-diagonals in
+ # arbitrary order. Then the Jacobian is a block-diagonal matrix,
+ # with the Jacobian of the diagonal bijector in the first block,
+ # and the identity Jacobian for the remaining entries (since this
+ # bijector acts as the identity on non-diagonal entries):
+ #
+ # J_vec(x) (vec(y)) =
+ # -------------------------------
+ # | J_diag(x) (diag(y)) 0 | n entries
+ # | |
+ # | 0 I | n**2-n entries
+ # -------------------------------
+ # n n**2-n
+ #
+ # Since the log-det of the second (identity) block is zero, the
+ # overall log-det-jacobian is just the log-det of first block,
+ # from the diagonal bijector.
+ #
+ # Note that for elementwise operations (exp, softplus, etc) the
+ # first block of the Jacobian will itself be a diagonal matrix,
+ # but our implementation does not require this to be true.
+ return self._diag_bijector.forward_log_det_jacobian(
+ array_ops.matrix_diag_part(x), event_ndims=1)
+
+ def _inverse_log_det_jacobian(self, y):
+ return self._diag_bijector.inverse_log_det_jacobian(
+ array_ops.matrix_diag_part(y), event_ndims=1)
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
index 98b4ce1b26..729d8525fa 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
@@ -57,11 +57,6 @@ class Dynamics(tf.keras.Model):
self.eps = tfe.Variable(
initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
- # TODO(lxuechen): Remove this after model.add_weight is in place
- self.vars_not_in_layers = [self.eps]
- self.vars_not_in_layers += self.position_fn.vars_not_in_layers
- self.vars_not_in_layers += self.momentum_fn.vars_not_in_layers
-
def apply_transition(self, position):
"""Propose a new state and perform the accept or reject step."""
@@ -290,86 +285,35 @@ class Dynamics(tf.keras.Model):
return grad
-# Defining loss and grads for training
-def compute_loss(x, dynamics, scale=.1, eps=1e-4):
- """Compute loss defined in equation (8)."""
-
- z = tf.random_normal(tf.shape(x))
- x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
- z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
-
- # Add eps for numerical stability; following released impl
- x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
- z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
-
- loss = tf.reduce_mean(
- (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
-
- return loss, x_out
-
-
-def loss_and_grads(x, dynamics):
- """Obtain loss value and gradients."""
-
- with tf.GradientTape() as tape:
- loss_val, x_out = compute_loss(x, dynamics)
-
- vars_ = dynamics.variables + dynamics.vars_not_in_layers
- grads = tape.gradient(loss_val, vars_)
-
- return loss_val, grads, x_out
-
-
-def warmup(dynamics, optimizer, n_iters=1, n_samples=200):
- """Warmup optimization to reduce overhead."""
-
- samples = tf.random_normal(
- shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
-
- for _ in range(n_iters):
- _, grads, samples = loss_and_grads(samples, dynamics)
- vars_ = dynamics.variables + dynamics.vars_not_in_layers
- optimizer.apply_gradients(zip(grads, vars_))
-
-
-def fit(dynamics,
- optimizer,
- n_samples=200,
- n_iters=5000,
- verbose=True,
- logdir=None):
- """Fit L2HMC sampler with given log-likelihood function."""
-
- if logdir:
- summary_writer = tf.contrib.summary.create_file_writer(logdir)
+# Examples of unnormalized log density/probabilities
+def get_scg_energy_fn():
+ """Get energy function for 2d strongly correlated Gaussian."""
- samples = tf.random_normal(
- shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+ # Avoid recreating tf constants on each invocation of gradients
+ mu = tf.constant([0., 0.])
+ sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
+ sigma_inv = tf.matrix_inverse(sigma)
- tf.train.get_or_create_global_step()
- for i in range(n_iters):
- loss, grads, samples = loss_and_grads(samples, dynamics)
- # TODO(lxuechen): Proper learning rate decay
- grads_ = [grad * .96**(i // 1000) for grad in grads]
- vars_ = dynamics.variables + dynamics.vars_not_in_layers
- optimizer.apply_gradients(
- zip(grads_, vars_), global_step=tf.train.get_global_step())
+ def energy(x):
+ """Unnormalized log density/energy of 2d strongly correlated Gaussian."""
- if verbose:
- print("Iteration %d: loss %.4f" % (i, loss))
+ xmmu = x - mu
+ return .5 * tf.diag_part(
+ tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
- if logdir:
- with summary_writer.as_default():
- with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("loss", loss)
+ return energy
-def get_scg_energy_fn():
+def get_multivariate_gaussian_energy_fn(x_dim=2):
"""Get energy function for 2d strongly correlated Gaussian."""
- # Avoid recreating tf constants on each invocation of gradients
- mu = tf.constant([0., 0.])
- sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
+ mu = tf.random_normal(shape=[x_dim])
+ # Lower triangularize and positive diagonal
+ l = tf.sigmoid(
+ tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0))
+ # Exploit Cholesky decomposition
+ sigma = tf.matmul(l, tf.transpose(l))
+ sigma *= 100. # Small covariance causes extreme numerical instability
sigma_inv = tf.matrix_inverse(sigma)
def energy(x):
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index 522a7c9380..e33b4cae4c 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -32,16 +32,83 @@ def get_default_hparams():
n_samples=200,
n_steps=10,
eps=.1,
- n_iters=5,
- learning_rate=.001,
- n_warmup_iters=1)
+ n_iters=10,
+ learning_rate=.0003,
+ n_warmup_iters=3)
+
+
+# Relevant functions for benchmarking
+def compute_loss(dynamics, x, scale=.1, eps=1e-4):
+ """Compute loss defined in equation (8)."""
+
+ z = tf.random_normal(tf.shape(x))
+ x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
+ z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
+
+ # Add eps for numerical stability; following released impl
+ x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
+ z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
+
+ loss = tf.reduce_mean(
+ (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
+
+ return loss, x_out
+
+
+def loss_and_grads(dynamics, x, loss_fn=compute_loss):
+ """Obtain loss value and gradients."""
+
+ with tf.GradientTape() as tape:
+ loss_val, x_out = loss_fn(dynamics, x)
+ grads = tape.gradient(loss_val, dynamics.variables)
+
+ return loss_val, grads, x_out
+
+
+def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss):
+ """Warmup optimization to reduce overhead."""
+
+ samples = tf.random_normal(
+ shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+
+ for _ in range(n_iters):
+ _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn)
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+
+def fit(dynamics,
+ samples,
+ optimizer,
+ loss_fn=compute_loss,
+ n_iters=5000,
+ verbose=True,
+ logdir=None,
+ decay_lr=True):
+ """Fit L2HMC sampler with given log-likelihood function."""
+
+ if logdir:
+ summary_writer = tf.contrib.summary.create_file_writer(logdir)
+
+ for i in range(n_iters):
+ loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn)
+ # TODO(lxuechen): Proper learning rate decay
+ if decay_lr:
+ grads = [grad * .96**(i // 1000) for grad in grads]
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+ if verbose:
+ print("Iteration %d: loss %.4f" % (i, loss))
+
+ if logdir:
+ with summary_writer.as_default():
+ with tf.contrib.summary.always_record_summaries():
+ tf.contrib.summary.scalar("loss", loss)
class L2hmcTest(tf.test.TestCase):
"""Unit tests for l2hmc in both eager and graph mode."""
- def testComputeLoss(self):
- """Testing function l2hmc.compute_loss in both graph and eager mode."""
+ def test_apply_transition(self):
+ """Testing function `Dynamics.apply_transition` in graph and eager mode."""
# Eager mode testing
hparams = get_default_hparams()
@@ -51,12 +118,12 @@ class L2hmcTest(tf.test.TestCase):
n_steps=hparams.n_steps,
eps=hparams.eps)
samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim])
- loss, x_out = l2hmc.compute_loss(samples, dynamics)
+ x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples)
- # Check shape and numerical stability
+ self.assertEqual(x_.shape, v_.shape)
self.assertEqual(x_out.shape, samples.shape)
- self.assertEqual(loss.shape, [])
- self.assertAllClose(loss.numpy(), loss.numpy(), rtol=1e-5)
+ self.assertEqual(x_.shape, x_out.shape)
+ self.assertEqual(x_accept_prob.shape, (hparams.n_samples,))
# Graph mode testing
with tf.Graph().as_default():
@@ -66,65 +133,49 @@ class L2hmcTest(tf.test.TestCase):
n_steps=hparams.n_steps,
eps=hparams.eps)
x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out = l2hmc.compute_loss(x, dynamics)
+ x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x)
samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
- loss_np, x_out_np = sess.run([loss, x_out], feed_dict={x: samples})
+ np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run(
+ [x_, v_, x_accept_prob, x_out], feed_dict={x: samples})
- # Check shape and numerical stability
- self.assertEqual(x_out_np.shape, samples.shape)
- self.assertEqual(loss_np.shape, ())
- self.assertAllClose(loss_np, loss_np, rtol=1e-5)
+ self.assertEqual(np_x_.shape, np_v_.shape)
+ self.assertEqual(samples.shape, np_x_out.shape)
+ self.assertEqual(np_x_.shape, np_x_out.shape)
+ self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,))
class L2hmcBenchmark(tf.test.Benchmark):
"""Eager and graph benchmarks for l2hmc."""
- def benchmarkEagerL2hmc(self):
- """Benchmark Eager performance."""
-
- hparams = get_default_hparams()
- dynamics = l2hmc.Dynamics(
- x_dim=hparams.x_dim,
- loglikelihood_fn=l2hmc.get_scg_energy_fn(),
- n_steps=hparams.n_steps,
- eps=hparams.eps)
- # TODO(lxuechen): Add learning rate decay
- optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
-
- # Warmup to reduce initialization effect when timing
- l2hmc.warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters)
+ def _get_energy_fn(self):
+ """Get specific energy function according to FLAGS."""
- # Time
- start_time = time.time()
- l2hmc.fit(
- dynamics,
- optimizer,
- n_samples=hparams.n_samples,
- n_iters=hparams.n_iters)
- wall_time = time.time() - start_time
- examples_per_sec = hparams.n_samples / wall_time
+ if FLAGS.energy_fn == "scg":
+ energy_fn = l2hmc.get_scg_energy_fn()
+ elif FLAGS.energy_fn == "multivariate_gaussian":
+ energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim)
+ else:
+ raise ValueError("No such energy function %s" % FLAGS.energy_fn)
- self.report_benchmark(
- name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"),
- iters=hparams.n_iters,
- extras={"examples_per_sec": examples_per_sec},
- wall_time=wall_time)
+ return energy_fn
- def benchmarkGraphL2hmc(self):
+ def benchmark_graph(self):
"""Benchmark Graph performance."""
hparams = get_default_hparams()
+ tf.reset_default_graph()
with tf.Graph().as_default():
+ energy_fn = self._get_energy_fn()
dynamics = l2hmc.Dynamics(
x_dim=hparams.x_dim,
- loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ loglikelihood_fn=energy_fn,
n_steps=hparams.n_steps,
eps=hparams.eps)
x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out = l2hmc.compute_loss(x, dynamics)
+ loss, x_out = compute_loss(dynamics, x)
global_step = tf.Variable(0., name="global_step", trainable=False)
learning_rate = tf.train.exponential_decay(
@@ -138,14 +189,15 @@ class L2hmcBenchmark(tf.test.Benchmark):
# Warmup to reduce initialization effect when timing
samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
for _ in range(hparams.n_warmup_iters):
- samples, _, _, _ = sess.run(
+ _, _, _, _ = sess.run(
[x_out, loss, train_op, learning_rate], feed_dict={x: samples})
- # Time
+ # Training
start_time = time.time()
- for _ in range(hparams.n_iters):
- samples, _, _, _ = sess.run(
+ for i in range(hparams.n_iters):
+ samples, loss_np, _, _ = sess.run(
[x_out, loss, train_op, learning_rate], feed_dict={x: samples})
+ print("Iteration %d: loss %.4f" % (i, loss_np))
wall_time = time.time() - start_time
examples_per_sec = hparams.n_samples / wall_time
@@ -156,7 +208,57 @@ class L2hmcBenchmark(tf.test.Benchmark):
extras={"examples_per_sec": examples_per_sec},
wall_time=wall_time)
+ def benchmark_eager(self):
+ self._benchmark_eager()
+
+ def benchmark_eager_defun(self):
+ self._benchmark_eager(defun=True)
+
+ def _benchmark_eager(self, defun=False):
+ """Benchmark Eager performance."""
+
+ hparams = get_default_hparams()
+ energy_fn = self._get_energy_fn()
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ loglikelihood_fn=energy_fn,
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+ loss_fn = tfe.defun(compute_loss) if defun else compute_loss
+
+ # Warmup to reduce initialization effect when timing
+ warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn)
+
+ # Training
+ samples = tf.random_normal(
+ shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
+ start_time = time.time()
+ fit(dynamics,
+ samples,
+ optimizer,
+ loss_fn=loss_fn,
+ n_iters=hparams.n_iters,
+ decay_lr=True)
+ wall_time = time.time() - start_time
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else
+ "cpu", "_defun" if defun else ""),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
+
+ del dynamics
+ del loss_fn
+
if __name__ == "__main__":
+ tf.flags.DEFINE_string("energy_fn", "scg",
+ ("The energy function/unnormalized log-probability. "
+ "Either be `scg` or `multivariate_gaussian`"))
+ tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.")
+ FLAGS = tf.flags.FLAGS
tf.enable_eager_execution()
tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
index c902e1f1f4..e230ad5e25 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
@@ -57,8 +57,6 @@ class GenericNet(tf.keras.Model):
initial_value=tf.zeros([1, x_dim]),
name='coeff_transformation',
trainable=True)
- # TODO(lxuechen): Remove this after model.add_weight is in place
- self.vars_not_in_layers = [self.coeff_scale, self.coeff_transformation]
def call(self, inputs):
v, x, t = inputs
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb
index 4fe3a0e3f3..5749f22ac5 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb
@@ -68,7 +68,7 @@
"# simply construct the object. Most layers take as a first argument the number\n",
"# of output dimensions / channels.\n",
"layer = tf.keras.layers.Dense(100)\n",
- "# The number of input dimensionss is often unnecessary, as it can be inferred\n",
+ "# The number of input dimensions is often unnecessary, as it can be inferred\n",
"# the first time the layer is used, but it can be provided if you want to \n",
"# specify it manually, which is useful in some complex models.\n",
"layer = tf.keras.layers.Dense(10, input_shape=(None, 5))"
@@ -267,7 +267,7 @@
" * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n",
" * `call`, where you do the forward computation\n",
"\n",
- "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes requires to create the variables will need to be explicitly specified."
+ "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified."
]
},
{
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index 00f03a111a..bc33596935 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide.
@@avg_pool2d
@@avg_pool3d
@@batch_norm
+@@convolution
+@@convolution1d
@@convolution2d
@@convolution3d
@@conv2d_in_plane
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index b7194ae333..b6d63c9640 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -57,10 +57,10 @@ from tensorflow.python.training import moving_averages
__all__ = [
'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d',
'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution',
- 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose',
- 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse',
- 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn',
- 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d',
+ 'convolution1d', 'convolution2d', 'convolution2d_in_plane',
+ 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose',
+ 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN',
+ 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d',
'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat',
'scale_gradient', 'separable_conv2d', 'separable_convolution2d',
'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm',
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 56e9194ceb..c5c7269b1f 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1312,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase):
self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
+ def testConv1dShape(self):
+ width = 7
+ with self.test_session():
+ images = random_ops.random_uniform((5, width, 3), seed=1)
+ output = layers_lib.convolution1d(images, 32, 3)
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
+
+ def testConvInferSpatialDims(self):
+ depth, height, width = 7, 9, 11
+ with self.test_session():
+ images = np.random.uniform(size=(5, width, 4)).astype(np.float32)
+ output = layers_lib.convolution(images, 32, [3])
+ self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
+ images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
+ output = layers_lib.convolution(images, 32, [3, 3])
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+ images = np.random.uniform(size=(5, depth, height, width,
+ 4)).astype(np.float32)
+ output = layers_lib.convolution(images, 32, [3, 3, 3])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth, height, width, 32])
+
class DenseToSparseTest(test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 541da90617..f8a3709ee5 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -505,7 +505,7 @@ class Experiment(object):
eval_result = None
last_warning_time = 0
while (not predicate_fn or predicate_fn(
- eval_result, checkpoint_path=previous_path if eval_result else None)):
+ eval_result, checkpoint_path=previous_path)):
# Exit if we have already reached number of steps to train.
if self._has_training_stopped(eval_result):
logging.info("Exiting continuous eval, global_step=%s >= "
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index d10927a0cd..fb16c94c29 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase):
noop_hook = _NoopHook()
def _predicate_fn(eval_result, checkpoint_path):
- self.assertEqual(not eval_result,
+ self.assertEqual(eval_result is None,
checkpoint_path is None)
return est.eval_count < 3 # pylint: disable=cell-var-from-loop
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index aa6a60dc9e..612813caee 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -204,6 +204,7 @@ def generated_test_models():
"conv",
"depthwiseconv",
"div",
+ "equal",
"exp",
"expand_dims",
"floor",
@@ -219,6 +220,7 @@ def generated_test_models():
"less_equal",
"local_response_norm",
"log_softmax",
+ "log",
"lstm",
"max_pool",
"maximum",
@@ -226,6 +228,7 @@ def generated_test_models():
"minimum",
"mul",
"neg",
+ "not_equal",
"pad",
"padv2",
# "prelu",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index fc6fdd6eef..f3b2ac77fb 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -96,6 +96,9 @@ typedef enum {
kTfLiteBuiltinSparseToDense = 68,
kTfLiteBuiltinTile = 69,
kTfLiteBuiltinExpandDims = 70,
+ kTfLiteBuiltinEqual = 71,
+ kTfLiteBuiltinNotEqual = 72,
+ kTfLiteBuiltinLog = 73,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD
index 9322e186a2..c61445114e 100644
--- a/tensorflow/contrib/lite/examples/label_image/BUILD
+++ b/tensorflow/contrib/lite/examples/label_image/BUILD
@@ -53,19 +53,18 @@ cc_library(
],
)
-# TODO(ahentz): Test disabled as it has a memory leek from read_bmp
-# cc_test(
-# name = "label_image_test",
-# srcs = [
-# "get_top_n.h",
-# "get_top_n_impl.h",
-# "label_image_test.cc",
-# ],
-# data = [
-# "testdata/grace_hopper.bmp",
-# ],
-# deps = [
-# ":bitmap_helpers",
-# "//testing/base/public:gunit",
-# ],
-# )
+cc_test(
+ name = "label_image_test",
+ srcs = [
+ "get_top_n.h",
+ "get_top_n_impl.h",
+ "label_image_test.cc",
+ ],
+ data = [
+ "testdata/grace_hopper.bmp",
+ ],
+ deps = [
+ ":bitmap_helpers",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
index 0b38cd38c8..2735d1f5ea 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
@@ -28,8 +28,9 @@ limitations under the License.
namespace tflite {
namespace label_image {
-uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output,
- int width, int height, int channels, bool top_down) {
+std::vector<uint8_t> decode_bmp(const uint8_t* input, int row_size, int width,
+ int height, int channels, bool top_down) {
+ std::vector<uint8_t> output(height * width * channels);
for (int i = 0; i < height; i++) {
int src_pos;
int dst_pos;
@@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output,
}
}
}
-
return output;
}
-uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
- int* channels, Settings* s) {
+std::vector<uint8_t> read_bmp(const std::string& input_bmp_name, int* width,
+ int* height, int* channels, Settings* s) {
int begin, end;
std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
@@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
if (s->verbose) LOG(INFO) << "len: " << len << "\n";
- const uint8_t* img_bytes = new uint8_t[len];
+ std::vector<uint8_t> img_bytes(len);
file.seekg(0, std::ios::beg);
- file.read((char*)img_bytes, len);
+ file.read(reinterpret_cast<char*>(img_bytes.data()), len);
const int32_t header_size =
- *(reinterpret_cast<const int32_t*>(img_bytes + 10));
- *width = *(reinterpret_cast<const int32_t*>(img_bytes + 18));
- *height = *(reinterpret_cast<const int32_t*>(img_bytes + 22));
- const int32_t bpp = *(reinterpret_cast<const int32_t*>(img_bytes + 28));
+ *(reinterpret_cast<const int32_t*>(img_bytes.data() + 10));
+ *width = *(reinterpret_cast<const int32_t*>(img_bytes.data() + 18));
+ *height = *(reinterpret_cast<const int32_t*>(img_bytes.data() + 22));
+ const int32_t bpp =
+ *(reinterpret_cast<const int32_t*>(img_bytes.data() + 28));
*channels = bpp / 8;
if (s->verbose)
@@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
bool top_down = (*height < 0);
// Decode image, allocating tensor once the image size is known
- uint8_t* output = new uint8_t[abs(*height) * *width * *channels];
const uint8_t* bmp_pixels = &img_bytes[header_size];
- return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height),
- *channels, top_down);
+ return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels,
+ top_down);
}
} // namespace label_image
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
index 97343dde6b..5fc75b1f72 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
@@ -22,8 +22,8 @@ limitations under the License.
namespace tflite {
namespace label_image {
-uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
- int* channels, Settings* s);
+std::vector<uint8_t> read_bmp(const std::string& input_bmp_name, int* width,
+ int* height, int* channels, Settings* s);
template <class T>
void resize(T* out, uint8_t* in, int image_height, int image_width,
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 966fcd2a31..86d7d1cc4a 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -138,8 +138,8 @@ void RunInference(Settings* s) {
int image_width = 224;
int image_height = 224;
int image_channels = 3;
- uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height,
- &image_channels, s);
+ std::vector<uint8_t> in = read_bmp(s->input_bmp_name, &image_width,
+ &image_height, &image_channels, s);
int input = interpreter->inputs()[0];
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
@@ -168,12 +168,12 @@ void RunInference(Settings* s) {
switch (interpreter->tensor(input)->type) {
case kTfLiteFloat32:
s->input_floating = true;
- resize<float>(interpreter->typed_tensor<float>(input), in, image_height,
- image_width, image_channels, wanted_height, wanted_width,
- wanted_channels, s);
+ resize<float>(interpreter->typed_tensor<float>(input), in.data(),
+ image_height, image_width, image_channels, wanted_height,
+ wanted_width, wanted_channels, s);
break;
case kTfLiteUInt8:
- resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
+ resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc
index ce35483f76..de7de21f77 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc
@@ -27,20 +27,20 @@ namespace label_image {
TEST(LabelImageTest, GraceHopper) {
std::string lena_file =
- "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp";
+ "tensorflow/contrib/lite/examples/label_image/testdata/"
+ "grace_hopper.bmp";
int height, width, channels;
Settings s;
- uint8_t *data;
-
- data = read_bmp(lena_file, &width, &height, &channels, &s);
+ std::vector<uint8_t> input =
+ read_bmp(lena_file, &width, &height, &channels, &s);
ASSERT_EQ(height, 606);
ASSERT_EQ(width, 517);
ASSERT_EQ(channels, 3);
- uint8_t *out = new uint8_t[606 * 517 * 3];
- downsize<uint8_t>(out, data, 606, 517, 3, 214, 214, 3, &s);
- ASSERT_EQ(out[0], 0x15);
- ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12);
+ std::vector<uint8_t> output(606 * 517 * 3);
+ resize<uint8_t>(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s);
+ ASSERT_EQ(output[0], 0x15);
+ ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11);
}
TEST(LabelImageTest, GetTopN) {
diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md
new file mode 100644
index 0000000000..bd2f797e6c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md
@@ -0,0 +1,206 @@
+# TensorFlow Lite Ops Versioning
+
+This document describes TensorFlow Lite's op versioning schema. Op
+versioning enables developers to add new functionalities and parameters into
+existing ops. In addition, it guarantees the following:
+
+* Backward compatibility: New TensorFlow Lite implementation should
+ handle an old model file.
+* Forward compatibility: Old TensorFlow Lite implementation should
+ handle a new model file produced by new version of TOCO, as long as no new
+ features are used.
+* Forward in-compatibility detection: If an old TensorFlow Lite implementation
+ reads a new model that contains a new version of an op which isn't
+ supported, it should report the error.
+
+## Example: Adding Dilation into Convolution
+
+The remainder of this document explains op versioning in TFLite by showing how
+to add dilation parameters to the convolution operation.
+
+Knowledge of dilation is not required to understand this document. Note that:
+
+* 2 new integer parameters will be added: `dilation_width_factor` and
+ `dilation_height_factor`.
+* Old convolution kernels that don't support dilation are equivalent to
+ setting the dilation factors to 1.
+
+### Change FlatBuffer Schema
+
+To add new parameters into an op, change the options table in
+`lite/schema/schema.fbs`.
+
+For example, the options table of convolution looks like this:
+
+```
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+```
+
+When adding new parameters:
+
+* Add comments indicating which parameters are supported by which version.
+* When the new implementation gets the default values for newly added
+ parameters, it should work exactly the same as the old implementation.
+
+The table will be like this after the new parameters are added:
+
+```
+table Conv2DOptions {
+ // Parameters supported by version 1:
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+
+ // Parameters supported by version 2:
+ dilation_width_factor:int = 1;
+ dilation_height_factor:int = 1;
+}
+```
+
+### Change C Structures and Kernel Implementation
+
+In TensorFlow Lite, the kernel implementation is decoupled from
+FlatBuffer definition. The kernels read the parameter from C structures defined
+in `lite/builtin_op_data.h`.
+
+The original convolution parameter is as follows:
+
+```
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+```
+
+As with the FlatBuffer schema, add comments indicating which parameters are
+supported starting from which version. The result is seen below:
+
+```
+typedef struct {
+ // Parameters supported by version 1: TfLitePadding padding; int
+ stride_width;
+ int stride_height;
+ TfLiteFusedActivation activation;
+
+ // Parameters supported by version 2:
+ int dilation_width_factor;
+ int dilation_height_factor;
+} TfLiteConvParams;
+```
+
+Please also change the kernel implementation to read the newly added parameters
+from the C structures. The details are omitted here.
+
+### Change the FlatBuffer Reading Code
+
+The logic to read FlatBuffer and produce C structure is in `lite/model.cc`.
+
+Update the file to handle the new parameters, as shown below:
+
+```
+case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ params->dilation_width_factor = conv_params->dilation_width_factor();
+ params->dilation_height_factor = conv_params->dilation_height_factor();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+}
+```
+
+It's not required to check the op version here. When the new implementation
+reads an old model file where dilation factors are missing, it will use 1 as
+the default value, and the new kernel will work consistently with the old
+kernel.
+
+### Change Kernel Registration
+
+The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions
+to register op kernels. The minimum and maximum version are 1 by default:
+
+```
+void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+```
+
+The built-in ops are registered in `lite/kernels/register.cc`. In this example,
+we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we
+need to change this line:
+
+```
+AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
+```
+
+to:
+
+```
+AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2);
+```
+
+### Change TOCO TFLite exporter
+
+The last step is to make TOCO populate the minimum version that's required to
+execute the op. In this example, it means:
+
+* Populate version=1 when dilation factors are all 1.
+* Populate version=2 otherwise.
+
+To do this, you need to override `GetVersion` function for the operator class in
+`lite/toco/tflite/operator.cc`.
+
+For ops with only one version, the `GetVersion` function is defined as:
+
+```
+int GetVersion(const Operator& op) const override { return 1; }
+```
+
+When supporting multiple versions, check the parameters and determine the
+version for the op, as shown in the following example:
+
+```
+int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const ConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+}
+```
+
+### Delegation Implementation
+
+TensorFlow Lite provides a delegation API which enables delegating ops to
+hardware backends. In Delegate's `Prepare` function, check if the version
+is supported for every node in Delegation code.
+
+```
+const int kMinVersion = 1;
+TfLiteNode* node;
+TfLiteRegistration;
+context->GetNodeAndRegistration(context, node_index, &node, &registration);
+
+if (registration->version > kMinVersion) {
+ // Reject the node if the version isn't supported.
+}
+```
+
+This is required even if the delegation only supports version 1 ops, so the
+delegation can detect incompatibility when getting a higher version op.
+
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index b2f6444e9e..965273f0f0 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph:
* [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide)
* [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
* [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars)
-* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater)
-* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal)
* [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity)
-* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less)
-* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal)
* [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum)
* [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum)
* [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply)
@@ -257,6 +253,19 @@ Options {
}
```
+**EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ equal to the corresponding element of the second tensor.
+}
+```
+
**EXP**
```
@@ -420,6 +429,17 @@ Outputs {
}
```
+**LOG**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to log(input)
+}
+```
+
**LOG_SOFTMAX**
```
@@ -503,6 +523,19 @@ Options {
}
```
+**NOT_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is not
+ equal to the corresponding element of the second tensor.
+}
+```
+
**RELU**
```
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 005dca0253..9e9387da86 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -43,31 +43,27 @@ size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
}
switch (type) {
case kTfLiteFloat32: {
- jfloatArray a = static_cast<jfloatArray>(array);
- jfloat* values = env->GetFloatArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseFloatArrayElements(a, values, JNI_ABORT);
+ jfloatArray float_array = static_cast<jfloatArray>(array);
+ jfloat* float_dst = static_cast<jfloat*>(dst);
+ env->GetFloatArrayRegion(float_array, 0, num_elements, float_dst);
return to_copy;
}
case kTfLiteInt32: {
- jintArray a = static_cast<jintArray>(array);
- jint* values = env->GetIntArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseIntArrayElements(a, values, JNI_ABORT);
+ jintArray int_array = static_cast<jintArray>(array);
+ jint* int_dst = static_cast<jint*>(dst);
+ env->GetIntArrayRegion(int_array, 0, num_elements, int_dst);
return to_copy;
}
case kTfLiteInt64: {
- jlongArray a = static_cast<jlongArray>(array);
- jlong* values = env->GetLongArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseLongArrayElements(a, values, JNI_ABORT);
+ jlongArray long_array = static_cast<jlongArray>(array);
+ jlong* long_dst = static_cast<jlong*>(dst);
+ env->GetLongArrayRegion(long_array, 0, num_elements, long_dst);
return to_copy;
}
case kTfLiteUInt8: {
- jbyteArray a = static_cast<jbyteArray>(array);
- jbyte* values = env->GetByteArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseByteArrayElements(a, values, JNI_ABORT);
+ jbyteArray byte_array = static_cast<jbyteArray>(array);
+ jbyte* byte_dst = static_cast<jbyte*>(dst);
+ env->GetByteArrayRegion(byte_array, 0, num_elements, byte_dst);
return to_copy;
}
default: {
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 3b81062cd4..f678f48fa5 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -23,6 +23,7 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace comparisons {
+namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
@@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<type>(input2), GetTensorDims(input2), \
GetTensorData<bool>(output), GetTensorDims(output));
+TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// TODO(renjieliu): Refactor the logic to avoid duplications.
+TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
@@ -167,8 +219,22 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+} // namespace
} // namespace comparisons
+TfLiteRegistration* Register_EQUAL() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_NOT_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::NotEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_GREATER() {
static TfLiteRegistration r = {nullptr, nullptr,
comparisons::ComparisonPrepare,
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 835d238d36..bb02e1c812 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -21,18 +21,17 @@ limitations under the License.
namespace tflite {
namespace {
-using ::testing::ElementsAreArray;
+using ::testing::ElementsAre;
-class GreaterOpModel : public SingleOpModel {
+class ComparisonOpModel : public SingleOpModel {
public:
- GreaterOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
+ ComparisonOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type, BuiltinOperator op) {
input1_ = AddInput(input_type);
input2_ = AddInput(input_type);
output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
- CreateGreaterOptions(builder_).Union());
+ ConfigureBuiltinOp(op);
BuildInterpreter({input1_shape, input2_shape});
}
@@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel {
int input1_;
int input2_;
int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_EqualOptions,
+ CreateEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_NOT_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_NotEqualOptions,
+ CreateNotEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS: {
+ SetBuiltinOp(op, BuiltinOptions_LessOptions,
+ CreateLessOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
};
-TEST(ComparisonsTest, GreaterFloat) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+TEST(ComparisonsTest, EqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterInt) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcast) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcastTwoD) {
- GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false,
+ false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class GreaterEqualOpModel : public SingleOpModel {
- public:
- GreaterEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
- BuiltinOptions_GreaterEqualOptions,
- CreateGreaterEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
+TEST(ComparisonsTest, NotEqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
- int input1() { return input1_; }
- int input2() { return input2_; }
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+TEST(ComparisonsTest, NotEqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
- private:
- int input1_;
- int input2_;
- int output_;
-};
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, true, true, true, true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
+
+TEST(ComparisonsTest, GreaterFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
TEST(ComparisonsTest, GreaterEqualFloat) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualInt) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcast) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
- GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessOpModel : public SingleOpModel {
- public:
- LessOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape, TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions,
- CreateLessOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
TEST(ComparisonsTest, LessFloat) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessInt) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcast) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcastTwoD) {
- LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessEqualOpModel : public SingleOpModel {
- public:
- LessEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
- CreateLessEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
-
TEST(ComparisonsTest, LessEqualFloat) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualInt) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcast) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
- LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index ee42e5cdc8..747c8a62c0 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
data->need_im2col =
(params->stride_width != 1 || params->stride_height != 1 ||
- filter_width != 1 || filter_height != 1);
+ params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1 || filter_width != 1 ||
+ filter_height != 1);
// If we're using the optimized multithreaded EigenTensor implementation of
// convolution, it expects the filter weights to be transposed compared to
// the normal TF Lite buffer format. Typical TF Lite weights are
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index 0bd5046950..98c21ce9d3 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -23,7 +23,7 @@ namespace ops {
namespace builtin {
namespace elementwise {
-TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
@@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
-TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node,
+ float float_func(float)) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
@@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
const float* in = GetTensorData<float>(input);
const float* in_end = in + elements;
float* out = output->data.f;
- for (; in < in_end; in++, out++) *out = std::sin(*in);
+ for (; in < in_end; in++, out++) *out = float_func(*in);
return kTfLiteOk;
}
default: {
@@ -55,14 +56,28 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::sin);
+}
+
+TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::log);
+}
+
} // namespace elementwise
TfLiteRegistration* Register_SIN() {
- static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare,
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::SinEval};
return &r;
}
+TfLiteRegistration* Register_LOG() {
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
+ elementwise::LogEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index 412ffb04b9..10e88d5a31 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -24,12 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
-class SinOpModel : public SingleOpModel {
+class ElementWiseOpModel : public SingleOpModel {
public:
- SinOpModel(std::initializer_list<int> input_shape) {
+ ElementWiseOpModel(BuiltinOperator op,
+ std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0);
+ SetBuiltinOp(op, BuiltinOptions_NONE, 0);
BuildInterpreter({input_shape});
}
@@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel {
};
TEST(ElementWise, Sin) {
- SinOpModel m({1, 1, 4, 1});
+ ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
@@ -50,6 +51,15 @@ TEST(ElementWise, Sin) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Log) {
+ ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 7539c0b30d..9410bead5e 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -24,7 +24,8 @@ limitations under the License.
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
-// Each item in output is a raw bytes copy of corresponding item in input.
+// Each item in output is a raw bytes copy of the corresponding item in input,
+// or a dequantized value in the case of a uint8 input.
// When indices are out of bound, the ops will not succeed.
//
@@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, outputSize);
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* output = GetOutput(context, node, 0);
- const TfLiteTensor* lookup = GetInput(context, node, 0);
- const TfLiteTensor* value = GetInput(context, node, 1);
-
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
@@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
+ const int row_size = SizeOfDimension(value, 0);
+ const double scaling_factor = 1.0 / value->params.scale;
+
+ // col_size after we flatten tensor into 2D.
+ int col_size = 1;
+ for (int i = 1; i < NumDimensions(value); i++) {
+ col_size *= SizeOfDimension(value, i);
+ }
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ // Dequantize embedding values.
+ // TODO(alanchiao): refactor scalar multiply into separate function
+ // for ease of adding a neon equivalent if ever necessary.
+ for (int j = 0; j < col_size; j++) {
+ output->data.f[j + i * col_size] =
+ value->data.uint8[j + idx * col_size] * scaling_factor;
+ }
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* value = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (value->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, lookup, value, output);
+ case kTfLiteUInt8:
+ return EvalHybrid(context, node, lookup, value, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+}
+
} // namespace embedding_lookup
TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
index 9b501878f1..04657fd863 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -7,13 +7,14 @@ You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
+for the specific language governing permissions and limitations under the
+License.
==============================================================================*/
// Unit test for TFLite Lookup op.
+#include <initializer_list>
#include <iomanip>
#include <vector>
@@ -29,12 +30,13 @@ namespace {
using ::testing::ElementsAreArray;
-class EmbeddingLookupOpModel : public SingleOpModel {
+class BaseEmbeddingLookupOpModel : public SingleOpModel {
public:
- EmbeddingLookupOpModel(std::initializer_list<int> index_shape,
- std::initializer_list<int> weight_shape) {
+ BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape,
+ TensorType weight_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_INT32);
- weight_ = AddInput(TensorType_FLOAT32);
+ weight_ = AddInput(weight_type);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
BuildInterpreter({index_shape, weight_shape});
@@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int weight_;
+ int output_;
+};
+
+class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
+
void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
TfLiteTensor* tensor = interpreter_->tensor(weight_);
int rows = tensor->dims->data[0];
@@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel {
}
}
}
+};
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ HybridEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape)
+ : BaseEmbeddingLookupOpModel(index_shape, weight_shape,
+ TensorType_UINT8) {}
- private:
- int input_;
- int weight_;
- int output_;
+ void SetWeight(std::initializer_list<float> data) {
+ SymmetricQuantizeAndPopulate(weight_, data);
+ }
};
// TODO(ahentz): write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
EmbeddingLookupOpModel m({3}, {3, 2, 4});
- m.PopulateTensor<int>(0, {1, 0, 2});
+ m.SetInput({1, 0, 2});
m.Set3DWeightMatrix(
[](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
@@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) {
})));
}
+TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 8});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 989920622d..5a0524bec6 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -105,7 +105,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int batch_size = input_size / filter->dims->data[1];
const int num_units = filter->dims->data[0];
- TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]);
+ TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]);
if (bias) {
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 6e62183975..36c25388e8 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -350,7 +350,7 @@ void LstmStep(
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
+ scaling_factors[b] * input_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
@@ -409,7 +409,7 @@ void LstmStep(
}
// Save quantization and matmul computation for all zero input.
- const bool is_cell_state_all_zeros =
+ bool is_cell_state_all_zeros =
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
// For each batch and cell: update input gate.
@@ -455,6 +455,8 @@ void LstmStep(
params->cell_clip, cell_state_ptr);
}
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
// For each batch and cell: update the output gate.
if (use_peephole && !is_cell_state_all_zeros) {
VectorMultiply(cell_to_output_weights_ptr, n_cell,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 0ce781db59..d2bee2cd70 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -6289,8 +6289,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// To optimize, start by using the conv code with transposed weights for the
// case of stride_height = stride_width = 1.
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
const int input_height = ArraySize(input_dims, 2);
const int input_width = ArraySize(input_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
@@ -6337,8 +6337,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
float input_value = input_data[Offset(input_dims, in_channel,
in_x, in_y, batch)];
float filter_value =
- filter_data[Offset(filter_dims, out_channel, filter_x,
- filter_y, in_channel)];
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
output_data[Offset(output_dims, out_channel, out_x, out_y,
batch)] += input_value * filter_value;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 4781bbc70a..5bd518f6d6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3810,8 +3810,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
int pad_height, float* output_data,
const Dims<4>& output_dims) {
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
const int input_height = ArraySize(input_dims, 2);
const int input_width = ArraySize(input_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
@@ -3851,8 +3851,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
float input_value = input_data[Offset(input_dims, in_channel,
in_x, in_y, batch)];
float filter_value =
- filter_data[Offset(filter_dims, out_channel, filter_x,
- filter_y, in_channel)];
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
output_data[Offset(output_dims, out_channel, out_x, out_y,
batch)] += input_value * filter_value;
}
@@ -3866,6 +3866,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline bool EqualFn(T lhs, T rhs) {
+ return lhs == rhs;
+}
+
+template <typename T>
+inline bool NotEqualFn(T lhs, T rhs) {
+ return lhs != rhs;
+}
+
+template <typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs;
}
@@ -4028,6 +4038,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
input2_offset, input2_multiplier, \
input2_shift, output_data, output_dims); \
}
+TFLITE_COMPARISON_OP(Equal);
+TFLITE_COMPARISON_OP(NotEqual);
TFLITE_COMPARISON_OP(Greater);
TFLITE_COMPARISON_OP(GreaterEqual);
TFLITE_COMPARISON_OP(Less);
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 0c7fb7a76a..1086c5b092 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -142,6 +142,22 @@ class RuntimeShape {
};
};
+// Converts inference-style shape to legacy tflite::Dims<4>.
+inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
+ tflite::Dims<4> result;
+ const int dimensions_count = array_shape.DimensionsCount();
+ TFLITE_CHECK_LE(dimensions_count, 4);
+ int cum_prod = 1;
+ for (int i = 0; i < 4; i++) {
+ const int new_dim =
+ (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
+ result.sizes[i] = new_dim;
+ result.strides[i] = cum_prod;
+ cum_prod *= new_dim;
+ }
+ return result;
+}
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
TFLITE_DCHECK_GT(num_dims, 0);
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 184b02dcec..7bb28d4de7 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -73,6 +73,7 @@ TfLiteRegistration* Register_SQUEEZE();
TfLiteRegistration* Register_STRIDED_SLICE();
TfLiteRegistration* Register_EXP();
TfLiteRegistration* Register_TOPK_V2();
+TfLiteRegistration* Register_LOG();
TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE();
@@ -93,6 +94,8 @@ TfLiteRegistration* Register_SIN();
TfLiteRegistration* Register_TRANSPOSE_CONV();
TfLiteRegistration* Register_EXPAND_DIMS();
TfLiteRegistration* Register_SPARSE_TO_DENSE();
+TfLiteRegistration* Register_EQUAL();
+TfLiteRegistration* Register_NOT_EQUAL();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -148,6 +151,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
+ AddBuiltin(BuiltinOperator_LOG, Register_LOG());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
@@ -168,6 +172,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
+ AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
+ AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 3c99661029..e83b1ec987 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -79,7 +79,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Ensure that weights and inputs have the same channel dimension.
// Note: TOCO will reorder weights in the following format: OHWI.
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
- SizeOfDimension(weights, 0));
+ SizeOfDimension(weights, 3));
if (!IsConstantTensor(output_shape)) {
SetTensorToDynamic(output);
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
index 52be089349..55df897180 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
@@ -88,10 +88,10 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
TEST(TransposeConvOpModelTest, TwoFiltersTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1);
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18});
m.PopulateTensor<float>(
m.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
@@ -117,10 +117,10 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) {
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
TEST(TransposeConvOpModelTest, PaddingValidTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1);
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18});
m.PopulateTensor<float>(
m.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
@@ -171,10 +171,10 @@ TEST(TransposeConvOpModelTest, StrideValidTest) {
// [1, 2, 2, 1 ],
// "VALID")
TEST(TransposeConvOpModelTest, MultiChannelTest) {
- TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2);
+ TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
- m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
+ 8, 10, 12, 14, 16, 18});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
m.Invoke();
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 8d8d74adfb..4fb1ada9fd 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -357,6 +357,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_FLOOR:
case BuiltinOperator_NEG:
case BuiltinOperator_SIN:
+ case BuiltinOperator_LOG:
break;
case BuiltinOperator_CAST: {
TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
@@ -689,6 +690,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_NOT_EQUAL:
case BuiltinOperator_SELECT: {
break;
}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index d27ab0c033..99cb40e967 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -490,10 +490,13 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_SELECT:
case tflite::BuiltinOperator_SLICE:
case tflite::BuiltinOperator_SIN:
+ case tflite::BuiltinOperator_LOG:
case tflite::BuiltinOperator_TRANSPOSE_CONV:
case tflite::BuiltinOperator_TILE:
case tflite::BuiltinOperator_EXPAND_DIMS:
case tflite::BuiltinOperator_SPARSE_TO_DENSE:
+ case tflite::BuiltinOperator_EQUAL:
+ case tflite::BuiltinOperator_NOT_EQUAL:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
index 6f2c9cd2b3..45388b500c 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
@@ -85,11 +85,18 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
return details;
}
+tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() {
+ auto options = tensorflow::StatSummarizerOptions();
+ options.show_summary = true;
+ options.show_memory = false;
+ return options;
+}
+
} // namespace
ProfileSummarizer::ProfileSummarizer()
- : stats_calculator_(new ::tensorflow::StatsCalculator(
- tensorflow::StatSummarizerOptions())) {}
+ : stats_calculator_(
+ new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {}
void ProfileSummarizer::ProcessProfiles(
const std::vector<const ProfileEvent*>& profile_stats,
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 7e6ff6c0a8..27909a9458 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -57,8 +57,9 @@ py_library(
":interpreter",
":lite_constants",
":op_hint",
- "//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:constants",
+ "//tensorflow/python/saved_model:loader",
"//tensorflow/python/tools:freeze_graph_lib",
],
)
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 0819475240..fce8ffb54a 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -123,7 +123,10 @@ def toco_convert(input_data,
drop_control_dependency=True,
reorder_across_fake_quant=False,
allow_custom_ops=False,
- change_concat_input_ranges=False):
+ change_concat_input_ranges=False,
+ quantize_weights=False,
+ dump_graphviz_dir=None,
+ dump_graphviz_video=False):
"""Convert a model using TOCO from `input_format` to `output_format`.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -143,10 +146,9 @@ def toco_convert(input_data,
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
- quantized_input_stats: Dict of strings representing input tensor names
- mapped to tuple of integers representing the mean and standard deviation
- of the training data (e.g., {"foo" : (0., 1.)}). Only need if
- `inference_type` is `QUANTIZED_UINT8`. (default None)
+ quantized_input_stats: List of tuples of integers representing the mean and
+ standard deviation. Each tuple maps to the corresponding input tensor.
+ Only need if `inference_type` is `QUANTIZED_UINT8`. (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -158,14 +160,24 @@ def toco_convert(input_data,
nodes is preventing graph transformations necessary to convert the graph.
Results in a graph that differs from the quantized training graph,
potentially causing differing arithmetic behavior. (default False)
- change_concat_input_ranges: Boolean to change behavior of min/max ranges for
- inputs and outputs of the concat operator for quantized models. Changes
- the ranges of concat operator overlap when true. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When false any unknown operation is an error. When true, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
+ change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+ inputs and outputs of the concat operator for quantized models. Changes
+ the ranges of concat operator overlap when true. (default False)
+ quantize_weights: Boolean indicating whether to store weights as quantized
+ weights followed by dequantize operations. Computation is still done in
+ float, but reduces model size (at the cost of accuracy and latency).
+ (default False)
+ dump_graphviz_dir: Full filepath of folder to dump the graphs at various
+ stages of processing GraphViz .dot files. Preferred over
+ --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
+ output file. (default None)
+ dump_graphviz_video: Boolean indicating whether to dump the graph after
+ every graph transformation. (default False)
Returns:
The converted data. For example if TFLite was the destination, then
@@ -185,9 +197,13 @@ def toco_convert(input_data,
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
+ toco.quantize_weights = quantize_weights
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]
+ if dump_graphviz_dir:
+ toco.dump_graphviz_dir = dump_graphviz_dir
+ toco.dump_graphviz_include_video = dump_graphviz_video
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 5dad49f1ed..1553464b9f 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.lite.python.convert import tensor_name
-from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
@@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
Raises:
ValueError: No valid MetaGraphDef for given tag_set.
"""
- saved_model = reader.read_saved_model(saved_model_dir)
- tag_sets = []
- result_meta_graph_def = None
- for meta_graph_def in saved_model.meta_graphs:
- meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
- tag_sets.append(meta_graph_tag_set)
- if meta_graph_tag_set == tag_set:
- result_meta_graph_def = meta_graph_def
- logging.info("The given saved_model contains the following tags: %s",
- tag_sets)
- if result_meta_graph_def is not None:
- return result_meta_graph_def
- else:
- raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible "
- "values are '{}'. ".format(tag_set, tag_sets))
+ with session.Session(graph=ops.Graph()) as sess:
+ return loader.load(sess, tag_set, saved_model_dir)
def _get_signature_def(meta_graph, signature_key):
@@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key):
raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
"values are '{}'.".format(signature_key,
",".join(signature_def_keys)))
- signature_def = signature_def_utils.get_signature_def_by_key(
- meta_graph, signature_key)
- return signature_def
+ return signature_def_map[signature_key]
def _get_inputs_outputs(signature_def):
@@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
+ assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
@@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
+ # Check SavedModel for assets directory.
+ collection_def = meta_graph.collection_def
+ if constants.ASSETS_KEY in collection_def:
+ raise ValueError("SavedModels with assets/ directory are not supported.")
+
graph = ops.Graph()
with session.Session(graph=graph) as sess:
- # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index eee8faa2dc..3f644c8acd 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -94,6 +94,16 @@ class TocoConverter(object):
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
+ quantize_weights: Boolean indicating whether to store weights as quantized
+ weights followed by dequantize operations. Computation is still done in
+ float, but reduces model size (at the cost of accuracy and latency).
+ (default False)
+ dump_graphviz_dir: Full filepath of folder to dump the graphs at various
+ stages of processing GraphViz .dot files. Preferred over
+ --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
+ output file. (default None)
+ dump_graphviz_video: Boolean indicating whether to dump the graph after
+ every graph transformation. (default False)
Example usage:
@@ -135,6 +145,9 @@ class TocoConverter(object):
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.allow_custom_ops = False
+ self.quantize_weights = False
+ self.dump_graphviz_dir = None
+ self.dump_graphviz_video = False
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
@@ -210,7 +223,7 @@ class TocoConverter(object):
# Check if graph is frozen.
if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py")
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
# Create TocoConverter class.
return cls(sess.graph_def, input_tensors, output_tensors)
@@ -310,7 +323,10 @@ class TocoConverter(object):
drop_control_dependency=self.drop_control_dependency,
reorder_across_fake_quant=self.reorder_across_fake_quant,
change_concat_input_ranges=self.change_concat_input_ranges,
- allow_custom_ops=self.allow_custom_ops)
+ allow_custom_ops=self.allow_custom_ops,
+ quantize_weights=self.quantize_weights,
+ dump_graphviz_dir=self.dump_graphviz_dir,
+ dump_graphviz_video=self.dump_graphviz_video)
return result
def get_input_arrays(self):
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 019a3a5f69..8c9d2c1651 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -25,9 +25,11 @@ from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -218,6 +220,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
+ # TODO(nupurgarg): Verify value of contents in GraphViz.
def testGraphviz(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
@@ -230,6 +233,39 @@ class FromSessionTest(test_util.TensorFlowTestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
+ # TODO(nupurgarg): Verify value of contents in GraphViz.
+ def testDumpGraphviz(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ graphviz_dir = self.get_temp_dir()
+ converter.dump_graphviz_dir = graphviz_dir
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure interpreter is able to allocate and check graphviz data.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ num_items_graphviz = len(os.listdir(graphviz_dir))
+ self.assertTrue(num_items_graphviz)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ graphviz_dir = self.get_temp_dir()
+ converter.dump_graphviz_dir = graphviz_dir
+ converter.dump_graphviz_video = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure graphviz folder has more data after using video flag.
+ num_items_graphviz_video = len(os.listdir(graphviz_dir))
+ self.assertTrue(num_items_graphviz_video > num_items_graphviz)
+
def testInferenceInputType(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
out_tensor = in_tensor + in_tensor
@@ -291,6 +327,36 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizeWeights(self):
+ np.random.seed(0)
+ # We need the tensor to have more than 1024 elements for quantize_weights
+ # to kick in. Thus, the [33, 33] shape.
+ in_tensor_1 = array_ops.placeholder(
+ shape=[33, 33], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = constant_op.constant(
+ np.random.uniform(low=-10., high=10., size=(33, 33)),
+ shape=[33, 33],
+ dtype=dtypes.float32,
+ name='inputB')
+ out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+ sess = session.Session()
+
+ # Convert float model.
+ float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1],
+ [out_tensor])
+ float_tflite = float_converter.convert()
+ self.assertTrue(float_tflite)
+
+ # Convert quantized weights model.
+ quantized_weights_converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1], [out_tensor])
+ quantized_weights_converter.quantize_weights = True
+ quantized_weights_tflite = quantized_weights_converter.convert()
+ self.assertTrue(quantized_weights_tflite)
+
+ # Ensure that the quantized weights tflite model is smaller.
+ self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -369,7 +435,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
- self.assertEqual('Please freeze the graph using freeze_graph.py',
+ self.assertEqual('Please freeze the graph using freeze_graph.py.',
str(error.exception))
def testPbtxt(self):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 6d77626a4b..32ad84ec3c 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -114,7 +114,7 @@ def _convert_model(flags):
"--input_arrays must be present when specifying "
"--std_dev_values and --mean_values with multiple input "
"tensors in order to map between names and "
- "values".format(",".join(input_arrays)))
+ "values.".format(",".join(input_arrays)))
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
if flags.default_ranges_min and flags.default_ranges_max:
converter.default_ranges_stats = (flags.default_ranges_min,
@@ -128,6 +128,12 @@ def _convert_model(flags):
converter.change_concat_input_ranges = flags.change_concat_input_ranges
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
+ if flags.quantize_weights:
+ converter.quantize_weights = flags.quantize_weights
+ if flags.dump_graphviz_dir:
+ converter.dump_graphviz_dir = flags.dump_graphviz_dir
+ if flags.dump_graphviz_video:
+ converter.dump_graphviz_vode = flags.dump_graphviz_video
# Convert model.
output_data = converter.convert()
@@ -159,8 +165,12 @@ def _check_flags(flags, unparsed):
output = ""
for flag in unparsed:
output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
+ output += _get_message_unparsed(flag, "--savedmodel_directory",
+ "--saved_model_dir")
output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
+ output += _get_message_unparsed(flag, "--dump_graphviz",
+ "--dump_graphviz_dir")
if output:
raise ValueError(output)
@@ -217,17 +227,17 @@ def run_main(_):
# Model format flags.
parser.add_argument(
"--output_format",
- type=str,
+ type=str.upper,
choices=["TFLITE", "GRAPHVIZ_DOT"],
help="Output file format.")
parser.add_argument(
"--inference_type",
- type=str,
+ type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
help="Target data type of arrays in the output file.")
parser.add_argument(
"--inference_input_type",
- type=str,
+ type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
help=("Target data type of input arrays. Allows for a different type for "
"input arrays in the case of quantization."))
@@ -282,6 +292,12 @@ def run_main(_):
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
+ parser.add_argument(
+ "--quantize_weights",
+ type=bool,
+ help=("Store float weights as quantized weights followed by dequantize "
+ "operations. Inference is still done in FLOAT, but reduces model "
+ "size (at the cost of accuracy and latency)."))
# Graph manipulation flags.
parser.add_argument(
@@ -314,6 +330,20 @@ def run_main(_):
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
+ # Logging flags.
+ parser.add_argument(
+ "--dump_graphviz_dir",
+ type=str,
+ help=("Full filepath of folder to dump the graphs at various stages of "
+ "processing GraphViz .dot files. Preferred over --output_format="
+ "GRAPHVIZ_DOT in order to keep the requirements of the output "
+ "file."))
+ parser.add_argument(
+ "--dump_graphviz_video",
+ action="store_true",
+ help=("Boolean indicating whether to dump the graph after every graph "
+ "transformation"))
+
tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
try:
_check_flags(tflite_flags, unparsed)
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7dbb36c864..ee5208df14 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -148,6 +148,9 @@ enum BuiltinOperator : byte {
SPARSE_TO_DENSE = 68,
TILE = 69,
EXPAND_DIMS = 70,
+ EQUAL = 71,
+ NOT_EQUAL = 72,
+ LOG = 73,
}
// Options for the builtin operators.
@@ -204,6 +207,8 @@ union BuiltinOptions {
SparseToDenseOptions,
TileOptions,
ExpandDimsOptions,
+ EqualOptions,
+ NotEqualOptions,
}
enum Padding : byte { SAME, VALID }
@@ -478,6 +483,12 @@ table SparseToDenseOptions {
validate_indices:bool;
}
+table EqualOptions {
+}
+
+table NotEqualOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index b1beb39b28..887e47ed1e 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -187,6 +187,12 @@ struct ExpandDimsOptionsT;
struct SparseToDenseOptions;
struct SparseToDenseOptionsT;
+struct EqualOptions;
+struct EqualOptionsT;
+
+struct NotEqualOptions;
+struct NotEqualOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -317,11 +323,14 @@ enum BuiltinOperator {
BuiltinOperator_SPARSE_TO_DENSE = 68,
BuiltinOperator_TILE = 69,
BuiltinOperator_EXPAND_DIMS = 70,
+ BuiltinOperator_EQUAL = 71,
+ BuiltinOperator_NOT_EQUAL = 72,
+ BuiltinOperator_LOG = 73,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_EXPAND_DIMS
+ BuiltinOperator_MAX = BuiltinOperator_LOG
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -392,7 +401,10 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] {
BuiltinOperator_TRANSPOSE_CONV,
BuiltinOperator_SPARSE_TO_DENSE,
BuiltinOperator_TILE,
- BuiltinOperator_EXPAND_DIMS
+ BuiltinOperator_EXPAND_DIMS,
+ BuiltinOperator_EQUAL,
+ BuiltinOperator_NOT_EQUAL,
+ BuiltinOperator_LOG
};
return values;
}
@@ -470,6 +482,9 @@ inline const char **EnumNamesBuiltinOperator() {
"SPARSE_TO_DENSE",
"TILE",
"EXPAND_DIMS",
+ "EQUAL",
+ "NOT_EQUAL",
+ "LOG",
nullptr
};
return names;
@@ -534,11 +549,13 @@ enum BuiltinOptions {
BuiltinOptions_SparseToDenseOptions = 50,
BuiltinOptions_TileOptions = 51,
BuiltinOptions_ExpandDimsOptions = 52,
+ BuiltinOptions_EqualOptions = 53,
+ BuiltinOptions_NotEqualOptions = 54,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_ExpandDimsOptions
+ BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -592,7 +609,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] {
BuiltinOptions_TransposeConvOptions,
BuiltinOptions_SparseToDenseOptions,
BuiltinOptions_TileOptions,
- BuiltinOptions_ExpandDimsOptions
+ BuiltinOptions_ExpandDimsOptions,
+ BuiltinOptions_EqualOptions,
+ BuiltinOptions_NotEqualOptions
};
return values;
}
@@ -652,6 +671,8 @@ inline const char **EnumNamesBuiltinOptions() {
"SparseToDenseOptions",
"TileOptions",
"ExpandDimsOptions",
+ "EqualOptions",
+ "NotEqualOptions",
nullptr
};
return names;
@@ -874,6 +895,14 @@ template<> struct BuiltinOptionsTraits<ExpandDimsOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions;
};
+template<> struct BuiltinOptionsTraits<EqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions;
+};
+
+template<> struct BuiltinOptionsTraits<NotEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1321,6 +1350,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ExpandDimsOptions ?
reinterpret_cast<const ExpandDimsOptionsT *>(value) : nullptr;
}
+ EqualOptionsT *AsEqualOptions() {
+ return type == BuiltinOptions_EqualOptions ?
+ reinterpret_cast<EqualOptionsT *>(value) : nullptr;
+ }
+ const EqualOptionsT *AsEqualOptions() const {
+ return type == BuiltinOptions_EqualOptions ?
+ reinterpret_cast<const EqualOptionsT *>(value) : nullptr;
+ }
+ NotEqualOptionsT *AsNotEqualOptions() {
+ return type == BuiltinOptions_NotEqualOptions ?
+ reinterpret_cast<NotEqualOptionsT *>(value) : nullptr;
+ }
+ const NotEqualOptionsT *AsNotEqualOptions() const {
+ return type == BuiltinOptions_NotEqualOptions ?
+ reinterpret_cast<const NotEqualOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4781,6 +4826,86 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(
flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct EqualOptionsT : public flatbuffers::NativeTable {
+ typedef EqualOptions TableType;
+ EqualOptionsT() {
+ }
+};
+
+struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef EqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<EqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct EqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ EqualOptionsBuilder &operator=(const EqualOptionsBuilder &);
+ flatbuffers::Offset<EqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<EqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ EqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct NotEqualOptionsT : public flatbuffers::NativeTable {
+ typedef NotEqualOptions TableType;
+ NotEqualOptionsT() {
+ }
+};
+
+struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef NotEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<NotEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct NotEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &);
+ flatbuffers::Offset<NotEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<NotEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ NotEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5068,6 +5193,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const {
return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast<const ExpandDimsOptions *>(builtin_options()) : nullptr;
}
+ const EqualOptions *builtin_options_as_EqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast<const EqualOptions *>(builtin_options()) : nullptr;
+ }
+ const NotEqualOptions *builtin_options_as_NotEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast<const NotEqualOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5302,6 +5433,14 @@ template<> inline const ExpandDimsOptions *Operator::builtin_options_as<ExpandDi
return builtin_options_as_ExpandDimsOptions();
}
+template<> inline const EqualOptions *Operator::builtin_options_as<EqualOptions>() const {
+ return builtin_options_as_EqualOptions();
+}
+
+template<> inline const NotEqualOptions *Operator::builtin_options_as<NotEqualOptions>() const {
+ return builtin_options_as_NotEqualOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7196,6 +7335,52 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flat
_validate_indices);
}
+inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new EqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<EqualOptions> EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateEqualOptions(
+ _fbb);
+}
+
+inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new NotEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<NotEqualOptions> NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateNotEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateNotEqualOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -7590,6 +7775,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -7816,6 +8009,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8030,6 +8231,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ExpandDimsOptionsT *>(value);
return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptionsT *>(value);
+ return CreateEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptionsT *>(value);
+ return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -8244,6 +8453,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ExpandDimsOptionsT(*reinterpret_cast<ExpandDimsOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_EqualOptions: {
+ value = new EqualOptionsT(*reinterpret_cast<EqualOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ value = new NotEqualOptionsT(*reinterpret_cast<NotEqualOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -8511,6 +8728,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<EqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<NotEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 351187f520..f5e25784fa 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2165,6 +2165,74 @@ def make_arg_max_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_equal_tests(zip_path):
+ """Make a set of tests to do equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_not_equal_tests(zip_path):
+ """Make a set of tests to do not equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the not euqal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.not_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_greater_tests(zip_path):
"""Make a set of tests to do greater."""
@@ -2352,30 +2420,44 @@ def make_neg_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_sin_tests(zip_path):
- """Make a set of tests to do sin."""
+def _make_elementwise_tests(op):
+ """Make a set of tests to do element-wise operations."""
- test_parameters = [{
- "input_dtype": [tf.float32],
- "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
- }]
+ def f(zip_path):
+ """Actual function that generates examples."""
+ test_parameters = [{
+ "input_dtype": [tf.float32],
+ "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+ }]
- def build_graph(parameters):
- """Build the sin op testing graph."""
- input_value = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input1",
- shape=parameters["input_shape"])
- out = tf.sin(input_value)
- return [input_value], [out]
+ def build_graph(parameters):
+ """Build the sin op testing graph."""
+ input_value = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape"])
+ out = op(input_value)
+ return [input_value], [out]
- def build_inputs(parameters, sess, inputs, outputs):
- input_value = create_tensor_data(parameters["input_dtype"],
- parameters["input_shape"])
- return [input_value], sess.run(
- outputs, feed_dict={inputs[0]: input_value})
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict={inputs[0]: input_value})
- make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+ return f
+
+
+def make_sin_tests(zip_path):
+ """Make a set of tests to do sin."""
+ return _make_elementwise_tests(tf.sin)(zip_path)
+
+
+def make_log_tests(zip_path):
+ """Make a set of tests to do log."""
+ return _make_elementwise_tests(tf.log)(zip_path)
def make_where_tests(zip_path):
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 99f0c81a1b..c7c80ab21c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model,
const auto& weights_array = model.GetArray(weights_array_name);
CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
- AxesOrder::kHWIO, tensorflow_graph);
+ AxesOrder::kHWOI, tensorflow_graph);
auto& strides = (*conv2d_op->mutable_attr())["strides"];
strides.mutable_list()->add_i(1);
strides.mutable_list()->add_i(src_op.stride_height);
@@ -1938,6 +1938,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowEqual) {
+ ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowNotEqual) {
+ ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
} else if (src_op.type == OperatorType::kTensorFlowGreater) {
ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
} else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 5071361bfd..a7841a6855 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -138,7 +138,8 @@ out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output")
with tf.Session() as sess:
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
- converter.quantized_input_stats = {"img" : (0., 1.)} # mean, std_dev
+ input_arrays = converter.get_input_arrays()
+ converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 076415ece8..8ca2cd66ac 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -46,8 +46,9 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
- conv_op->stride_height == 1) {
- // 1x1 unstrided conv does not need an im2col array.
+ conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 &&
+ conv_op->dilation_height_factor == 1) {
+ // 1x1 unstrided undilated conv does not need an im2col array.
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 64096fb069..92d283ca2c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -60,6 +60,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTensorFlowLessEqual:
case OperatorType::kTensorFlowGreater:
case OperatorType::kTensorFlowGreaterEqual:
+ case OperatorType::kTensorFlowEqual:
+ case OperatorType::kTensorFlowNotEqual:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index adb241da32..170a499d4e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -278,7 +278,7 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
<< "TransposeConv input shape must have 4 dimensions. Input \""
<< op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
<< toco::ShapeToString(weights_shape) << ".";
- CHECK_EQ(input_shape.dims(3), weights_shape.dims(0))
+ CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
<< "Input shape depth and weight depth do not agree";
// Set the output shape according to the specified output shape.
@@ -1563,6 +1563,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTensorFlowMaximum:
case OperatorType::kTensorFlowMinimum:
case OperatorType::kTensorFlowGreaterEqual:
+ case OperatorType::kTensorFlowEqual:
+ case OperatorType::kTensorFlowNotEqual:
ProcessSimpleBinaryOperator(model, op);
break;
case OperatorType::kAddN:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 142841fcc4..ab24c4f996 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -60,7 +60,7 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kTensorFlowGreaterEqual ||
type == OperatorType::kTensorFlowLess ||
type == OperatorType::kTensorFlowLessEqual ||
- type == OperatorType::kSelect;
+ type == OperatorType::kSelect || type == OperatorType::kArgMax;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 323d25a5bd..c1c2997c6b 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -48,6 +48,12 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
+#define TOCO_RETURN_IF_ERROR(...) \
+ do { \
+ const ::toco::port::Status _status = (__VA_ARGS__); \
+ if (!_status.ok()) return _status; \
+ } while (0)
+
using tensorflow::AttrValue;
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
@@ -130,6 +136,37 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node,
return attr.list();
}
+Status CheckOptionalAttr(const NodeDef& node, const string& attr_name,
+ const string& expected_value) {
+ if (HasAttr(node, attr_name)) {
+ const string& value = GetStringAttr(node, attr_name);
+ if (value != expected_value) {
+ return Status(false, "Unexpected value for attribute '" + attr_name +
+ "'. Expected '" + expected_value + "'");
+ }
+ }
+ return Status::OK();
+}
+Status CheckOptionalAttr(const NodeDef& node, const string& attr_name,
+ const tensorflow::DataType& expected_value) {
+ if (HasAttr(node, attr_name)) {
+ const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
+ if (value != expected_value) {
+ return Status(false, "Unexpected value for attribute '" + attr_name +
+ "'. Expected '" +
+ tensorflow::DataType_Name(expected_value) + "'");
+ }
+ }
+ return Status::OK();
+}
+
+template <typename T1, typename T2>
+Status ExpectValue(const T1& v1, const T2& v2, const string& description) {
+ if (v1 == v2) return Status::OK();
+ return Status(false, absl::StrCat("Unexpected ", description, ": got ", v1,
+ ", expected ", v2));
+}
+
ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
if (dtype == DT_UINT8)
return ArrayDataType::kUint8;
@@ -466,18 +503,16 @@ Status ConvertConstOperator(const NodeDef& node,
return status;
}
-void ConvertConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+Status ConvertConvOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Conv2D");
CheckInputsCount(node, tf_import_flags, 2);
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
- if (HasAttr(node, "data_format")) {
- CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
- }
- CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
+ TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
const auto& input_name = node.input(0);
const auto& weights_name = node.input(1);
@@ -502,27 +537,27 @@ void ConvertConvOperator(const NodeDef& node,
auto* conv = new ConvOperator;
conv->inputs = {input_name, reordered_weights_name};
conv->outputs = {node.name()};
+ TOCO_RETURN_IF_ERROR(
+ Status(HasAttr(node, "strides"), "Missing attribute 'strides'"));
const auto& strides = GetListAttr(node, "strides");
- CHECK_EQ(strides.i_size(), 4);
- CHECK_EQ(strides.i(0), 1);
- CHECK_EQ(strides.i(3), 1);
+ TOCO_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
+ TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
+ TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
if (HasAttr(node, "dilations")) {
const auto& dilations = GetListAttr(node, "dilations");
- CHECK_EQ(dilations.i_size(), 4);
- CHECK_EQ(dilations.i(0), 1)
- << "Can only import Conv ops with dilation along the height (1st) or "
- "width (2nd) axis. TensorFlow op \""
- << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
- << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
- << "].";
- CHECK_EQ(dilations.i(3), 1)
- << "Can only import Conv ops with dilation along the height (1st) or "
- "width (2nd) axis. TensorFlow op \""
- << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
- << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
- << "].";
+ TOCO_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return Status(
+ false, absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3),
+ "]."));
+ }
conv->dilation_height_factor = dilations.i(1);
conv->dilation_width_factor = dilations.i(2);
} else {
@@ -535,9 +570,11 @@ void ConvertConvOperator(const NodeDef& node,
} else if (padding == "VALID") {
conv->padding.type = PaddingType::kValid;
} else {
- LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ return Status(false, "Bad padding (only SAME and VALID are supported)");
}
model->operators.emplace_back(conv);
+
+ return Status::OK();
}
void ConvertDepthwiseConvOperator(const NodeDef& node,
@@ -1408,11 +1445,13 @@ void ConvertTransposeConvOperator(const NodeDef& node,
if (existing_transpose) {
CHECK(existing_transpose->type == OperatorType::kTranspose);
} else {
- // Transpose weights from HWIO order to OHWI order, which is more efficient
- // for computation
+ // Transpose weights from HWOI order to OHWI order, which is more efficient
+ // for computation. (Note that TensorFlow considers the order as HWIO
+ // because they consider this a backward conv, inverting the sense of
+ // input/output.)
TransposeOperator* transpose = new TransposeOperator;
string perm_array = CreateConstArray<ArrayDataType::kInt32>(
- model, node.name() + "_transpose_perm", {3, 0, 1, 2});
+ model, node.name() + "_transpose_perm", {2, 0, 1, 3});
transpose->inputs = {weights_name, perm_array};
transpose->outputs = {transposed_weights_name};
model->operators.emplace_back(transpose);
@@ -1722,7 +1761,7 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
if (node.op() == "Const") {
return ConvertConstOperator(node, tf_import_flags, model);
} else if (node.op() == "Conv2D") {
- ConvertConvOperator(node, tf_import_flags, model);
+ return ConvertConvOperator(node, tf_import_flags, model);
} else if (node.op() == "Conv2DBackpropInput") {
ConvertTransposeConvOperator(node, tf_import_flags, model);
} else if (node.op() == "DepthwiseConv2dNative") {
@@ -1904,10 +1943,18 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertRandomUniform(node, tf_import_flags, model);
} else if (node.op() == "Sin") {
ConvertSimpleOperator<SinOperator, 1>(node, tf_import_flags, model);
+ } else if (node.op() == "Log") {
+ ConvertSimpleOperator<LogOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Select") {
ConvertSimpleOperator<SelectOperator, 3>(node, tf_import_flags, model);
} else if (node.op() == "SparseToDense") {
ConvertSparseToDenseOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Equal") {
+ ConvertSimpleOperator<TensorFlowEqualOperator, 2>(node, tf_import_flags,
+ model);
+ } else if (node.op() == "NotEqual") {
+ ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>(node, tf_import_flags,
+ model);
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 1a4f87e363..2f43adb07b 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -136,6 +136,8 @@ enum class OperatorType {
kReorderAxes,
kSelect,
kSparseToDense,
+ kTensorFlowEqual,
+ kTensorFlowNotEqual,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -153,6 +155,7 @@ enum class AxesOrder {
k1HWO, // Our standard for DepthwiseConv weights
kHWIM, // TensorFlow DepthwiseConv weights
kNHWC, // TensorFlow activations
+ kHWOI, // TensorFlow back-prop conv weights
};
// The type of the scalars in an array.
@@ -1358,6 +1361,22 @@ struct TensorFlowGreaterEqualOperator : Operator {
: Operator(OperatorType::kTensorFlowGreaterEqual) {}
};
+// TensorFlow Equal equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowEqualOperator : Operator {
+ TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {}
+};
+
+// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for
+// details.
+struct TensorFlowNotEqualOperator : Operator {
+ TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {}
+};
+
// Global max reduction: computes the max of all of entries in the input array.
// Thus the output is "0-dimensional": it consists of a single scalar value.
//
@@ -1625,8 +1644,8 @@ struct SparseToDenseOperator : Operator {
// be used for the transient array at hand. The 'start' and 'end' values are
// offsets from the start of the workspace buffer, expressed in bytes.
struct Alloc {
- int start = 0;
- int end = 0;
+ int64 start = 0;
+ int64 end = 0;
};
inline bool operator<(const Alloc& a, const Alloc& b) {
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 0f104d5e2d..4c9f1aa4b0 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags(
"that information from the input file."),
Flag("input_arrays", parsed_flags.input_arrays.bind(),
parsed_flags.input_arrays.default_value(),
- "Names of the output arrays, comma-separated. If not specified, "
+ "Names of the input arrays, comma-separated. If not specified, "
"will try to read that information from the input file."),
Flag("output_array", parsed_flags.output_array.bind(),
parsed_flags.output_array.default_value(),
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index a954f1d6ba..93fe756a55 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -12,6 +12,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite/toco:model_flags_proto_cc",
"//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options",
"//tensorflow/contrib/lite/toco:toco_port",
"//tensorflow/contrib/lite/toco:toco_tooling",
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
index 5b1db852b4..d93e104038 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_tooling.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
@@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
if (error) return nullptr;
- // Use toco to produce new outputs
+ // Use TOCO to produce new outputs.
toco::ModelFlags model_flags;
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
LOG(FATAL) << "Model proto failed to parse." << std::endl;
@@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
LOG(FATAL) << "Toco proto failed to parse." << std::endl;
}
+
+ auto& dump_options = *GraphVizDumpOptions::singleton();
+ if (toco_flags.has_dump_graphviz_dir()) {
+ dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
+ }
+ if (toco_flags.has_dump_graphviz_include_video()) {
+ dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
+ }
+
+ // Convert model.
std::unique_ptr<toco::Model> model =
toco::Import(toco_flags, model_flags, input_contents_txt);
toco::Transform(toco_flags, model.get());
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 5daa703c80..a2d753657b 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -316,6 +316,7 @@ void Export(
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
&builder, &error_summary);
const string fake_quant_operation_name = "FAKE_QUANT";
+
if (error_summary.count(fake_quant_operation_name) != 0) {
LOG(ERROR)
<< fake_quant_operation_name
@@ -327,6 +328,21 @@ void Export(
error_summary.erase(fake_quant_operation_name);
}
if (!allow_custom_ops && !error_summary.empty()) {
+ // Remove ExpandDims and ReorderAxes from unimplemented list unless they
+ // compose the list. Both ops are removed during graph transformations.
+ // However, if an op is unimplemented earlier in the model, the graph
+ // transformation is unable to run because the output shape is not defined.
+ // This causes unnecessary confusion during model conversion time.
+ std::set<string> error_summary_final;
+ for (const auto& op_type : error_summary) {
+ if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
+ error_summary_final.insert(op_type);
+ }
+ }
+ if (error_summary_final.empty()) {
+ error_summary_final = error_summary;
+ }
+
LOG(QFATAL)
<< "Some of the operators in the model are not supported by "
"the standard TensorFlow Lite runtime. If you have a custom "
@@ -334,7 +350,7 @@ void Export(
"--allow_custom_ops, or by setting allow_custom_ops=True "
"when calling tf.contrib.lite.toco_convert(). Here is a list "
"of operators for which you will need custom implementations: "
- << absl::StrJoin(error_summary, ", ") << ".";
+ << absl::StrJoin(error_summary_final, ", ") << ".";
}
auto ops =
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc
index c0e7ab2ef5..1be7cf07a7 100644
--- a/tensorflow/contrib/lite/toco/tflite/import.cc
+++ b/tensorflow/contrib/lite/toco/tflite/import.cc
@@ -113,15 +113,34 @@ void ImportOperators(
<< operators_table.size();
}
string opname = operators_table.at(index);
+
+ // Find and use the appropriate operator deserialization factory.
+ std::unique_ptr<Operator> new_op = nullptr;
if (ops_by_name.count(opname) == 0) {
- LOG(FATAL) << "Op '" << opname << "' not supported";
+ string effective_opname = "TENSORFLOW_UNSUPPORTED";
+ if (ops_by_name.count(effective_opname) == 0) {
+ LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found.";
+ }
+ new_op = ops_by_name.at(effective_opname)
+ ->Deserialize(input_op->builtin_options(),
+ input_op->custom_options());
+ if (TensorFlowUnsupportedOperator* unsupported_op =
+ dynamic_cast<TensorFlowUnsupportedOperator*>(new_op.get())) {
+ unsupported_op->tensorflow_op = opname;
+ // TODO(b/109932940): Remove this when quantized is removed.
+ // For now, we assume all ops are quantized.
+ unsupported_op->quantized = true;
+ } else {
+ LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator";
+ }
+ } else {
+ new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(),
+ input_op->custom_options());
}
-
- auto new_op = ops_by_name.at(opname)->Deserialize(
- input_op->builtin_options(), input_op->custom_options());
model->operators.emplace_back(new_op.release());
auto* op = model->operators.back().get();
+ // Make sure all the inputs and outputs are hooked up.
auto inputs = input_op->inputs();
for (int i = 0; i < inputs->Length(); i++) {
auto input_index = inputs->Get(i);
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index a8518adefc..7490ab960b 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1112,12 +1112,18 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
"LESS", OperatorType::kTensorFlowLess));
ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>(
"LESS_EQUAL", OperatorType::kTensorFlowLessEqual));
+ ops.emplace_back(new SimpleOperator<TensorFlowEqualOperator>(
+ "EQUAL", OperatorType::kTensorFlowEqual));
+ ops.emplace_back(new SimpleOperator<TensorFlowNotEqualOperator>(
+ "NOT_EQUAL", OperatorType::kTensorFlowNotEqual));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
ops.emplace_back(
new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect));
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
+ // Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
+ ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index d63c99a5f9..e3144ad63e 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -119,6 +119,11 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
+ CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL",
+ OperatorType::kTensorFlowEqual);
+ CheckSimpleOperator<TensorFlowNotEqualOperator>(
+ "NOT_EQUAL", OperatorType::kTensorFlowNotEqual);
+ CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 4fe57879fb..ad4e94ded9 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -174,4 +174,13 @@ message TocoFlags {
// Computation is still done in float, but reduces model size (at the cost of
// accuracy and latency).
optional bool quantize_weights = 20 [default = false];
+
+ // Full filepath of folder to dump the graphs at various stages of processing
+ // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order
+ // to keep the requirements of the output file.
+ optional string dump_graphviz_dir = 24;
+
+ // Boolean indicating whether to dump the graph after every graph
+ // transformation.
+ optional bool dump_graphviz_include_video = 25;
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index fe7bed885d..810718f610 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -394,6 +394,8 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
HANDLE_OPERATORTYPENAME_CASE(Select)
HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -1863,18 +1865,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
output_axes_order == AxesOrder::kHWIO) {
// 3210 <- 3210
// HWIO <- OHWI
- (*shuffle)[0] = 1;
- (*shuffle)[1] = 2;
- (*shuffle)[2] = 3;
- (*shuffle)[3] = 0;
+ *shuffle = {1, 2, 3, 0};
} else if (input_axes_order == AxesOrder::kHWIO &&
output_axes_order == AxesOrder::kOHWI) {
// 3210 <- 3210
// OHWI <- HWIO
- (*shuffle)[0] = 3;
- (*shuffle)[1] = 0;
- (*shuffle)[2] = 1;
- (*shuffle)[3] = 2;
+ *shuffle = {3, 0, 1, 2};
+ } else if (input_axes_order == AxesOrder::kOHWI &&
+ output_axes_order == AxesOrder::kHWOI) {
+ *shuffle = {1, 2, 0, 3};
} else {
LOG(FATAL) << "Bad shuffle";
}
@@ -2020,6 +2019,8 @@ int AxesCount(AxesOrder axes_order) {
return 4;
case AxesOrder::kNHWC:
return 4;
+ case AxesOrder::kHWOI:
+ return 4;
default:
LOG(FATAL) << "Bad AxesOrder";
return 0;
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 4824a4dbde..f918010e2b 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -5,6 +5,8 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
common_copts = ["-Wall"]
@@ -14,7 +16,7 @@ cc_binary(
"benchmark_main.cc",
"logging.h",
],
- copts = common_copts,
+ copts = tflite_copts() + common_copts,
linkopts = select({
"//tensorflow:android": [
"-pie",
@@ -58,6 +60,7 @@ cc_library(
],
hdrs = ["benchmark_tflite_model.h"],
copts = common_copts,
+ linkopts = tflite_linkopts(),
deps = [
":benchmark_model_lib",
"//tensorflow/contrib/lite:framework",
diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md
index e6f333aa5b..c10826afff 100644
--- a/tensorflow/contrib/lite/tools/benchmark/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/README.md
@@ -46,8 +46,6 @@ adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
--input_layer="Placeholder" \
--input_layer_shape="1,224,224,3" \
- --input_layer_type="uint8" \
- --output_layer="MobilenetV1/Predictions/Reshape_1" \
--num_threads=4
```
@@ -66,8 +64,6 @@ bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
--graph=mobilenet_quant_v1_224.tflite \
--input_layer="Placeholder" \
--input_layer_shape="1,224,224,3" \
- --input_layer_type="uint8" \
- --output_layer="MobilenetV1/Predictions/Reshape_1" \
--num_threads=4
```
@@ -93,80 +89,66 @@ This compiles TFLite with profiling enabled, now you can run the benchmark binar
============================== Run Order ==============================
[node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
- CONV_2D 0.000 9.132 9.132 0.121% 0.121% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
- DEPTHWISE_CONV_2D 9.135 3.280 3.280 0.043% 0.165% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6]
- CONV_2D 12.419 6.877 6.877 0.091% 0.256% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
- DEPTHWISE_CONV_2D 19.299 1.708 1.708 0.023% 0.278% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6]
- CONV_2D 21.012 4.162 4.162 0.055% 0.334% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6]
- DEPTHWISE_CONV_2D 25.177 3.520 3.520 0.047% 0.380% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6]
- CONV_2D 28.701 10.218 10.218 0.136% 0.516% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
- DEPTHWISE_CONV_2D 38.922 0.827 0.827 0.011% 0.527% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6]
- CONV_2D 39.752 1.401 1.401 0.019% 0.545% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6]
- DEPTHWISE_CONV_2D 41.156 1.290 1.290 0.017% 0.563% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6]
- CONV_2D 42.448 5.995 5.995 0.080% 0.642% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
- DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.647% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6]
- CONV_2D 48.856 6.167 6.167 0.082% 0.729% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
- DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.738% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6]
- CONV_2D 55.656 6.464 6.464 0.086% 0.823% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
- DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.832% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6]
- CONV_2D 62.774 14.666 14.666 0.195% 1.026% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
- DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 1.035% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6]
- CONV_2D 78.081 7.186 7.186 0.095% 1.130% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
- DEPTHWISE_CONV_2D 85.270 0.646 0.646 0.009% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6]
- CONV_2D 85.918 9.529 9.529 0.126% 1.265% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
- DEPTHWISE_CONV_2D 95.451 0.628 0.628 0.008% 1.273% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6]
- CONV_2D 96.081 2.077 2.077 0.028% 1.301% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6]
- DEPTHWISE_CONV_2D 98.162 0.168 0.168 0.002% 1.303% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6]
- CONV_2D 98.332 1.007 1.007 0.013% 1.317% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6]
- DEPTHWISE_CONV_2D 99.342 0.288 0.288 0.004% 1.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6]
- CONV_2D 99.632 8.197 8.197 0.109% 1.429% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
- AVERAGE_POOL_2D 107.832 0.045 0.045 0.001% 1.430% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool]
- CONV_2D 107.878 0.325 0.325 0.004% 1.434% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd]
- RESHAPE 108.206 0.003 0.003 0.000% 1.434% 0.000 0 [MobilenetV1/Predictions/Reshape]
- SOFTMAX 108.211 0.038 0.038 0.001% 1.434% 0.000 0 [MobilenetV1/Predictions/Softmax]
+ CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
+ DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6]
+ CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6]
+ CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6]
+ CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6]
+ CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6]
+ CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6]
+ CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6]
+ CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6]
+ CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6]
+ CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6]
+ CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6]
+ CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6]
+ CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6]
+ CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
+ AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool]
+ CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd]
+ RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape]
+ SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax]
============================== Top by Computation Time ==============================
[node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
- CONV_2D 62.774 14.666 14.666 0.195% 0.195% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
- CONV_2D 28.701 10.218 10.218 0.136% 0.330% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
- CONV_2D 85.918 9.529 9.529 0.126% 0.456% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
- CONV_2D 0.000 9.132 9.132 0.121% 0.578% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
- CONV_2D 99.632 8.197 8.197 0.109% 0.686% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
- CONV_2D 78.081 7.186 7.186 0.095% 0.782% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
- CONV_2D 12.419 6.877 6.877 0.091% 0.873% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
- CONV_2D 55.656 6.464 6.464 0.086% 0.958% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
- CONV_2D 48.856 6.167 6.167 0.082% 1.040% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
- CONV_2D 42.448 5.995 5.995 0.080% 1.120% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
-
-============================== Top by Memory Use ==============================
- [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
- SOFTMAX 108.211 0.038 0.038 0.001% 0.001% 0.000 0 [MobilenetV1/Predictions/Softmax]
- RESHAPE 108.206 0.003 0.003 0.000% 0.001% 0.000 0 [MobilenetV1/Predictions/Reshape]
- CONV_2D 78.081 7.186 7.186 0.095% 0.096% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
- DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 0.104% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6]
- CONV_2D 62.774 14.666 14.666 0.195% 0.299% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
- DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.307% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6]
- CONV_2D 55.656 6.464 6.464 0.086% 0.393% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
- DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.401% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6]
- CONV_2D 48.856 6.167 6.167 0.082% 0.483% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
- DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.489% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6]
+ CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
+ CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
+ CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
+ CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
+ CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
+ CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
+ CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6]
+ CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
+ CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
+ CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
Number of nodes executed: 31
============================== Summary by node type ==============================
[Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called]
- CONV_2D 15 1.861 86.679% 86.679% 0.000 0
- DEPTHWISE_CONV_2D 13 0.286 13.321% 100.000% 0.000 0
+ CONV_2D 15 1.406 89.270% 89.270% 0.000 0
+ DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0
SOFTMAX 1 0.000 0.000% 100.000% 0.000 0
RESHAPE 1 0.000 0.000% 100.000% 0.000 0
AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0
-Timings (microseconds): count=50 first=108164 curr=128308 min=102850 max=197072 avg=150805 std=24368
+Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929
Memory (bytes): count=0
31 nodes observed
-Average inference timings in us: Warmup: 135310, Init: 12123, no stats: 150988
-
+Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
```
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 2e5b866273..5f803cec19 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -123,29 +123,11 @@ void FillRandomString(tflite::DynamicBuffer* buffer,
}
}
-TfLiteType TfLiteTypeFromString(const string& input_layer_type) {
- if (input_layer_type == "string")
- return kTfLiteString;
- else if (input_layer_type == "float")
- return kTfLiteFloat32;
- else if (input_layer_type == "uint8")
- return kTfLiteUInt8;
- else if (input_layer_type == "int32")
- return kTfLiteInt32;
- else if (input_layer_type == "int64")
- return kTfLiteInt64;
- else
- return kTfLiteNoType;
-}
-
bool PopulateInputLayerInfo(
const string& names_string, const string& shapes_string,
- const string& types_string, const string& values_string,
std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
std::vector<std::string> names = Split(names_string, ',');
std::vector<std::string> shapes = Split(shapes_string, ':');
- std::vector<std::string> types = Split(types_string, ',');
- std::vector<std::string> values = Split(values_string, ':');
if (names.size() != shapes.size()) {
TFLITE_LOG(ERROR) << "The number of items in"
@@ -158,17 +140,6 @@ bool PopulateInputLayerInfo(
<< " --input_layer_shape=1,224,224,4:1,20";
return false;
}
- if (names.size() != types.size()) {
- TFLITE_LOG(ERROR) << "The number of items in"
- << " --input_layer_type (" << types_string << ", with "
- << types.size() << " items)"
- << " must match the number of items in"
- << " --input_layer (" << names_string << ", with "
- << names.size() << " items)."
- << " For example --input_layer=input1,input2"
- << " --input_layer_type=float,int";
- return false;
- }
for (int i = 0; i < names.size(); ++i) {
info->push_back(BenchmarkTfLiteModel::InputLayerInfo());
@@ -176,10 +147,6 @@ bool PopulateInputLayerInfo(
input.name = names[i];
- input.data_type = TfLiteTypeFromString(types[i]);
- TFLITE_BENCHMARK_CHECK(input.data_type != kTfLiteNoType)
- << types[i] << " was an invalid type";
-
TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape))
<< "Incorrect size string specified: " << shapes[i];
for (int dim : input.shape) {
@@ -190,12 +157,6 @@ bool PopulateInputLayerInfo(
return false;
}
}
-
- if (i < values.size()) {
- TFLITE_BENCHMARK_CHECK(
- SplitAndParse(values[i], ',', &input.initialization_values))
- << "Incorrect initialization values string specified: " << values[i];
- }
}
return true;
@@ -209,10 +170,6 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
Flag("graph", &graph, "graph file name"),
Flag("input_layer", &input_layer_string, "input layer names"),
Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"),
- Flag("input_layer_type", &input_layer_type_string, "input layer type"),
- Flag("input_layer_values", &input_layer_values_string,
- "values to initialize the inputs with"),
- Flag("output_layer", &output_layer_string, "output layer name"),
Flag("use_nnapi", &use_nnapi, "use nnapi api")};
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
@@ -224,8 +181,6 @@ void BenchmarkTfLiteModel::LogFlags() {
TFLITE_LOG(INFO) << "Graph: [" << graph << "]";
TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]";
TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]";
- TFLITE_LOG(INFO) << "Input types: [" << input_layer_type_string << "]";
- TFLITE_LOG(INFO) << "Output layers: [" << output_layer_string << "]";
TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]";
}
@@ -236,8 +191,7 @@ bool BenchmarkTfLiteModel::ValidateFlags() {
return false;
}
return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string,
- input_layer_type_string,
- input_layer_values_string, &inputs);
+ &inputs);
}
uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
@@ -293,8 +247,6 @@ void BenchmarkTfLiteModel::Init() {
TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name)
<< "Tensor # " << i << " is named " << t->name << " but flags call it "
<< input.name;
- TFLITE_BENCHMARK_CHECK_EQ(t->type, input.data_type)
- << "Could not match the type of input tensor " << t->name;
}
// Resize all non-string tensors.
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index e70f6de1bf..ffb93da964 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -64,10 +64,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
struct InputLayerInfo {
std::string name;
- TfLiteType data_type;
std::vector<int> shape;
- // Note that initialization_values is currently unused.
- std::vector<float> initialization_values;
};
private:
@@ -78,7 +75,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::string input_layer_type_string;
std::string input_layer_shape_string;
std::string input_layer_values_string;
- std::string output_layer_string;
std::vector<InputLayerInfo> inputs;
bool use_nnapi;
ProfilingListener profiling_listener_;
diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
index 9a931d5ddd..620d61b027 100644
--- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
@@ -134,9 +134,9 @@ TEST(CommandLineFlagsTest, UsageString) {
std::string some_name = "something";
// Don't test float in this case, because precision is hard to predict and
// match against, and we don't want a flakey test.
- const string tool_name = "some_tool_name";
- string usage = Flags::Usage(tool_name + " <flags>",
- {Flag("some_int", &some_int, "some int"),
+ const std::string tool_name = "some_tool_name";
+ std::string usage = Flags::Usage(
+ tool_name + " <flags>", {Flag("some_int", &some_int, "some int"),
Flag("some_int64", &some_int64, "some int64"),
Flag("some_switch", &some_switch, "some switch"),
Flag("some_name", &some_name, "some name")});
diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc
index ce8a7857d2..ad7d59ecb4 100644
--- a/tensorflow/contrib/lite/tools/verifier_test.cc
+++ b/tensorflow/contrib/lite/tools/verifier_test.cc
@@ -41,7 +41,7 @@ class TfLiteFlatbufferModelBuilder {
}
TfLiteFlatbufferModelBuilder(const std::vector<BuiltinOperator>& builtin_ops,
- const std::vector<string>& custom_ops) {
+ const std::vector<std::string>& custom_ops) {
buffers_.push_back(
CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
@@ -194,8 +194,8 @@ TEST(VerifyModel, TensorBufferIsNotValid) {
/*operators=*/0, builder.CreateString("Main"))});
auto buffers = builder.CreateVector(std::vector<Offset<Buffer>>{
- CreateBuffer(builder,
- builder.CreateVector(std::vector<uint8>{1, 2, 3, 4, 5, 6})),
+ CreateBuffer(builder, builder.CreateVector(
+ std::vector<uint8_t>{1, 2, 3, 4, 5, 6})),
});
auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0,
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 5d4682ec9f..5a080cceab 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -24,6 +24,7 @@ import six
from tensorflow.contrib import lookup
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -1396,15 +1397,22 @@ class KeyValueTensorInitializerTest(test.TestCase):
class IndexTableFromTensor(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes()
def test_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ table = lookup.index_table_from_tensor(
+ mapping=("brain", "salad", "surgery"), num_oov_buckets=1)
+
+ if not context.executing_eagerly():
+ with self.assertRaises(errors_impl.OpError):
+ self.evaluate(table.lookup(
+ constant_op.constant(("salad", "surgery", "tarkus"))))
+ else:
+ # Reinitializing a table in eager should work.
table = lookup.index_table_from_tensor(
mapping=("brain", "salad", "surgery"), num_oov_buckets=1)
- ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
-
- self.assertRaises(errors_impl.OpError, ids.eval)
- lookup_ops.tables_initializer().run()
- self.assertAllEqual((1, 2, 3), ids.eval())
+ self.evaluate(lookup_ops.tables_initializer())
+ ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
+ self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
with self.test_session():
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
index e4e5ccc334..ef34f7bf7b 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -26,26 +26,32 @@ from tensorflow.python.training import optimizer
class LossScaleOptimizer(optimizer.Optimizer):
+ # TODO(jamesqin): move mixed precision training explanation to __init__
+ # docstring.
"""An optimizer that applies loss scaling in backprop.
- This class is useful for mixed precision training on GPUs (or other potential
- accelerators), which is an approach to improve compute throughput without loss
- of model quality.
-
- The commmon configuration of mixed precision models is the following:
- * variables are kept in high precision (e.g. float32).
- * computations are done in lower precision (e.g. float16). variables are
- casted to lower precision before they're used.
- * (in training), final gradients are casted back to variable precision and get
- applied.
-
- Because computations happen in lower precision, gradients in the backprop pass
- might underflow in the smaller dynamic range, causing a model to converge at a
- suboptimal level. This optimizer multiplies the loss by a factor before
- backprop starts to prevent underflow. Before gradients are applied, they are
- casted to higher precision and down-scaled by the same factor, so
- mathematically the variable updates are no different from regular
- same-precision training.
+ This class is useful for "mixed precision training" on GPUs (or other
+ potential accelerators), an approach to improve compute throughput without
+ compromising model quality.
+
+ The canonical way to perform mixed precision training is the following:
+ * Model variables are kept in high precision (e.g. float32).
+ * Computations are done in lower precision (e.g. float16), which enjoys
+ performance speedup by virtue of hardware support. Variables are casted to
+ lower precision before they're used.
+ * Final gradients are casted back to high precision dtype, then used to update
+ variables.
+
+ The side-effect of performing computation in lower precision, is that it comes
+ with smaller numerical range. During backproping, small gradients might
+ underflow in the reduced numerical range, causing a model to converge at
+ suboptimal level.
+
+ To prevent underflow, this optimizer multiplies the loss by a factor before
+ backprop starts. Consequently, the gradients are linearly scaled up by the
+ same factor, thus not falling into the underflow zone. After that, to perserve
+ the correctness of backprop, the gradients are down-scaled by the same factor,
+ casted to the (higher) variable precision, then applied on the variables.
See [Nvidia's manual on mixed precision training](
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 99485322c6..f80f5652af 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -18,7 +18,7 @@ limitations under the License.
// Initiates a TPU profiling on the TPUProfiler service at service_addr,
// receives and dumps the profile data to a tensorboard log directory.
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include <cstdio>
#include <ctime>
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 71a5012691..1c482950e6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -623,6 +623,11 @@ def split_compile_and_replicate(computation,
vscope.set_use_resource(saved_use_resource)
+ # If the computation returns `None`, add `no_op` here so that when user
+ # fetches `no_op` returned by this function, the TPUExecute node will be
+ # triggered.
+ if outputs is None:
+ outputs = (control_flow_ops.no_op(),)
# If the computation only returned one value, makes it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index 409aba817c..a2444934bc 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
- # pylint: disable=protected-access
if padded_shapes is None:
self._padded_shapes = nest.map_structure(
- dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
+ convert.partial_shape_to_tensor, input_dataset.output_shapes)
else:
self._padded_shapes = nest.map_structure_up_to(
- input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
+ input_dataset.output_shapes, convert.partial_shape_to_tensor,
padded_shapes)
+ # pylint: disable=protected-access
padding_values = (
padding_values if padding_values is not None else
dataset_ops._default_padding(input_dataset))
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc
index 742f946c95..af29abd91f 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service.cc
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc
@@ -15,9 +15,9 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
-#include "grpc++/alarm.h"
-#include "grpc++/grpc++.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
index 991f9a9d8b..4da7b59c69 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
index 1f0f10517e..abe5e08b07 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
#define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5de59eaef7..2e4df72edc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2636,6 +2636,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/dma_helper.h",
"common_runtime/eigen_thread_pool.h",
"common_runtime/executor.h",
+ "common_runtime/executor_factory.h",
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
"common_runtime/lower_if_op.h",
@@ -2685,6 +2686,7 @@ tf_cuda_library(
"common_runtime/device_resolver_local.cc",
"common_runtime/device_set.cc",
"common_runtime/executor.cc",
+ "common_runtime/executor_factory.cc",
"common_runtime/function.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt
new file mode 100644
index 0000000000..0c5b1eb45a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt
@@ -0,0 +1,18 @@
+op {
+ graph_op_name: "BatchDatasetV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "batch_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a batch.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ summary: "Creates a dataset that batches `batch_size` elements from `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
new file mode 100644
index 0000000000..ffd01ba5cc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "FeatureStatsDataset"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt
new file mode 100644
index 0000000000..9fefc0c418
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "PaddedBatchDatasetV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "batch_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a
+batch.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ in_arg {
+ name: "padded_shapes"
+ description: <<END
+A list of int64 tensors representing the desired padded shapes
+of the corresponding output components. These shapes may be partially
+specified, using `-1` to indicate that a particular dimension should be
+padded to the maximum size of all batch elements.
+END
+ }
+ in_arg {
+ name: "padding_values"
+ description: <<END
+A list of scalars containing the padding value to use for
+each of the outputs.
+END
+ }
+ summary: "Creates a dataset that batches and pads `batch_size` elements from the input."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt
new file mode 100644
index 0000000000..3b3a274df5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ResourceScatterNdAdd"
+ in_arg {
+ name: "ref"
+ description: <<END
+A resource handle. Must be from a VarHandleOp.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A Tensor. Must be one of the following types: int32, int64.
+A tensor of indices into ref.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A Tensor. Must have the same type as ref. A tensor of
+values to add to ref.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+An optional bool. Defaults to True. If True, the assignment will
+be protected by a lock; otherwise the behavior is undefined,
+but may exhibit less contention.
+END
+ }
+ summary: "Adds sparse `updates` to individual values or slices within a given"
+ description: <<END
+variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to
+8 elements. In Python, that update would look like this:
+
+```python
+ ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_add(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+```
+
+The resulting update to ref would look like this:
+
+ [1, 12, 3, 14, 14, 6, 7, 20]
+
+See @{tf.scatter_nd} for more details about how to make updates to
+slices.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt
new file mode 100644
index 0000000000..dd37b94ffa
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "TensorArrayGradWithShape"
+ endpoint {
+ name: "TensorArrayGradWithShape"
+ }
+ in_arg {
+ name: "handle"
+ description: <<END
+The handle to the forward TensorArray.
+END
+ }
+ in_arg {
+ name: "flow_in"
+ description: <<END
+A float scalar that enforces proper chaining of operations.
+END
+ }
+ in_arg {
+ name: "shape_to_prepend"
+ description: <<END
+An int32 vector representing a shape. Elements in the gradient accumulator will
+have shape which is this shape_to_prepend value concatenated with shape of the
+elements in the TensorArray corresponding to the input handle.
+END
+ }
+ attr {
+ name: "source"
+ description: <<END
+The gradient source string, used to decide which gradient TensorArray
+to return.
+END
+ }
+ summary: "Creates a TensorArray for storing multiple gradients of values in the given handle."
+ description: <<END
+Similar to TensorArrayGradV3. However it creates an accumulator with an
+expanded shape compared to the input TensorArray whose gradient is being
+computed. This enables multiple gradients for the same TensorArray to be
+calculated using the same accumulator.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
new file mode 100644
index 0000000000..7f721f4fb7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "FeatureStatsDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt
new file mode 100644
index 0000000000..ffef3ab522
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterNdAdd"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt
new file mode 100644
index 0000000000..f6c8af5c33
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterNdAdd"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt
new file mode 100644
index 0000000000..5d76c112a0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorArrayGradWithShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc
index a9dc6ca6cd..00f7a8e645 100644
--- a/tensorflow/core/common_runtime/build_graph_options.cc
+++ b/tensorflow/core/common_runtime/build_graph_options.cc
@@ -32,6 +32,9 @@ string BuildGraphOptions::DebugString() const {
for (auto& s : callable_options.target()) {
strings::StrAppend(&rv, s, ", ");
}
+ if (collective_graph_key != kNoCollectiveGraphKey) {
+ strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key);
+ }
return rv;
}
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index 5ca170e922..3d0f242ea5 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -31,6 +31,9 @@ struct BuildGraphOptions {
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
+ static const int64 kNoCollectiveGraphKey = 0;
+ int64 collective_graph_key = kNoCollectiveGraphKey;
+
string DebugString() const;
};
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index e07829b286..4f03a5e13a 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -25,11 +25,11 @@ namespace tensorflow {
CollectiveExecutorMgr::CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver)
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver)
: dev_mgr_(dev_mgr),
- dev_resolver_(dev_resolver),
- param_resolver_(param_resolver) {}
+ dev_resolver_(std::move(dev_resolver)),
+ param_resolver_(std::move(param_resolver)) {}
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
for (auto iter : executor_table_) {
@@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
if (it != executor_table_.end()) {
ce = it->second;
} else {
- CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
- dev_mgr_, dev_resolver_.get(), step_id);
- ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+ ce = Create(step_id);
executor_table_[step_id] = ce;
}
ce->Ref();
@@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
return ce;
}
+CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessLocal* rma =
+ new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
void CollectiveExecutorMgr::Cleanup(int64 step_id) {
CollectiveExecutor* ce = nullptr;
{
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 4b42e2b4d1..9de6ab8968 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -25,8 +25,8 @@ class DeviceMgr;
class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
public:
CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver);
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver);
virtual ~CollectiveExecutorMgr();
@@ -56,11 +56,16 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
void RetireStepId(int64 graph_key, int64 step_id) override {}
protected:
+ // Called by FindOrCreate when table entry does not yet exist.
+ virtual CollectiveExecutor* Create(int64 step_id);
+
const DeviceMgr* dev_mgr_;
std::unique_ptr<DeviceResolverInterface> dev_resolver_;
std::unique_ptr<ParamResolverInterface> param_resolver_;
CollectiveRemoteAccess* remote_access_;
string task_name_;
+
+ private:
mutex exec_mu_;
// Map from step_id to CollectiveExecutor
gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
index 34c9163d6a..91994c5731 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -40,10 +40,13 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
- cme_.reset(new CollectiveExecutorMgr(
- cp, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> prl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ task_name));
+ cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
+ std::move(prl)));
}
std::unique_ptr<CollectiveExecutorMgr> cme_;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 3a871f962d..43c404f2ec 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -201,7 +201,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
LOCKS_EXCLUDED(irec->out_mu);
const DeviceMgr* dev_mgr_;
- DeviceResolverInterface* dev_resolver_;
+ DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
mutex group_mu_;
gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 07c1eafedc..5cef93c605 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -450,11 +450,13 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Set up for collectives if the RunOption declares a key.
if (run_options.experimental().collective_graph_key() > 0) {
if (!collective_executor_mgr_) {
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> cprl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
- options_.config, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl,
- "/job:localhost/replica:0/task:0")));
+ options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
}
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 2a43a31c02..b410ea175b 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -79,6 +79,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
+ params.cancellation_manager = &cm_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index f78d197fd5..c41a0972b1 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -76,6 +77,11 @@ class KernelAndDevice {
const DataTypeVector& output_dtypes() { return output_dtypes_; }
private:
+ // TODO(apassos) Consider a shared cancellation manager. Note that this
+ // cancellation manager is not useful to actually cancel anything, and is
+ // provided here only for the few kernels which can't handle one being
+ // missing.
+ CancellationManager cm_;
std::unique_ptr<OpKernel> kernel_;
Device* device_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 585d777e81..f7f2cdc14f 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/costmodel_manager.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
@@ -2764,4 +2765,30 @@ Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
+namespace {
+
+class DefaultExecutorRegistrar {
+ public:
+ DefaultExecutorRegistrar() {
+ Factory* factory = new Factory;
+ ExecutorFactory::Register("", factory);
+ ExecutorFactory::Register("DEFAULT", factory);
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret = nullptr;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static DefaultExecutorRegistrar registrar;
+
+} // namespace
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor_factory.cc b/tensorflow/core/common_runtime/executor_factory.cc
new file mode 100644
index 0000000000..ee7c7c3a73
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor_factory.cc
@@ -0,0 +1,85 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/executor_factory.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+static mutex executor_factory_lock(LINKER_INITIALIZED);
+
+typedef std::unordered_map<string, ExecutorFactory*> ExecutorFactories;
+ExecutorFactories* executor_factories() {
+ static ExecutorFactories* factories = new ExecutorFactories;
+ return factories;
+}
+
+} // namespace
+
+void ExecutorFactory::Register(const string& executor_type,
+ ExecutorFactory* factory) {
+ mutex_lock l(executor_factory_lock);
+ if (!executor_factories()->insert({executor_type, factory}).second) {
+ LOG(FATAL) << "Two executor factories are being registered "
+ << "under" << executor_type;
+ }
+}
+
+namespace {
+const string RegisteredFactoriesErrorMessageLocked()
+ SHARED_LOCKS_REQUIRED(executor_factory_lock) {
+ std::vector<string> factory_types;
+ for (const auto& executor_factory : *executor_factories()) {
+ factory_types.push_back(executor_factory.first);
+ }
+ return strings::StrCat("Registered factories are {",
+ str_util::Join(factory_types, ", "), "}.");
+}
+} // namespace
+
+Status ExecutorFactory::GetFactory(const string& executor_type,
+ ExecutorFactory** out_factory) {
+ tf_shared_lock l(executor_factory_lock);
+
+ auto iter = executor_factories()->find(executor_type);
+ if (iter == executor_factories()->end()) {
+ return errors::NotFound(
+ "No executor factory registered for the given executor type: ",
+ executor_type, " ", RegisteredFactoriesErrorMessageLocked());
+ }
+
+ *out_factory = iter->second;
+ return Status::OK();
+}
+
+Status NewExecutor(const string& executor_type,
+ const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) {
+ ExecutorFactory* factory = nullptr;
+ TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));
+ return factory->NewExecutor(params, std::move(graph), out_executor);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor_factory.h b/tensorflow/core/common_runtime/executor_factory.h
new file mode 100644
index 0000000000..f81bb080eb
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor_factory.h
@@ -0,0 +1,51 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class Executor;
+class Graph;
+struct LocalExecutorParams;
+
+class ExecutorFactory {
+ public:
+ virtual Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) = 0;
+ virtual ~ExecutorFactory() {}
+
+ static void Register(const string& executor_type, ExecutorFactory* factory);
+ static Status GetFactory(const string& executor_type,
+ ExecutorFactory** out_factory);
+};
+
+Status NewExecutor(const string& executor_type,
+ const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc
index 8cb1567852..7697103faf 100644
--- a/tensorflow/core/common_runtime/executor_test.cc
+++ b/tensorflow/core/common_runtime/executor_test.cc
@@ -464,12 +464,12 @@ BENCHMARK(BM_executor)->ArgPair(1024, 1024);
static void BM_FeedInputFetchOutput(int iters) {
Graph* g = new Graph(OpRegistry::Global());
// z = x + y: x and y are provided as benchmark inputs. z is the
- // output of the benchmark. Conceptually, the caller is "a", the
- // benchmark is "b".
- Node* x = test::graph::Recv(g, "x", "float", "a", 1, "b");
- Node* y = test::graph::Recv(g, "y", "float", "a", 1, "b");
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
+ Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+ Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
Node* sum = test::graph::Add(g, x, y);
- Node* z = test::graph::Send(g, sum, "z", "b", 1, "a");
+ Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
Tensor val(DT_FLOAT, TensorShape({}));
val.scalar<float>()() = 3.14;
#ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 5d9be70522..68d37ddbcd 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
@@ -215,6 +216,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
+ string executor_type;
~Item() {
delete this->func_graph;
@@ -549,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
item->instantiation_counter = 1;
+ item->executor_type = options.executor_type;
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
@@ -623,10 +626,12 @@ void PruneFunctionBody(Graph* g) {
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
const FunctionBody* fbody;
const FunctionLibraryDefinition* lib_def;
+ string executor_type;
{
mutex_lock l(mu_);
fbody = (*item)->func_graph;
lib_def = (*item)->overlay_lib;
+ executor_type = (*item)->executor_type;
}
if (!lib_def) {
lib_def = base_lib_def_;
@@ -656,17 +661,14 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
DeleteNonCachedKernel(kernel);
};
Graph* graph = g.get();
- Executor* exec;
- TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
-
+ std::unique_ptr<Executor> exec;
+ TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
{
// Guard item since it is already inserted in items_.
mutex_lock l(mu_);
- if ((*item)->exec) {
- delete exec;
- } else {
+ if ((*item)->exec == nullptr) {
(*item)->graph = graph;
- (*item)->exec = exec;
+ (*item)->exec = exec.release();
}
}
return Status::OK();
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index f4f5198396..1e837e9a7e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
@@ -531,6 +532,69 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
+namespace {
+class DummyExecutorRegistrar {
+ public:
+ DummyExecutorRegistrar() {
+ ExecutorFactory::Register("DUMMY", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ return errors::Internal("This is a dummy.");
+ }
+ };
+};
+static DummyExecutorRegistrar registrar;
+} // namespace
+
+TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
+ Init({test::function::XTimesTwo()});
+
+ auto x = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y;
+
+ // Test that the default executor works.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test the explicit registration for the default executor.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DEFAULT";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test that a non-default executor factory can be invoked.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent exector types trigger an error.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "UNKNOWN_EXECUTOR";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -803,7 +867,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__6")
+ s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -913,7 +977,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
"Not found: Function Foo is not defined.");
}
-TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
+TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
auto bad_x_times_two = FDH::Define(
// Name
"XTimesTwo",
@@ -1009,13 +1073,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__11")
+ s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__10")
+ s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
index 7de1b80e2d..1f585a8c24 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -43,7 +44,7 @@ namespace test {
// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
Benchmark::Benchmark(const string& device, Graph* g,
const SessionOptions* options, Graph* init,
- Rendezvous* rendez) {
+ Rendezvous* rendez, const char* executor_type) {
SessionOptions default_options;
if (!options) {
options = &default_options;
@@ -86,23 +87,26 @@ Benchmark::Benchmark(const string& device, Graph* g,
};
if (init) {
- Executor* init_exec;
- TF_CHECK_OK(
- NewLocalExecutor(params, std::unique_ptr<Graph>(init), &init_exec));
+ std::unique_ptr<Executor> init_exec;
+ TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init),
+ &init_exec));
Executor::Args args;
args.rendezvous = rendez_;
args.runner = runner;
TF_CHECK_OK(init_exec->Run(args));
- delete init_exec;
}
- TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr<Graph>(g), &exec_));
+ TF_CHECK_OK(
+ NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
}
Benchmark::~Benchmark() {
if (device_) {
rendez_->Unref();
- delete exec_;
+ // We delete `exec_` before `device_` because the `exec_` destructor may
+ // run kernel destructors that may attempt to access state borrowed from
+ // `device_`, such as the resource manager.
+ exec_.reset();
delete device_;
delete pool_;
}
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
index 3a7b3a5ace..995a15a299 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -39,7 +39,7 @@ class Benchmark {
// "init", and one reference on "rendez" (if not null).
Benchmark(const string& device, Graph* g,
const SessionOptions* options = nullptr, Graph* init = nullptr,
- Rendezvous* rendez = nullptr);
+ Rendezvous* rendez = nullptr, const char* executor_type = "");
~Benchmark();
// Executes the graph for "iters" times.
@@ -57,7 +57,7 @@ class Benchmark {
thread::ThreadPool* pool_ = nullptr;
Device* device_ = nullptr;
Rendezvous* rendez_ = nullptr;
- Executor* exec_ = nullptr;
+ std::unique_ptr<Executor> exec_;
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
};
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 245320c896..29f702699f 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -29,7 +29,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
+#ifndef DO_NOT_USE_ML
#include "i_malloc.h"
+#endif
#ifdef _WIN32
typedef unsigned int uint;
@@ -97,14 +99,14 @@ class MklCPUAllocator : public VisitableAllocator {
VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
kAllowGrowth, kName);
-
+#ifndef DO_NOT_USE_ML
// For redirecting all allocations from MKL to this allocator
// From: http://software.intel.com/en-us/node/528565
i_malloc = MallocHook;
i_calloc = CallocHook;
i_realloc = ReallocHook;
i_free = FreeHook;
-
+#endif
return Status::OK();
}
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 58361bf78f..8d3c9ff575 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <atomic>
#include <unordered_set>
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 03a011f79e..9e8002d490 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#ifndef PLATFORM_WINDOWS
-#include "grpc++/create_channel.h"
+#include "grpcpp/create_channel.h"
#else
// winsock2.h is used in grpc, so Ws2_32.lib is needed
#pragma comment(lib, "Ws2_32.lib")
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index ead698d787..9032823e17 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -145,9 +145,11 @@ tf_cc_test(
deps = [
":session_mgr",
":worker_env",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
)
@@ -226,6 +228,17 @@ tf_cc_test(
],
)
+cc_library(
+ name = "cancellable_call",
+ hdrs = ["cancellable_call.h"],
+ deps = [
+ ":call_options",
+ ":worker_cache",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
tf_cc_test(
name = "tensor_coding_test",
size = "small",
@@ -392,6 +405,7 @@ cc_library(
hdrs = ["master_env.h"],
deps = [
":worker_cache",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
],
@@ -453,10 +467,45 @@ cc_library(
)
cc_library(
+ name = "rpc_collective_executor_mgr",
+ srcs = ["rpc_collective_executor_mgr.cc"],
+ hdrs = ["rpc_collective_executor_mgr.h"],
+ deps = [
+ ":base_rendezvous_mgr",
+ ":collective_param_resolver_distributed",
+ ":collective_rma_distributed",
+ ":device_resolver_distributed",
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "rpc_collective_executor_mgr_test",
+ srcs = ["rpc_collective_executor_mgr_test.cc"],
+ deps = [
+ ":collective_param_resolver_distributed",
+ ":device_resolver_distributed",
+ ":rpc_collective_executor_mgr",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "collective_rma_distributed",
srcs = ["collective_rma_distributed.cc"],
hdrs = ["collective_rma_distributed.h"],
deps = [
+ ":cancellable_call",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -492,6 +541,7 @@ cc_library(
hdrs = ["collective_param_resolver_distributed.h"],
deps = [
":call_options",
+ ":cancellable_call",
":device_resolver_distributed",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/core/distributed_runtime/cancellable_call.h b/tensorflow/core/distributed_runtime/cancellable_call.h
new file mode 100644
index 0000000000..05089c7d15
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/cancellable_call.h
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
+
+#include <string>
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager.
+class CancellableCall {
+ public:
+ CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+ WorkerCacheInterface* wc)
+ : cancel_mgr_(cancel_mgr),
+ remote_worker_(remote_worker),
+ wc_(wc),
+ wi_(wc_->CreateWorker(remote_worker_)) {}
+
+ virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+ virtual void IssueCall(const StatusCallback& done) = 0;
+
+ void Start(const StatusCallback& done) {
+ CancellationToken token = cancel_mgr_->get_cancellation_token();
+ const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+ token, [this, token]() { opts_.StartCancel(); });
+ if (not_yet_cancelled) {
+ IssueCall([this, token, done](const Status& s) {
+ cancel_mgr_->DeregisterCallback(token);
+ done(s);
+ });
+ } else {
+ done(errors::Cancelled("RPC Request was cancelled"));
+ }
+ }
+
+ protected:
+ mutable mutex mu_;
+ CancellationManager* const cancel_mgr_; // Not owned
+ const string remote_worker_;
+ WorkerCacheInterface* const wc_; // Not owned
+ WorkerInterface* const wi_; // Owned by wc_, must be released.
+ CallOptions opts_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 7a93b54eae..612ac14e22 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -14,55 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
-#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/protobuf/config.pb.h"
-// TODO(tucker): When we're ready to enable collectives this const will
-// transition to a settable config member.
-static const char FLAGS_collective_group_leader[] =
- "/job:worker/replica:0/task:0";
-
namespace tensorflow {
namespace {
-// Supports client side cancellation of WorkerInterface calls via
-// registration with a CancellationManager. Note that ParamResolverInterface
-// calls are done on behalf of an Op execution which needs to abort if the
-// step in which it executes is cancelled.
-class CancellableCall {
- public:
- CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
- WorkerCacheInterface* wc)
- : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
- wi_ = wc_->CreateWorker(remote_worker_);
- }
- virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
-
- virtual void IssueCall(const StatusCallback& done) = 0;
-
- void Start(const StatusCallback& done) {
- CancellationToken token = cancel_mgr_->get_cancellation_token();
- const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
- token, [this, token]() { opts_.StartCancel(); });
- if (not_yet_cancelled) {
- IssueCall([this, token, done](const Status& s) {
- cancel_mgr_->DeregisterCallback(token);
- done(s);
- });
- } else {
- done(errors::Cancelled("RPC Request was cancelled"));
- }
- }
-
- protected:
- mutable mutex mu_;
- CancellationManager* cancel_mgr_; // Not owned
- const string remote_worker_;
- WorkerCacheInterface* wc_; // Not owned
- WorkerInterface* wi_; // Owned by wc_, must be released.
- CallOptions opts_;
-};
class CompleteGroupCall : public CancellableCall {
public:
@@ -126,9 +84,9 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
const string& task_name)
: CollectiveParamResolverLocal(dev_mgr, dev_resolver, task_name),
worker_cache_(worker_cache),
- group_leader_(task_name == FLAGS_collective_group_leader
+ group_leader_(task_name == config.experimental().collective_group_leader()
? ""
- : FLAGS_collective_group_leader) {}
+ : config.experimental().collective_group_leader()) {}
void CollectiveParamResolverDistributed::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index 95a010286d..4eed856759 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -147,10 +147,9 @@ class DeviceResDistTest : public ::testing::Test {
ConfigProto config;
for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w);
- // TODO(tucker): When config option becomes available, set here.
- // if (w == 0) {
- // config.set_collective_group_leader(name);
- // }
+ if (w == 0) {
+ config.mutable_experimental()->set_collective_group_leader(name);
+ }
DefineWorker(config, name, device_type, num_devices);
}
}
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index c15878bfd3..d4c47cab49 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
@@ -28,45 +29,6 @@ namespace tensorflow {
namespace {
-// Supports client side cancellation of WorkerInterface calls via
-// registration with a CancellationManager.
-//
-// TODO(tucker): Maybe unify this with CancellableCall in
-// collective_param_resolver_distributed.cc.
-class CancellableCall {
- public:
- CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
- WorkerCacheInterface* wc)
- : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
- wi_ = wc_->CreateWorker(remote_worker_);
- }
- virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
-
- virtual void IssueCall(const StatusCallback& done) = 0;
-
- void Start(const StatusCallback& done) {
- CancellationToken token = cancel_mgr_->get_cancellation_token();
- const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
- token, [this, token]() { opts_.StartCancel(); });
- if (not_yet_cancelled) {
- IssueCall([this, token, done](const Status& s) {
- cancel_mgr_->DeregisterCallback(token);
- done(s);
- });
- } else {
- done(errors::Cancelled("RPC Request was cancelled"));
- }
- }
-
- protected:
- mutable mutex mu_;
- CancellationManager* cancel_mgr_; // Not owned
- const string remote_worker_;
- WorkerCacheInterface* wc_; // Not owned
- WorkerInterface* wi_; // Owned by wc_, must be released.
- CallOptions opts_;
-};
-
class RecvBufCall : public CancellableCall {
public:
RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
@@ -119,7 +81,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
};
State* state = new State;
- // Logic to be executed on the RecvBufferAsync callback.
+ // Logic to be executed on the RecvBufAsync callback.
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
to_device_ctx, to_tensor, done](const Status& s) {
if (s.ok()) {
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 8447c55bf4..e2f13df19f 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -30,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -118,9 +120,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug(
Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
+ int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
Item* item) {
item->session = session;
+ item->collective_graph_key = collective_graph_key;
item->lib_def.reset(
new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
@@ -280,11 +284,12 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
+ int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* handle) {
Item* item = new Item;
- Status s =
- InitItem(session, gdef, graph_options, debug_options, cluster_flr, item);
+ Status s = InitItem(session, gdef, graph_options, debug_options,
+ collective_graph_key, cluster_flr, item);
if (!s.ok()) {
item->Unref();
return s;
@@ -415,7 +420,12 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
-
+ CollectiveExecutor::Handle* ce_handle =
+ item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
+ ? new CollectiveExecutor::Handle(
+ worker_env_->collective_executor_mgr->FindOrCreate(step_id),
+ true)
+ : nullptr;
// Sends values specified by the caller.
if (s.ok()) {
std::vector<string> keys;
@@ -431,22 +441,25 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
if (!s.ok()) {
done(s);
+ delete ce_handle;
item->Unref();
rendezvous->Unref();
return;
}
- StartParallelExecutors(handle, step_id, item, rendezvous, collector,
- cost_graph, cancellation_manager,
- [item, rendezvous, done](const Status& s) {
+ StartParallelExecutors(handle, step_id, item, rendezvous, ce_handle,
+ collector, cost_graph, cancellation_manager,
+ [item, rendezvous, ce_handle, done](const Status& s) {
done(s);
rendezvous->Unref();
item->Unref();
+ delete ce_handle;
});
}
void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
Item* item, Rendezvous* rendezvous,
+ CollectiveExecutor::Handle* ce_handle,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -471,6 +484,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
args.step_id = ++next_id_;
}
args.rendezvous = rendezvous;
+ args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
args.step_container = step_container;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index cc35264b8f..5196046c19 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -75,7 +76,7 @@ class GraphMgr {
// reference to cluster_flr to do cross process function calls.
Status Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
- const DebugOptions& debug_options,
+ const DebugOptions& debug_options, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* handle);
@@ -138,6 +139,8 @@ class GraphMgr {
// Used to deregister a cost model when cost model is required in graph
// manager.
GraphMgr* graph_mgr;
+
+ int64 collective_graph_key;
};
const WorkerEnv* worker_env_; // Not owned.
@@ -161,6 +164,7 @@ class GraphMgr {
void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
Rendezvous* rendezvous,
+ CollectiveExecutor::Handle* ce_handle,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -175,7 +179,7 @@ class GraphMgr {
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
- const DebugOptions& debug_options,
+ const DebugOptions& debug_options, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index 16f4d93c8b..da26c42aca 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -26,6 +26,7 @@ limitations under the License.
namespace tensorflow {
+class CollectiveExecutorMgrInterface;
class Device;
class DeviceSet;
class Env;
@@ -90,6 +91,10 @@ struct MasterEnv {
std::function<Status(const WorkerCacheFactoryOptions&,
WorkerCacheInterface**)>
worker_cache_factory;
+
+ // Generates per-step CollectiveExecutors and has access to utilities
+ // supporting collective operations.
+ CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
};
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index e29bb76ddf..d34ca53f73 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -69,6 +70,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
bool is_partial, WorkerCacheInterface* worker_cache,
bool should_deregister)
: session_handle_(handle),
+ bg_opts_(bopts),
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
@@ -100,6 +102,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const CallableOptions& callable_options() { return callable_opts_; }
+ const BuildGraphOptions& build_graph_options() { return bg_opts_; }
+
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
@@ -225,6 +229,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
private:
const string session_handle_;
+ const BuildGraphOptions bg_opts_;
const std::unique_ptr<ClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
@@ -444,6 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
+ c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -1065,6 +1071,9 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
*callable_opts->mutable_run_options()->mutable_debug_options() =
req.options().debug_options();
}
+
+ opts->collective_graph_key =
+ req.options().experimental().collective_graph_key();
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
@@ -1102,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
+ if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
+ h = Hash64Combine(opts.collective_graph_key, h);
+ }
+
return h;
}
@@ -1118,6 +1131,9 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
for (const string& name : opts.callable_options.fetch()) {
strings::StrAppend(&buf, " FeE: ", name);
}
+ if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
+ strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
+ }
strings::StrAppend(&buf, "\n");
return buf;
}
@@ -1430,11 +1446,35 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
rcg_map->clear();
}
-namespace {
-uint64 MakeStepId() {
- return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+uint64 MasterSession::NewStepId(int64 graph_key) {
+ if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // StepId must leave the most-significant 7 bits empty for future use.
+ return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
+ } else {
+ uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
+ int32 retry_count = 0;
+ while (step_id == CollectiveExecutor::kInvalidId) {
+ Notification note;
+ Status status;
+ env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
+ graph_key, [&status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ if (!status.ok()) {
+ LOG(ERROR) << "Bad status from "
+ "collective_executor_mgr->RefreshStepIdSequence: "
+ << status << ". Retrying.";
+ int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
+ Env::Default()->SleepForMicroseconds(delay_micros);
+ } else {
+ step_id = env_->collective_executor_mgr->NextStepId(graph_key);
+ }
+ }
+ return step_id;
+ }
}
-} // namespace
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
@@ -1456,15 +1496,13 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
- int64 count;
+ int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
- // Keeps the highest 8 bits 0x01: we reserve some bits of the
- // step_id for future use.
- const uint64 step_id = MakeStepId();
- TRACEPRINTF("stepid %llu", step_id);
rcg->Ref();
- RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
+ RunState* run_state =
+ new RunState(inputs, outputs, rcg,
+ NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
{
mutex_lock l(mu_);
partial_runs_.emplace(
@@ -1566,6 +1604,13 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
}
run_state = it->second.get();
}
+ // CollectiveOps are not supported in partial runs.
+ if (req.options().experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ return errors::InvalidArgument(
+ "PartialRun does not support Collective ops. collective_graph_key "
+ "must be kNoCollectiveGraphKey.");
+ }
// If this is the first partial run, initialize the PerStepState.
if (!run_state->step_started) {
@@ -1743,7 +1788,11 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
-
+ if (rcg->build_graph_options().collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ env_->collective_executor_mgr->RetireStepId(
+ rcg->build_graph_options().collective_graph_key, step_id);
+ }
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
} else if (errors::IsCancelled(s)) {
@@ -1801,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- const uint64 step_id = MakeStepId();
+ uint64 step_id = NewStepId(bgopts.collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1865,9 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- // Keeps the highest 8 bits 0x01: we reserve some bits of the
- // step_id for future use.
- const uint64 step_id = MakeStepId();
+ const uint64 step_id =
+ NewStepId(rcg->build_graph_options().collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index ec34e20b79..449a6d3e3c 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -141,6 +141,8 @@ class MasterSession : public core::RefCounted {
std::atomic<int64> partial_run_handle_counter_ = {0};
+ uint64 NewStepId(int64 graph_key);
+
mutex mu_;
std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
int64 graph_version_;
@@ -175,6 +177,7 @@ class MasterSession : public core::RefCounted {
std::unordered_map<string, bool> pending_outputs; // true if fetched
ReffedClientGraph* rcg = nullptr;
uint64 step_id;
+ int64 collective_graph_key;
int64 count = 0;
PerStepState pss;
std::unique_ptr<ProfileHandler> ph;
diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc
index 0826a90860..62b18a45b1 100644
--- a/tensorflow/core/distributed_runtime/master_test.cc
+++ b/tensorflow/core/distributed_runtime/master_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 4b2747f26d..2eadfcde54 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -274,11 +274,14 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
+ "//tensorflow/core/distributed_runtime:device_resolver_distributed",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:local_master",
"//tensorflow/core/distributed_runtime:master",
"//tensorflow/core/distributed_runtime:master_env",
"//tensorflow/core/distributed_runtime:master_session",
+ "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
index f5dc4c831d..9b863ccee5 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
@@ -74,7 +74,7 @@ class EagerGrpcServer : public GrpcServer {
this->eager_service_.reset(
new eager::GrpcEagerServiceImpl(worker_env, server_builder));
},
- nullptr));
+ nullptr, nullptr));
worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
"", worker_name_,
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
index 4786c43ee2..b23466037f 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
-#include "grpc++/generic/generic_stub.h"
+#include "grpcpp/generic/generic_stub.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
index 3fd7deaa86..39ab6856c5 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
namespace eager {
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
index d7b192ac85..66458186ad 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index 65550caf64..e94aedf535 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
-#include "grpc++/alarm.h"
-#include "grpc++/completion_queue.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/completion_queue.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
index ecad1274cc..90666def60 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
@@ -20,9 +20,9 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
-#include "grpc++/grpc++.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/server_builder.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index 613188244f..0ebc084cb6 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <map>
#include <unordered_map>
-#include "grpc++/create_channel.h"
+#include "grpcpp/create_channel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 48b9d958aa..4861cdb691 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -22,7 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
index d367b83ee7..6e7f5dbd13 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index e025e555dd..127dea2882 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -30,8 +30,8 @@ limitations under the License.
// RunGraph on workers.
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
-#include "grpc++/alarm.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
index 85adfd2c76..770a0fcf14 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
index 8f1b589698..751f2633e7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
#include "tensorflow/core/protobuf/master.pb.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 1acf1fb4fc..6008462d04 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
-#include "grpc++/generic/generic_stub.h"
-#include "grpc++/grpc++.h"
+#include "grpcpp/generic/generic_stub.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index e5ffb4ed2f..43dbe20836 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -19,14 +19,16 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
#include "grpc/support/alloc.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/local_master.h"
#include "tensorflow/core/distributed_runtime/master.h"
@@ -38,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
@@ -106,6 +109,7 @@ GrpcServer::~GrpcServer() {
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory) {
mutex_lock l(mu_);
@@ -204,6 +208,26 @@ Status GrpcServer::Init(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache);
+ if (collective_mgr_func) {
+ worker_env_.collective_executor_mgr =
+ collective_mgr_func(config, &worker_env_, worker_cache);
+ if (!worker_env_.collective_executor_mgr) {
+ return errors::Internal(
+ "collective_mgr_func did not return CollectiveExecutorMgr");
+ }
+ } else {
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver(
+ new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
+ default_worker_name));
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
+ new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
+ dev_resolver.get(), worker_cache,
+ default_worker_name));
+ worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
+ config, worker_env_.device_mgr, std::move(dev_resolver),
+ std::move(param_resolver), worker_cache, default_worker_name);
+ }
+
// Set up worker environment.
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
@@ -246,18 +270,21 @@ Status GrpcServer::Init(
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func) {
- return Init(std::move(service_func), rendezvous_mgr_func, worker_func,
- CreateNoOpStatsPublisher);
+ return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
+ worker_func, CreateNoOpStatsPublisher);
}
Status GrpcServer::Init(
ServiceInitFunction service_func,
- const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
- return Init(service_func, rendezvous_mgr_func, nullptr);
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func) {
+ return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
+ nullptr);
}
-Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
+Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
@@ -403,7 +430,7 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<GrpcServer> ret(
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
ServiceInitFunction service_func = nullptr;
- TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr));
*out_server = std::move(ret);
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 0122df178a..ca9946cafc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/env.h"
@@ -41,6 +42,11 @@ class Master;
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
+// function that creates a CollectiveExecutorMgr.
+typedef std::function<CollectiveExecutorMgrInterface*(
+ const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
+ CollectiveMgrCreationFunction;
+
// function that registers a service to the server. The service needs to
// be registered before builder.BuildAndStart().
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
@@ -71,15 +77,18 @@ class GrpcServer : public ServerInterface {
protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory);
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func);
Status Init(ServiceInitFunction service_func,
- const RendezvousMgrCreationFunction& rendezvous_mgr_func);
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func);
Status Init();
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index 59dbb7ae04..61c5bc285f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <utility>
-#include "grpc++/generic/generic_stub.h"
-#include "grpc++/grpc++.h"
+#include "grpcpp/generic/generic_stub.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
index e51894b4c7..d0684f1833 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
-#include "grpc++/support/byte_buffer.h"
-#include "grpc++/support/slice.h"
+#include "grpcpp/support/byte_buffer.h"
+#include "grpcpp/support/slice.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
index 71f69e9024..7cace573e8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
-#include "grpc++/support/byte_buffer.h"
-#include "grpc++/support/slice.h"
+#include "grpcpp/support/byte_buffer.h"
+#include "grpcpp/support/slice.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
index f247322bc4..e52b257411 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <iostream>
#include <vector>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
index e718db251c..33cbadda0a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include <vector>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
index 4b58781b54..45259aa2ec 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/support/byte_buffer.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/support/byte_buffer.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index aa9304a033..61f5369617 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <deque>
-#include "grpc++/alarm.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
index 38cc2b81d3..72b5e77f1c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index da270835bd..7915c3aafd 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -16,15 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
-#include "grpc++/support/byte_buffer.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
+#include "grpcpp/support/byte_buffer.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
new file mode 100644
index 0000000000..5eeed6e382
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name)
+ : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
+ std::move(param_resolver)),
+ worker_cache_(worker_cache),
+ task_name_(task_name) {
+ group_leader_ = (task_name == config.experimental().collective_group_leader())
+ ? ""
+ : config.experimental().collective_group_leader();
+}
+
+RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
+ for (auto it : sequence_table_) {
+ delete it.second;
+ }
+}
+
+CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessDistributed* rma =
+ new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+ worker_cache_, step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
+namespace {
+// StepId must leave the most-significant 7 bits empty for future use.
+static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
+
+int64 NewRandomStepId() {
+ int64 step_id = random::New64();
+ // Leave MS 8 bits clear for future use.
+ step_id &= kStepIdMask;
+ return step_id;
+}
+} // namespace
+
+void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+ int64 graph_key, const StatusCallback& done) {
+ if (group_leader_.empty()) {
+ mutex_lock l(sequence_mu_);
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(graph_key);
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = NewRandomStepId();
+ done(Status::OK());
+ } else {
+ WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_);
+ GetStepSequenceRequest* req = new GetStepSequenceRequest;
+ GetStepSequenceResponse* resp = new GetStepSequenceResponse;
+ req->add_graph_key(graph_key);
+ wi->GetStepSequenceAsync(
+ req, resp, [this, req, resp, done](const Status& s) {
+ if (!s.ok()) {
+ LOG(ERROR) << "Bad response [" << s
+ << "] from GetStepSequenceAsync call to "
+ << group_leader_;
+ done(s);
+ } else {
+ done(UpdateStepSequences(*resp));
+ }
+ delete req;
+ delete resp;
+ });
+ }
+}
+
+Status RpcCollectiveExecutorMgr::UpdateStepSequences(
+ const GetStepSequenceResponse& resp) {
+ mutex_lock l(sequence_mu_);
+ for (const StepSequence& ss : resp.step_sequence()) {
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(ss.graph_key());
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(ss.graph_key());
+ sequence_table_[ss.graph_key()] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = ss.next_step_id();
+ }
+ return Status::OK();
+}
+
+int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ return it->second->next_step_id_;
+ }
+ return CollectiveExecutor::kInvalidId;
+}
+
+void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ if (step_id == it->second->next_step_id_) {
+ it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
+ } else {
+ it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
+ }
+ } else {
+ LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
new file mode 100644
index 0000000000..e9f3f0ebe8
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
+
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+
+namespace tensorflow {
+class CollectiveParamResolverDistributed;
+class ConfigProto;
+class DeviceMgr;
+class DeviceResolverDistributed;
+class WorkerCacheInterface;
+class StepSequenceRequest;
+class StepSequenceResponse;
+
+// An implementation of CollectiveExecutorMgr for a distributed environment
+// that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs.
+//
+// In some execution environments it may be possible to implement a
+// higher-performance solution and use it in place of this class.
+class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
+ public:
+ RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name);
+
+ virtual ~RpcCollectiveExecutorMgr();
+
+ void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) override;
+
+ int64 NextStepId(int64 graph_key) override;
+
+ void RetireStepId(int64 graph_key, int64 step_id) override;
+
+ protected:
+ CollectiveExecutor* Create(int64 step_id) override;
+
+ WorkerCacheInterface* const worker_cache_; // Not owned.
+ const string task_name_;
+ string group_leader_;
+ friend class RpcCollectiveExecutorMgrTest;
+
+ private:
+ Status UpdateStepSequences(const GetStepSequenceResponse& resp);
+
+ // This class maintains the step_id sequencing for a single
+ // collective_graph_key.
+ struct GraphKeySequence {
+ explicit GraphKeySequence(int64 k)
+ : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {}
+
+ const int64 graph_key_;
+ int64 next_step_id_;
+ };
+
+ mutex sequence_mu_;
+ gtl::FlatMap<int64, GraphKeySequence*> sequence_table_
+ GUARDED_BY(sequence_mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
new file mode 100644
index 0000000000..37b83d82be
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+#define NUM_DEVS 3
+
+class RpcCollectiveExecutorMgrTest : public ::testing::Test {
+ protected:
+ RpcCollectiveExecutorMgrTest() {
+ string task_name = "/job:localhost/replica:0/task:0";
+ SessionOptions options;
+ options.config.mutable_experimental()->set_collective_group_leader(
+ task_name);
+ WorkerCacheInterface* worker_cache = nullptr;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
+ device_mgr_.get(), worker_cache, task_name));
+ std::unique_ptr<CollectiveParamResolverDistributed> cpr(
+ new CollectiveParamResolverDistributed(options.config,
+ device_mgr_.get(), dr.get(),
+ worker_cache, task_name));
+ // This CME is the group leader.
+ cme_.reset(new RpcCollectiveExecutorMgr(options.config, device_mgr_.get(),
+ std::move(dr), std::move(cpr),
+ worker_cache, task_name));
+ }
+
+ std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
+TEST_F(RpcCollectiveExecutorMgrTest, FindOrCreate) {
+ CollectiveExecutor::Handle* h =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_TRUE(h->get());
+ CollectiveExecutor::Handle* h2 =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(h->get(), h2->get());
+ CollectiveExecutor* ce = h->get();
+ delete h;
+ delete h2;
+ CollectiveExecutor* ce2 = cme_->FindOrCreate(1);
+ EXPECT_EQ(ce, ce2);
+ ce2->Unref();
+ cme_->Cleanup(1);
+}
+
+TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) {
+ int64 x = cme_->NextStepId(7);
+ EXPECT_EQ(x, CollectiveExecutor::kInvalidId);
+ // Calling Refresh should generate a valid id.
+ {
+ Notification note;
+ Status status;
+ cme_->RefreshStepIdSequenceAsync(7,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ EXPECT_TRUE(status.ok());
+ }
+ x = cme_->NextStepId(7);
+ EXPECT_NE(x, CollectiveExecutor::kInvalidId);
+ // Should keep returning same number.
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ // Retire on a different graph_key should have no effect.
+ cme_->RetireStepId(6, x);
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ // Retire on same graph_key should advance.
+ cme_->RetireStepId(7, x);
+ int64 y = cme_->NextStepId(7);
+ EXPECT_EQ((x + 1) & (((1uLL << 56) - 1) | (1uLL << 56)), y);
+ // Calling refresh should jump to a different point in the random space.
+ {
+ Notification note;
+ Status status;
+ cme_->RefreshStepIdSequenceAsync(7,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ int64 z = cme_->NextStepId(7);
+ // z should not be equal to or a successor of y.
+ EXPECT_NE(y, z);
+ EXPECT_GT(llabs(y - z), 3);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 4e6500fbc6..1ea19c48f0 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
@@ -72,7 +73,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(),
request->graph_options(), request->debug_options(),
- session->cluster_flr.get(), response->mutable_graph_handle());
+ request->collective_graph_key(), session->cluster_flr.get(),
+ response->mutable_graph_handle());
}
done(s);
}
@@ -315,6 +317,12 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->Cleanup(step_id);
}
+ for (Device* d : env_->local_devices) {
+ ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
+ if (sam) {
+ sam->Cleanup(step_id);
+ }
+ }
done(Status::OK());
}
diff --git a/tensorflow/core/framework/cost_graph.proto b/tensorflow/core/framework/cost_graph.proto
index 19d765cd32..cc6bc84d69 100644
--- a/tensorflow/core/framework/cost_graph.proto
+++ b/tensorflow/core/framework/cost_graph.proto
@@ -69,6 +69,9 @@ message CostGraphDef {
// Ids of the control inputs for this node.
repeated int32 control_input = 8;
+
+ // Are the costs inaccurate?
+ bool inaccurate = 17;
}
repeated Node node = 1;
}
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index ec26d92a61..b59ced869d 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -186,6 +186,10 @@ class DeviceBase {
virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
+ const bool has_eigen_cpu_device() const {
+ return (eigen_cpu_device_ != nullptr);
+ }
+
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 647c66099c..88d9d65f5a 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -815,6 +815,10 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
}
+ if (!options.executor_type.empty()) {
+ entries.push_back(
+ strings::StrCat("_executor_type", "=", options.executor_type));
+ }
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 872906756a..8e607b927c 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -450,6 +450,12 @@ class FunctionLibraryRuntime {
// state (in stateful kernels); and two functions with different
// values for `state_handle` will have independent state.
string state_handle;
+
+ // This interface is EXPERIMENTAL and subject to change.
+ //
+ // Instatiates the function using an executor of the given type. If empty,
+ // the default TensorFlow executor will be used.
+ string executor_type;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index b05a9df7c1..a0f449d64f 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/op_kernel.h"
#include <unordered_map>
#include <utility>
#include <vector>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -270,6 +273,19 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
+ if (params->device->has_eigen_cpu_device()) {
+ int64 block_size = -1, output_size = -1, num_threads = 1;
+ const Eigen::ThreadPoolDevice* thread_pool =
+ params_->device->eigen_cpu_device();
+ AttrSlice attributes(op_kernel().def());
+ if (GetNodeAttr(attributes, "_block_size", &block_size) == Status::OK() &&
+ GetNodeAttr(attributes, "_output_size", &output_size) == Status::OK()) {
+ num_threads = std::min(Eigen::divup(output_size, block_size),
+ static_cast<int64>(thread_pool->numThreads()));
+ eigen_cpu_device_ = MakeUnique<Eigen::ThreadPoolDevice>(
+ thread_pool->getPool(), num_threads);
+ }
+ }
}
OpKernelContext::~OpKernelContext() {
@@ -1120,6 +1136,16 @@ void LogAllRegisteredKernels() {
}
}
+std::vector<KernelDef> GetAllRegisteredKernels() {
+ const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
+ std::vector<KernelDef> kernels;
+ kernels.reserve(typed_registry->size());
+ for (const auto& p : *typed_registry) {
+ kernels.emplace_back(p.second.def);
+ }
+ return kernels;
+}
+
string KernelsRegisteredForOp(StringPiece op_name) {
string ret;
for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index f577664709..d307078e63 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
#include <functional>
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
@@ -1003,6 +1004,7 @@ class OpKernelContext {
// OpKernels can use these eigen devices to carry out their
// numerical computation.
const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
+ if (eigen_cpu_device_ != nullptr) return *eigen_cpu_device_;
return *device()->eigen_cpu_device();
}
const Eigen::GpuDevice& eigen_gpu_device() const {
@@ -1138,6 +1140,7 @@ class OpKernelContext {
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
gtl::InlinedVector<TensorValue, 4> outputs_;
+ std::unique_ptr<Eigen::ThreadPoolDevice> eigen_cpu_device_;
// Constructed only if <params->record_tensor_accesses>.
ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
@@ -1303,6 +1306,9 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
// missing kernel errors.
void LogAllRegisteredKernels();
+// Gets a vector of all registered kernels.
+std::vector<KernelDef> GetAllRegisteredKernels();
+
namespace kernel_factory {
class OpKernelRegistrar {
@@ -1572,4 +1578,4 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index bcd409e5c5..50319ca576 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -964,5 +964,27 @@ void BM_SelectInputRange(int iters) {
BENCHMARK(BM_ConcatInputRange);
BENCHMARK(BM_SelectInputRange);
+TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
+ auto all_registered_kernels = GetAllRegisteredKernels();
+ auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
+
+ // Verify we can find the "Test1" op registered above
+ auto test1_it = std::find_if(all_registered_kernels.begin(),
+ all_registered_kernels.end(), has_name_test1);
+ ASSERT_NE(test1_it, all_registered_kernels.end());
+ EXPECT_EQ(test1_it->device_type(), "CPU");
+
+ // Verify there was just one kernel
+ ++test1_it;
+ EXPECT_EQ(
+ std::find_if(test1_it, all_registered_kernels.end(), has_name_test1),
+ all_registered_kernels.end());
+}
+
+// Simple test just to check we can call LogAllRegisteredKernels
+TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
+ tensorflow::LogAllRegisteredKernels();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc
index 30ff19cd7e..fea25560d8 100644
--- a/tensorflow/core/graph/control_flow.cc
+++ b/tensorflow/core/graph/control_flow.cc
@@ -24,8 +24,8 @@ limitations under the License.
namespace tensorflow {
-Status BuildControlFlowInfo(const Graph* g,
- std::vector<ControlFlowInfo>* info) {
+Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
+ std::vector<string>* unreachable_nodes) {
info->clear();
info->resize(g->num_node_ids());
@@ -114,6 +114,13 @@ Status BuildControlFlowInfo(const Graph* g,
}
}
}
+ if (unreachable_nodes) {
+ for (const Node* node : g->op_nodes()) {
+ if (!parent_nodes[node->id()]) {
+ unreachable_nodes->push_back(node->name());
+ }
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 79e2be0d4b..8605d57c14 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -33,11 +33,15 @@ struct ControlFlowInfo {
// Clear and populate `info` with each node's frame and the level it belongs to.
// We check the well-formedness of the graph: All inputs to a node must come
// from the same frame and have the same "static" iteration level.
+// If `unreachable_nodes` is set, return names of nodes unreachable from the
+// source node. We cannot build ControlFlowInfo for such nodes. They might be
+// pruned later.
//
// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level 0.
// This essentially means there can't be multiple serial Nexts in an iteration,
// which all sane front-ends should satisfy.
-Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info);
+Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
+ std::vector<string>* unreachable_nodes = nullptr);
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index c8ba4dfbda..a60e3c7a9f 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -98,6 +98,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
node_costs.compute_time.asMicroSeconds().count());
cost_node->set_memory_time(
node_costs.memory_time.asMicroSeconds().count());
+ cost_node->set_inaccurate(node_costs.inaccurate);
for (const auto& output : op_context.op_info.outputs()) {
auto output_info = cost_node->add_output_info();
output_info->set_dtype(output.dtype());
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b8e337582c..b994d26397 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -45,6 +45,7 @@ constexpr char kIdentityN[] = "IdentityN";
constexpr char kRefIdentity[] = "RefIdentity";
constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
+constexpr char kSqueeze[] = "Squeeze";
constexpr char kRecv[] = "_Recv";
constexpr char kSend[] = "_Send";
constexpr char kBatchMatMul[] = "BatchMatMul";
@@ -232,6 +233,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kSqueeze, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 5ed73eec50..2073c2968b 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -792,6 +792,9 @@ tf_cc_test(
name = "scoped_allocator_optimizer_test",
size = "small",
srcs = ["scoped_allocator_optimizer_test.cc"],
+ tags = [
+ "nomsan",
+ ],
deps = [
":scoped_allocator_optimizer",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 121de1e089..08fc9d84da 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -68,10 +68,45 @@ tf_cc_test(
)
cc_library(
+ name = "shuffle_and_repeat_fusion",
+ srcs = ["shuffle_and_repeat_fusion.cc"],
+ hdrs = [
+ "shuffle_and_repeat_fusion.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "shuffle_and_repeat_fusion_test",
+ srcs = ["shuffle_and_repeat_fusion_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":shuffle_and_repeat_fusion",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
name = "data",
visibility = ["//visibility:public"],
deps = [
":map_and_batch_fusion",
+ ":shuffle_and_repeat_fusion",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index df12de37da..aece142f7a 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -28,6 +28,8 @@ namespace grappler {
namespace graph_utils {
namespace {
+constexpr char kConstOpName[] = "Const";
+
int FindNodeWithPredicate(const std::function<bool(const NodeDef&)>& predicate,
const GraphDef& graph) {
for (int i = 0; i < graph.node_size(); ++i) {
@@ -68,9 +70,8 @@ Status AddScalarConstNodeHelper(
DataType dtype, const std::function<void(TensorProto*)>& add_value,
GraphDef* graph, NodeDef** result) {
NodeDef* node = graph->add_node();
- const string& name = strings::StrCat("Const/_", graph->node_size());
- node->set_name(name);
- node->set_op("Const");
+ node->set_op(kConstOpName);
+ SetUniqueName(kConstOpName, graph, node);
(*node->mutable_attr())["dtype"].set_type(dtype);
std::unique_ptr<tensorflow::TensorProto> tensor =
tensorflow::MakeUnique<tensorflow::TensorProto>();
@@ -94,7 +95,7 @@ Status AddNode(const string& name, const string& op,
if (!name.empty()) {
node->set_name(name);
} else {
- node->set_name(strings::StrCat(op, "/_", graph->node_size()));
+ SetUniqueName(op, graph, node);
}
node->set_op(op);
for (const string& input : inputs) {
@@ -212,6 +213,14 @@ int FindNodeWithOp(const string& op, const GraphDef& graph) {
[op](const NodeDef& node) { return node.op() == op; }, graph);
}
+void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node) {
+ int id = graph->node_size();
+ while (ContainsNodeWithName(strings::StrCat(op, "/_", id), *graph)) {
+ ++id;
+ }
+ node->set_name(strings::StrCat(op, "/_", id));
+}
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index b40ca44d78..3d2467031f 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -74,6 +74,10 @@ int FindNodeWithName(const string& name, const GraphDef& graph);
// exists.
int FindNodeWithOp(const string& op, const GraphDef& graph);
+// Sets the node name using the op name as a prefix while guaranteeing the name
+// is unique across the graph.
+void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index b34726044e..00f66c9bc1 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -136,6 +136,21 @@ TEST_F(GraphUtilsTest, FindNodeWithOp) {
EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
}
+TEST_F(GraphUtilsTest, SetUniqueName) {
+ GraphDef graph;
+
+ NodeDef* node1;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node1));
+ NodeDef* node2;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node2));
+ EXPECT_NE(node1->name(), node2->name());
+
+ TF_EXPECT_OK(DeleteNodes({node1->name()}, &graph));
+ NodeDef* node3;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node3));
+ EXPECT_NE(node2->name(), node3->name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 290326ab75..1e8cbb9784 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -28,6 +28,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+namespace {
+
+constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
+
+} // namespace
Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
@@ -35,25 +40,24 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphView graph(output);
std::set<string> nodes_to_delete;
for (const NodeDef& node : item.graph.node()) {
- if (node.op() != "BatchDataset") {
+ if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
continue;
}
- // Use a more descriptive variable name now that we now the node type.
- NodeDef batch_node(node);
+ // Use a more descriptive variable name now that we know the node type.
+ const NodeDef batch_node(node);
GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
NodeDef* node2 = graph.GetRegularFanin(input_port).node;
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
- // Use a more descriptive variable name now that we now the node type.
- NodeDef* map_node = node2;
- NodeDef* new_node = output->mutable_node()->Add();
- new_node->set_op("MapAndBatchDatasetV2");
- new_node->set_name(
- strings::StrCat("MapAndBatchDatasetV2/_", output->node_size()));
+ NodeDef* new_node = output->add_node();
+ new_node->set_op(kFusedOpName);
+ graph_utils::SetUniqueName(kFusedOpName, output, new_node);
+ // Use a more descriptive variable name now that we know the node type.
+ NodeDef* map_node = node2;
// Set the `input` input argument.
new_node->add_input(map_node->input(0));
@@ -89,7 +93,9 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Set the `drop_remainder` input argument.
- {
+ if (batch_node.op() == "BatchDatasetV2") {
+ new_node->add_input(batch_node.input(2));
+ } else {
NodeDef* tmp;
TF_RETURN_IF_ERROR(
graph_utils::AddScalarConstNode<bool>(false, output, &tmp));
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
index a5a4d91df6..2c64831105 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
@@ -23,13 +23,13 @@ namespace grappler {
class MapAndBatchFusion : public CustomGraphOptimizer {
public:
- MapAndBatchFusion() {}
- ~MapAndBatchFusion() override {}
+ MapAndBatchFusion() = default;
+ ~MapAndBatchFusion() override = default;
string name() const override { return "map_and_batch_fusion"; };
- Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
- nullptr) override {
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index 8c7498dc5d..3c1d8d5359 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -112,6 +112,95 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
batch_node->attr().at("output_types")));
}
+TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ std::vector<std::pair<string, AttrValue>> range_attrs;
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, graph, &range_node));
+ NodeDef *captured_input_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
+ "hello", graph, &captured_input_node));
+
+ NodeDef *map_node;
+ {
+ std::vector<string> map_inputs(2);
+ map_inputs[0] = range_node->name();
+ map_inputs[1] = captured_input_node->name();
+ std::vector<std::pair<string, AttrValue>> map_attrs(2);
+ AttrValue f_attr;
+ SetAttrValue("f", &f_attr);
+ map_attrs[0] = std::make_pair("f", f_attr);
+ AttrValue args_attr;
+ SetAttrValue("Targuments", &args_attr);
+ map_attrs[1] = std::make_pair("Targuments", args_attr);
+ TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs,
+ graph, &map_node));
+ }
+
+ NodeDef *batch_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ NodeDef *drop_remainder_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<bool>(true, graph, &drop_remainder_node));
+ NodeDef *batch_node;
+ {
+ std::vector<string> batch_inputs(3);
+ batch_inputs[0] = map_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ batch_inputs[2] = drop_remainder_node->name();
+ std::vector<std::pair<string, AttrValue>> batch_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ batch_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ batch_attrs[1] = std::make_pair("output_types", types_attr);
+ TF_ASSERT_OK(graph_utils::AddNode("", "BatchDatasetV2", batch_inputs,
+ batch_attrs, graph, &batch_node));
+ }
+
+ MapAndBatchFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node =
+ output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_EQ(map_and_batch_node.input_size(), 5);
+ EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
+ EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
+ EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
+ NodeDef num_parallel_calls_node = output.node(
+ graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0),
+ 1);
+ EXPECT_EQ(map_and_batch_node.input(4), batch_node->input(2));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"),
+ map_node->attr().at("f")));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("Targuments"),
+ map_node->attr().at("Targuments")));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_shapes"),
+ batch_node->attr().at("output_shapes")));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_types"),
+ batch_node->attr().at("output_types")));
+}
+
TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
GrapplerItem item;
GraphDef *graph = &item.graph;
@@ -204,10 +293,9 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
}
TEST(MapAndBatchFusionTest, NoChange) {
- std::vector<std::pair<string, AttrValue>> empty_attributes;
-
GrapplerItem item;
GraphDef *graph = &item.graph;
+
NodeDef *start_node;
TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
NodeDef *stop_node;
@@ -219,9 +307,27 @@ TEST(MapAndBatchFusionTest, NoChange) {
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
+ std::vector<std::pair<string, AttrValue>> range_attrs;
NodeDef *range_node;
TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
- empty_attributes, graph, &range_node));
+ range_attrs, graph, &range_node));
+
+ NodeDef *batch_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ std::vector<string> batch_inputs(2);
+ batch_inputs[0] = range_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ std::vector<std::pair<string, AttrValue>> batch_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ batch_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ batch_attrs[1] = std::make_pair("output_types", types_attr);
+ NodeDef *batch_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ batch_attrs, graph, &batch_node));
MapAndBatchFusion optimizer;
GraphDef output;
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
new file mode 100644
index 0000000000..0df73b33ed
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset";
+
+} // namespace
+
+Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ GraphView graph(output);
+ std::set<string> nodes_to_delete;
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() != "RepeatDataset") {
+ continue;
+ }
+
+ // Use a more descriptive variable name now that we know the node type.
+ const NodeDef repeat_node(node);
+ GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
+ NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ if (node2->op() != "ShuffleDataset") {
+ continue;
+ }
+
+ NodeDef* new_node = output->add_node();
+ new_node->set_op(kFusedOpName);
+ graph_utils::SetUniqueName(kFusedOpName, output, new_node);
+
+ // Use a more descriptive variable name now that we know the node type.
+ NodeDef* shuffle_node = node2;
+
+ // Set the `input` input argument.
+ new_node->add_input(shuffle_node->input(0));
+
+ // Set the `buffer_size` input argument.
+ new_node->add_input(shuffle_node->input(1));
+
+ // Set the `seed` input argument.
+ new_node->add_input(shuffle_node->input(2));
+
+ // Set the `seed2` input argument.
+ new_node->add_input(shuffle_node->input(3));
+
+ // Set the `count` input argument.
+ new_node->add_input(repeat_node.input(1));
+
+ // Set `output_types` and `output_shapes` attributes.
+ for (auto key : {"output_shapes", "output_types"}) {
+ (*new_node->mutable_attr())[key] = repeat_node.attr().at(key);
+ }
+
+ // Mark the `Shuffle` and `Repeat` nodes for removal.
+ nodes_to_delete.insert(shuffle_node->name());
+ nodes_to_delete.insert(repeat_node.name());
+
+ // Update the input of the outputs of the `Repeat` node to use
+ // `ShuffleAndRepeat`.
+ GraphView::OutputPort output_port =
+ graph.GetOutputPort(repeat_node.name(), 0);
+ auto fanout = graph.GetFanout(output_port);
+ for (auto it = fanout.begin(); it != fanout.end(); ++it) {
+ NodeDef* node = it->node;
+ node->set_input(0, new_node->name());
+ }
+ }
+ TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+ return Status::OK();
+}
+
+void ShuffleAndRepeatFusion::Feedback(Cluster* cluster,
+ const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(ShuffleAndRepeatFusion,
+ "shuffle_and_repeat_fusion");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h
new file mode 100644
index 0000000000..c8fa53edce
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class ShuffleAndRepeatFusion : public CustomGraphOptimizer {
+ public:
+ ShuffleAndRepeatFusion() = default;
+ ~ShuffleAndRepeatFusion() override = default;
+
+ string name() const override { return "shuffle_and_repeat_fusion"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
new file mode 100644
index 0000000000..e89675efb7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ std::vector<std::pair<string, AttrValue>> common_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ common_attrs[1] = std::make_pair("output_types", types_attr);
+
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ common_attrs, graph, &range_node));
+
+ NodeDef *buffer_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(128, graph, &buffer_size_node));
+ NodeDef *seed_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &seed_node));
+ NodeDef *seed2_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &seed2_node));
+ std::vector<string> shuffle_inputs(4);
+ shuffle_inputs[0] = range_node->name();
+ shuffle_inputs[1] = buffer_size_node->name();
+ shuffle_inputs[2] = seed_node->name();
+ shuffle_inputs[3] = seed2_node->name();
+ NodeDef *shuffle_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "ShuffleDataset", shuffle_inputs,
+ common_attrs, graph, &shuffle_node));
+
+ NodeDef *count_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &count_node));
+ std::vector<string> repeat_inputs(2);
+ repeat_inputs[0] = shuffle_node->name();
+ repeat_inputs[1] = count_node->name();
+ NodeDef *repeat_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RepeatDataset", repeat_inputs,
+ common_attrs, graph, &repeat_node));
+
+ ShuffleAndRepeatFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(shuffle_node->name(), output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(repeat_node->name(), output));
+ EXPECT_TRUE(
+ graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
+ NodeDef shuffle_and_repeat_node = output.node(
+ graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
+ EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
+ EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
+ EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
+ EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
+ EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
+ EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
+ EXPECT_TRUE(
+ AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"),
+ repeat_node->attr().at("output_shapes")));
+ EXPECT_TRUE(
+ AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"),
+ repeat_node->attr().at("output_types")));
+}
+
+TEST(ShuffleAndRepeatFusionTest, NoChange) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ std::vector<std::pair<string, AttrValue>> common_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ common_attrs[1] = std::make_pair("output_types", types_attr);
+
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ common_attrs, graph, &range_node));
+
+ NodeDef *count_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &count_node));
+ std::vector<string> repeat_inputs(2);
+ repeat_inputs[0] = range_node->name();
+ repeat_inputs[1] = count_node->name();
+ NodeDef *repeat_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RepeatDataset", repeat_inputs,
+ common_attrs, graph, &repeat_node));
+
+ ShuffleAndRepeatFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_TRUE(graph_utils::Compare(*graph, output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index e08ab1eb67..3251e7cb10 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -499,6 +499,7 @@ class NodeProcessor : public GraphProcessor {
UpdateAttrDataFormat();
UpdateAttrKSize();
UpdateAttrStrides();
+ UpdateAttrDilations();
UpdateAttrShape();
TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
@@ -742,6 +743,13 @@ class NodeProcessor : public GraphProcessor {
}
}
+ void UpdateAttrDilations() {
+ if (node_->attr().find("dilations") != node_->attr().end()) {
+ auto list = node_->mutable_attr()->at("dilations").mutable_list();
+ UpdateTuple(list);
+ }
+ }
+
void UpdateAttrDataFormat() {
if (node_->attr().find("data_format") != node_->attr().end()) {
if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index dad49cd74f..20e47c1b26 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -87,12 +87,13 @@ class LayoutOptimizerTest : public GrapplerTest {
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
int filter_size, const string& padding) {
- return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true);
+ return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true,
+ true);
}
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
int filter_size, const string& padding,
- bool const_input_size) {
+ bool const_input_size, bool dilated) {
int batch_size = 128;
int input_height = input_size;
int input_width = input_size;
@@ -123,14 +124,18 @@ class LayoutOptimizerTest : public GrapplerTest {
Output conv_backprop_input;
Output input_sizes_i =
ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
+ ops::Conv2DBackpropInput::Attrs attrs;
+ if (dilated) {
+ attrs = attrs.Dilations({1, 2, 2, 1});
+ }
if (const_input_size) {
conv_backprop_input = ops::Conv2DBackpropInput(
s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
- {1, stride, stride, 1}, padding);
+ {1, stride, stride, 1}, padding, attrs);
} else {
conv_backprop_input = ops::Conv2DBackpropInput(
s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
- {1, stride, stride, 1}, padding);
+ {1, stride, stride, 1}, padding, attrs);
}
return conv_backprop_input;
}
@@ -216,7 +221,7 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false);
+ auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false, false);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 622fb134a1..03e36a7b9c 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -201,7 +201,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
}
}
if (optimizable) {
- VLOG(1)<< "Optimizing fused batch norm node " << node.DebugString();
+ VLOG(1) << "Optimizing fused batch norm node " << node.DebugString();
AddBatchNormNodes(optimized_graph, node);
continue;
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index c7c7879714..5e4c8a78b0 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2812,6 +2812,9 @@ tf_kernel_library(
srcs = [] + if_mkl([
"mkl_batch_matmul_op.cc",
]),
+ # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
+ copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "batch_matmul_op",
deps = MATH_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
@@ -2879,6 +2882,9 @@ tf_kernel_library(
"mkl_matmul_op.cc",
]),
hdrs = ["matmul_op.h"],
+ # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
+ copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
defines = select({
":xsmm": [
"TENSORFLOW_USE_LIBXSMM",
@@ -3248,8 +3254,7 @@ tf_kernel_library(
"//conditions:default": [],
}),
# Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
- # So that it doesn't take 20 minutes to compile conv_grad_ops_3d.cc and conv_ops_3d.cc
- # on Windows. See https://github.com/tensorflow/tensorflow/issues/10521
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
defines = select({
":xsmm_convolutions": [
@@ -3395,6 +3400,9 @@ tf_kernel_library(
tf_kernel_library(
name = "lrn_op",
+ # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
+ copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "lrn_op",
deps = NN_DEPS,
)
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 96216764fd..b77c80c01f 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,7 +17,7 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL)
+#if !defined(INTEL_MKL) || defined(DO_NOT_USE_ML)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
#endif
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 87a0795f2f..fe259c1634 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL)
+#if !defined(INTEL_MKL) || defined(DO_NOT_USE_ML)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
#endif
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 53bdd482cb..48afd3fbf3 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -255,7 +255,7 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
// node_ids
const Tensor* node_ids_t;
OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
- const auto node_ids = node_ids_t->flat<int32>();
+ const auto node_ids = node_ids_t->vec<int32>();
// gradients
const Tensor* gradients_t;
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
@@ -270,46 +270,34 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
&bucketized_features_list));
// Infer batch size.
const int64 batch_size = node_ids_t->dim_size(0);
- // Allocate output stats tensor (Rank 4).
- Tensor* output_stats_summary_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "stats_summary",
- {num_features_, max_splits_, num_buckets_, 2},
- &output_stats_summary_t));
- auto output_stats_summary = output_stats_summary_t->flat<float>();
- EIGEN_STATIC_ASSERT(
- (static_cast<int>(decltype(output_stats_summary)::Layout) ==
- static_cast<int>(Eigen::RowMajor)),
- THIS_METHOD_IS_ONLY_FOR_ROW_MAJOR_MATRICES);
- const int shift_per_node = num_buckets_ * 2;
- const int shift_per_feature = shift_per_node * max_splits_;
- const int32 max_index = num_features_ * shift_per_feature;
- // We use double to sum the gradients and hessians, due to possible
- // precision loss when summing small float values.
- std::vector<double> res(max_index, 0);
+ // Allocate temporary stats tensor (Rank 4).
+ Tensor temp_stats_double_t;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DT_DOUBLE,
+ {num_features_, max_splits_, num_buckets_, 2},
+ &temp_stats_double_t));
+ auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
+ temp_stats_double.setZero();
// Partition by node, and then bucketize.
- int feature_idx = 0;
- int feature_shift = 0;
- for (const Tensor& tensor : bucketized_features_list) {
- const auto& features = tensor.flat<int32>();
+ for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+ const auto& features = bucketized_features_list[feature_idx].vec<int32>();
for (int i = 0; i < batch_size; ++i) {
const int32 node = node_ids(i);
const int32 bucket = features(i);
- // Calculate the index in the flattened vector for
- // [feature_idx][node][bucket][0].
- const int index = feature_shift + node * shift_per_node + bucket * 2;
- res[index] += gradients(i, 0);
- res[index + 1] += hessians(i, 0);
+ temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
+ temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
}
- ++feature_idx;
- feature_shift += shift_per_feature;
- }
- // Copy over the results.
- for (int i = 0; i < max_index; ++i) {
- output_stats_summary(i) = res[i];
}
+
+ // Copy temp tensor over to output tensor.
+ Tensor* output_stats_summary_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "stats_summary", temp_stats_double_t.shape(),
+ &output_stats_summary_t));
+ output_stats_summary_t->tensor<float, 4>() =
+ temp_stats_double.template cast<float>();
}
private:
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index da330e742e..6d2a04aa25 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -358,6 +358,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 9a83c16f33..58b86f2a08 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -27,7 +27,8 @@ namespace {
class BatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit BatchDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
+ : UnaryDatasetOpKernel(ctx),
+ op_version_(ctx->def().op() == "BatchDataset" ? 1 : 2) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
@@ -38,14 +39,24 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
- *output = new Dataset(ctx, batch_size, input);
+ bool drop_remainder = false;
+ if (op_version_ > 1) {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "drop_remainder",
+ &drop_remainder));
+ }
+
+ *output = new Dataset(ctx, batch_size, drop_remainder, input);
}
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 batch_size, const DatasetBase* input)
- : GraphDatasetBase(ctx), batch_size_(batch_size), input_(input) {
+ Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
+ const DatasetBase* input)
+ : GraphDatasetBase(ctx),
+ batch_size_(batch_size),
+ drop_remainder_(drop_remainder),
+ input_(input) {
input_->Ref();
// NOTE(mrry): Currently we implement "batch up to" semantics. If
@@ -54,8 +65,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
const auto& input_shapes = input_->output_shapes();
output_shapes_.reserve(input_shapes.size());
for (const auto& input_shape : input_shapes) {
- output_shapes_.emplace_back(
- PartialTensorShape({-1}).Concatenate(input_shape));
+ if (drop_remainder_) {
+ output_shapes_.emplace_back(
+ PartialTensorShape({batch_size_}).Concatenate(input_shape));
+ } else {
+ output_shapes_.emplace_back(
+ PartialTensorShape({-1}).Concatenate(input_shape));
+ }
}
}
@@ -86,8 +102,10 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, batch_size}, output));
+ Node* drop_remainder = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, batch_size, drop_remainder}, output));
return Status::OK();
}
@@ -133,6 +151,12 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ if (dataset()->drop_remainder_ &&
+ batch_elements.size() < dataset()->batch_size_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
// Copy the retrieved batch elements into one output tensor
// per tuple component.
// NOTE(mrry): If the input or output sizes are statically
@@ -201,14 +225,20 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
};
const int64 batch_size_;
+ const bool drop_remainder_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};
+
+ const int op_version_;
};
REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
BatchDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
+ BatchDatasetOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 9d9e74adba..d71cac4ebc 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -782,7 +782,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
return;
}
}
- ProduceOutput(ctx, std::move(done));
+ ProduceOutput(ctx, done);
}
private:
@@ -803,9 +803,9 @@ class OneShotIteratorOp : public AsyncOpKernel {
}
for (auto&& ctx_done : callbacks_to_run) {
- ProduceOutput(ctx_done.first, std::move(ctx_done.second));
+ ProduceOutput(ctx_done.first, ctx_done.second);
}
- ProduceOutput(ctx, std::move(done));
+ ProduceOutput(ctx, done);
}
Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index d9e43ace39..59cbdb655d 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -28,7 +28,8 @@ namespace {
class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
+ : UnaryDatasetOpKernel(ctx),
+ op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
@@ -39,6 +40,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
+ bool drop_remainder = false;
+ if (op_version_ > 1) {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "drop_remainder",
+ &drop_remainder));
+ }
+
OpInputList padded_shape_tensors;
OP_REQUIRES_OK(ctx,
ctx->input_list("padded_shapes", &padded_shape_tensors));
@@ -85,18 +92,20 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padding_values.push_back(tensor::DeepCopy(padding_value_t));
}
- *output = new Dataset(ctx, batch_size, std::move(padded_shapes),
- std::move(padding_values), input);
+ *output =
+ new Dataset(ctx, batch_size, drop_remainder, std::move(padded_shapes),
+ std::move(padding_values), input);
}
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 batch_size,
+ Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
std::vector<PartialTensorShape> padded_shapes,
std::vector<Tensor> padding_values, const DatasetBase* input)
: GraphDatasetBase(ctx),
batch_size_(batch_size),
+ drop_remainder_(drop_remainder),
padded_shapes_(std::move(padded_shapes)),
padding_values_(std::move(padding_values)),
input_(input) {
@@ -112,8 +121,13 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
const auto& input_shapes = input_->output_shapes();
output_shapes_.reserve(input_shapes.size());
for (size_t i = 0; i < input_shapes.size(); ++i) {
- output_shapes_.push_back(
- PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
+ if (drop_remainder_) {
+ output_shapes_.push_back(
+ PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
+ } else {
+ output_shapes_.push_back(
+ PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
+ }
}
}
@@ -166,16 +180,19 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padding_values.emplace_back(node);
}
+ Node* drop_remainder = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
+
AttrValue output_types;
b->BuildAttrValue(output_dtypes(), &output_types);
AttrValue N;
b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {{0, input_graph_node}, {1, batch_size}},
- {{2, padded_shapes}, {3, padding_values}},
- {{"Toutput_types", output_types}, {"N", N}}, output));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
+ {{2, padded_shapes}, {3, padding_values}},
+ {{"Toutput_types", output_types}, {"N", N}}, output));
return Status::OK();
}
@@ -226,6 +243,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ if (dataset()->drop_remainder_ &&
+ batch_elements.size() < dataset()->batch_size_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
// Copy the retrieved batch elements into one output tensor
// per tuple component.
// NOTE(mrry): If the input or output sizes are statically
@@ -341,16 +364,22 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
};
const int64 batch_size_;
+ const bool drop_remainder_;
const std::vector<PartialTensorShape> padded_shapes_;
const std::vector<Tensor> padding_values_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};
+
+ const int op_version_;
};
REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
PaddedBatchDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
+ PaddedBatchDatasetOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 3438199ebd..b859295fa4 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -61,10 +61,12 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
}
protected:
- class Iterator : public DatasetIterator<ShuffleDatasetBase> {
+ template <class T>
+ class Iterator : public DatasetIterator<T> {
public:
- explicit Iterator(const Params& params, int64 seed, int64 seed2)
- : DatasetIterator<ShuffleDatasetBase>(params),
+ explicit Iterator(const typename DatasetIterator<T>::Params& params,
+ int64 seed, int64 seed2)
+ : DatasetIterator<T>(params),
input_impl_(nullptr),
seed_(seed),
seed2_(seed2),
@@ -85,26 +87,28 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
bool first_call = false;
if (!input_impl_ && epoch_ == 0) {
first_call = true;
- TF_RETURN_IF_ERROR(
- dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
+ ctx, this->prefix(), &input_impl_));
}
- while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
+ while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++;
LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
- << num_elements_ << " of " << dataset()->buffer_size_;
+ << num_elements_ << " of "
+ << this->dataset()->buffer_size_;
}
std::vector<Tensor> input_element;
bool end_of_input_sequence = false;
- while (dataset()->count_ == -1 || epoch_ < dataset()->count_) {
+ while (this->dataset()->count_ == -1 ||
+ epoch_ < this->dataset()->count_) {
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
&end_of_input_sequence));
if (!end_of_input_sequence) {
first_call = false;
break;
}
- if (first_call && dataset()->count_ == -1) {
+ if (first_call && this->dataset()->count_ == -1) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
@@ -115,11 +119,11 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
epoch_++;
int64 n = slices_.back()->end;
slices_.emplace_back(new Slice{n, n});
- TF_RETURN_IF_ERROR(
- dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
+ ctx, this->prefix(), &input_impl_));
}
if (!end_of_input_sequence) {
- buffer_[slices_.back()->end % dataset()->buffer_size_] =
+ buffer_[slices_.back()->end % this->dataset()->buffer_size_] =
std::move(input_element);
num_elements_++;
slices_.back()->end++;
@@ -144,10 +148,11 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
int64 offset =
Random() % (slices_.front()->end - slices_.front()->start);
int64 index =
- (slices_.front()->start + offset) % dataset()->buffer_size_;
+ (slices_.front()->start + offset) % this->dataset()->buffer_size_;
*out_tensors = std::move(buffer_[index]);
- std::swap(buffer_[index],
- buffer_[slices_.front()->start % dataset()->buffer_size_]);
+ std::swap(
+ buffer_[index],
+ buffer_[slices_.front()->start % this->dataset()->buffer_size_]);
slices_.front()->start++;
num_elements_--;
} else {
@@ -160,40 +165,44 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
-
// Save state needed to restore the random number generators.
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
- num_random_samples_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ this->full_name("num_random_samples"), num_random_samples_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed"), seed_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(this->full_name("seed2"), seed2_));
// Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence".
if (!input_impl_) {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("end_of_input_sequence"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ this->full_name("end_of_input_sequence"), ""));
} else {
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ TF_RETURN_IF_ERROR(this->SaveParent(writer, input_impl_));
}
// Save the epoch counter, buffer, and buffer slices.
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("epoch"), epoch_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("num_elements"), num_elements_));
TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("slices_size"), slices_.size()));
+ writer->WriteScalar(this->full_name("epoch"), epoch_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("num_elements"),
+ num_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("slices_size"),
+ slices_.size()));
for (size_t i = 0; i < slices_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("slices_start_", i)),
+ this->full_name(strings::StrCat("slices_start_", i)),
slices_[i]->start));
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("slices_end_", i)), slices_[i]->end));
+ this->full_name(strings::StrCat("slices_end_", i)),
+ slices_[i]->end));
for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) {
- size_t index = j % dataset()->buffer_size_;
+ size_t index = j % this->dataset()->buffer_size_;
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("buffer_", index, "_size")),
+ this->full_name(strings::StrCat("buffer_", index, "_size")),
buffer_[index].size()));
for (size_t k = 0; k < buffer_[index].size(); ++k) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat("buffer_", index, "_", k)),
+ this->full_name(strings::StrCat("buffer_", index, "_", k)),
buffer_[index][k]));
}
}
@@ -205,51 +214,54 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
-
// Restore the random number generators.
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
- &num_random_samples_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ this->full_name("num_random_samples"), &num_random_samples_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed_));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(this->full_name("seed2"), &seed2_));
ResetRngs();
// Restore the input iterator if it wasn't already exhausted.
- if (!reader->Contains(full_name("end_of_input_sequence"))) {
- TF_RETURN_IF_ERROR(
- dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ if (!reader->Contains(this->full_name("end_of_input_sequence"))) {
+ TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
+ ctx, this->prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(this->RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
// Restore the epoch counter, buffer, and buffer slices.
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("epoch"), &epoch_));
TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("num_elements"), &num_elements_));
+ reader->ReadScalar(this->full_name("epoch"), &epoch_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("num_elements"),
+ &num_elements_));
size_t slices_size;
{
int64 temp;
TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("slices_size"), &temp));
+ reader->ReadScalar(this->full_name("slices_size"), &temp));
slices_size = static_cast<size_t>(temp);
}
- buffer_.reset(new std::vector<Tensor>[dataset()->buffer_size_]);
+ buffer_.reset(new std::vector<Tensor>[this->dataset()->buffer_size_]);
for (size_t i = 0; i < slices_size; ++i) {
int64 start;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("slices_start_", i)), &start));
+ this->full_name(strings::StrCat("slices_start_", i)), &start));
int64 end;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("slices_end_", i)), &end));
+ this->full_name(strings::StrCat("slices_end_", i)), &end));
slices_.emplace_back(new Slice{start, end});
for (size_t j = start; j < end; ++j) {
- size_t index = j % dataset()->buffer_size_;
+ size_t index = j % this->dataset()->buffer_size_;
int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("buffer_", index, "_size")),
+ this->full_name(strings::StrCat("buffer_", index, "_size")),
&list_size));
buffer_[index] = std::vector<Tensor>(list_size);
for (int k = 0; k < list_size; ++k) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
- full_name(strings::StrCat("buffer_", index, "_", k)),
+ this->full_name(strings::StrCat("buffer_", index, "_", k)),
&buffer_[index][k]));
}
}
@@ -289,8 +301,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
mutex mu_;
std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- const int64 seed_ GUARDED_BY(mu_);
- const int64 seed2_ GUARDED_BY(mu_);
+ int64 seed_ GUARDED_BY(mu_);
+ int64 seed2_ GUARDED_BY(mu_);
int64 epoch_ GUARDED_BY(mu_);
int64 num_elements_ GUARDED_BY(mu_);
std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_);
@@ -360,6 +372,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
generator_(&parent_generator_) {}
string DebugString() const override {
+ mutex_lock l(mu_);
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
", ", seed2_, ")::ReshufflingDataset");
}
@@ -370,38 +383,96 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
int64 iterator_seed2;
{
mutex_lock l(mu_);
- iterator_seed = generator_();
- iterator_seed2 = generator_();
+ iterator_seed = Random();
+ iterator_seed2 = Random();
}
- return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
- {this, strings::StrCat(prefix, "::Shuffle")}, iterator_seed,
- iterator_seed2));
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Shuffle")},
+ iterator_seed, iterator_seed2));
}
protected:
+ class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> {
+ public:
+ explicit Iterator(const Params& params, int64 seed, int64 seed2)
+ : ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed,
+ seed2) {}
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(dataset()->mu_);
+
+ // Save RNG state of Dataset.
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("ds_num_random_samples"),
+ dataset()->num_random_samples_));
+
+ // Save the Iterator.
+ return ShuffleDatasetBase::Iterator<ReshufflingDataset>::SaveInternal(
+ writer);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(dataset()->mu_);
+
+ // Restore RNG state of Dataset.
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("ds_num_random_samples"),
+ &dataset()->num_random_samples_));
+ dataset()->ResetRngs();
+
+ // Restore the Iterator.
+ return ShuffleDatasetBase::Iterator<
+ ReshufflingDataset>::RestoreInternal(ctx, reader);
+ }
+ };
+
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented(
- "Checkpointing ShufflingDataset with reshuffle_each_iteration=true "
- "is not supported.\n"
- "If you have a ds.shuffle(buffer_size).repeat(count) in your input "
- "pipeline, replace it with "
- "ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count)).\n"
- "If you iterate over your dataset once, change shuffle(buffer_size) "
- "to shuffle(buffer_size, reshuffle_each_iteration=False).\n"
- "If you are using Dataset.list_files(pattern), change it to "
- "Dataset.list_files(pattern, shuffle=False) and manually shuffle "
- "the list of files using shuffle_and_repeat as above or using "
- "ds.shuffle with reshuffle_each_iteration=False.");
+ mutex_lock l(mu_);
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* buffer_size = nullptr;
+ Node* seed = nullptr;
+ Node* seed2 = nullptr;
+ AttrValue reshuffle_each_iteration;
+
+ TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
+ b->BuildAttrValue(true, &reshuffle_each_iteration);
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
+ {std::make_pair("reshuffle_each_iteration",
+ reshuffle_each_iteration)}, // Attrs
+ output));
+ return Status::OK();
}
private:
- const int64 seed_;
- const int64 seed2_;
+ random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random() const
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ num_random_samples_++;
+ auto out = generator_();
+ return out;
+ }
+
+ void ResetRngs() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // Reset the generators based on the current seeds.
+ parent_generator_ = random::PhiloxRandom(seed_, seed2_);
+ generator_ =
+ random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
+ generator_.Skip(num_random_samples_);
+ }
+
+ mutable int64 seed_ GUARDED_BY(mu_);
+ mutable int64 seed2_ GUARDED_BY(mu_);
mutable mutex mu_;
mutable random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
mutable random::SingleSampleAdapter<random::PhiloxRandom> generator_
GUARDED_BY(mu_);
+ mutable int64 num_random_samples_ GUARDED_BY(mu_) = 0;
};
// A dataset that uses the same fixed seed for all iterators created from it.
@@ -421,8 +492,9 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
- {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
+ return std::unique_ptr<IteratorBase>(
+ new ShuffleDatasetBase::Iterator<ShuffleDatasetBase>(
+ {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
}
protected:
@@ -504,9 +576,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
- {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
- seed2_));
+ return std::unique_ptr<IteratorBase>(
+ new ShuffleDatasetBase::Iterator<ShuffleDatasetBase>(
+ {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
+ seed2_));
}
protected:
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 7370a24b38..3e0a6ae049 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
@@ -234,6 +236,189 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
};
};
+class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit FeatureStatsDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ string tag;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
+ OP_REQUIRES(ctx, input->output_dtypes()[0] == DT_STRING,
+ errors::InvalidArgument("FeatureStatsDataset only supports "
+ "input with a single `tf.string` "
+ "component."));
+ *output = new Dataset(ctx, input, std::move(tag));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
+ : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::FeatureStatsDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "FeatureStatsDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ Node* tag_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ tf_shared_lock l(mu_);
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ auto stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator && s.ok() && !*end_of_sequence) {
+ for (const Tensor& t : *out_tensors) {
+ auto record_t = t.flat<string>();
+ Example example;
+ // TODO(shivaniagrawal): redundant parsing here, potential solutions
+ // to improve performance is to a) have a potential
+ // ParseExampleDataset and collect stats from there and b) make
+ // changes to parse_example() where it returns stats as well.
+ for (int i = 0; i < record_t.size(); ++i) {
+ if (example.ParseFromString(record_t(i))) {
+ AddStatsFeatures(example, stats_aggregator);
+ } else {
+ SequenceExample sequence_example;
+ if (sequence_example.ParseFromString(record_t(i))) {
+ AddStatsFeatures(sequence_example, stats_aggregator);
+ }
+ }
+ }
+ }
+ }
+ return s;
+ }
+
+ // TODO(shivaniagrawal): Add features/feature-values to streamz metrics.
+ int AddStatsFeatureValues(const Feature& feature) {
+ int feature_values_list_size = 0;
+ switch (feature.kind_case()) {
+ case Feature::kBytesList: {
+ feature_values_list_size = feature.bytes_list().value().size();
+ break;
+ }
+ case Feature::kFloatList: {
+ feature_values_list_size = feature.float_list().value().size();
+ break;
+ }
+ case Feature::kInt64List: {
+ feature_values_list_size = feature.int64_list().value().size();
+ break;
+ }
+ case Feature::KIND_NOT_SET:
+ break;
+ }
+ return feature_values_list_size;
+ }
+
+ void AddStatsFeatures(
+ const Example& example,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":features"),
+ {static_cast<double>(example.features().feature().size())});
+
+ int feature_values_list_size_sum = 0;
+ for (const auto& feature : example.features().feature()) {
+ feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
+ }
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":feature-values"),
+ {static_cast<double>(feature_values_list_size_sum)});
+ }
+
+ void AddStatsFeatures(
+ const SequenceExample& example,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":features"),
+ {static_cast<double>(
+ example.context().feature().size() +
+ example.feature_lists().feature_list().size())});
+
+ int feature_values_list_size_sum = 0;
+ for (const auto& feature : example.context().feature()) {
+ feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
+ }
+
+ for (const auto& feature_list :
+ example.feature_lists().feature_list()) {
+ for (const auto& feature : feature_list.second.feature()) {
+ feature_values_list_size_sum += AddStatsFeatureValues(feature);
+ }
+ }
+
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":feature-values"),
+ {static_cast<double>(feature_values_list_size_sum)});
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ const string tag_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("FeatureStatsDataset").Device(DEVICE_CPU),
+ FeatureStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index e0d594fa25..519c475332 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -152,7 +152,7 @@ class IfOp : public AsyncOpKernel {
: kernel_(kernel),
ctx_(ctx),
cond_(cond),
- done_(done),
+ done_(std::move(done)),
lib_(CHECK_NOTNULL(ctx_->function_library())) {
SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
for (int i = 1; i < ctx_->num_inputs(); ++i) {
@@ -174,9 +174,9 @@ class IfOp : public AsyncOpKernel {
s = SetOutputs(kernel_, ctx_, rets_);
}
ctx_->SetStatus(s);
- auto done = done_;
+ DoneCallback captured_done(std::move(done_));
delete this;
- done();
+ captured_done();
});
}
@@ -184,7 +184,7 @@ class IfOp : public AsyncOpKernel {
IfOp* const kernel_;
OpKernelContext* const ctx_;
const bool cond_;
- const DoneCallback done_;
+ DoneCallback done_;
FunctionLibraryRuntime* const lib_;
FunctionLibraryRuntime::Options opts_;
TensorVec args_;
@@ -257,7 +257,7 @@ class WhileOp : public AsyncOpKernel {
ctx_(ctx),
cond_handle_(cond_handle),
body_handle_(body_handle),
- done_(done),
+ done_(std::move(done)),
lib_(CHECK_NOTNULL(ctx_->function_library())) {
SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
for (int i = 0; i < ctx_->num_inputs(); ++i) {
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index f9c15ce6d7..fc3b3d3445 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -551,7 +551,8 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
+
// MKL does not support half and int32 types for matrix-multiplication, so
// register the kernel to use default Eigen based implementations for these
// types. Registration for NO-LABEL version is in mkl_matmul_op.cc
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index b539b00009..4ad858e4a9 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -24,15 +24,16 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
using mkldnn::sum;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -333,7 +334,7 @@ class MklAddNOp : public OpKernel {
if (!input1_in_mkl_format && src1_dims_size == 0) {
Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
+ MklDnnShape mkl_shape_dst;
mkl_shape_dst.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
src1_tensor.shape(), mkl_shape_dst);
@@ -347,7 +348,7 @@ class MklAddNOp : public OpKernel {
if (!input1_in_mkl_format && !input2_in_mkl_format) {
if (src1_tensor.shape().num_elements() == 0) {
Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
+ MklDnnShape mkl_shape_dst;
mkl_shape_dst.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
src1_tensor.shape(), mkl_shape_dst);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 723b445a75..45328b03d6 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#include <vector>
#include "mkl_cblas.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index a9b952095d..31d1b949ef 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -27,16 +27,17 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::concat;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
index a6698a1a07..f857be6c32 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
@@ -39,8 +39,10 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index e0706568b1..356eed8b67 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -38,9 +38,6 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
@@ -49,8 +46,13 @@ using mkldnn::convolution_backward_weights;
using mkldnn::memory;
using mkldnn::prop_kind;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index d203c04934..21b18f9119 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -23,8 +23,10 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 62aafa7930..3fe660cf96 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -21,21 +21,21 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
-
using mkldnn::batch_normalization_backward;
using mkldnn::batch_normalization_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
using mkldnn::use_global_stats;
using mkldnn::use_scale_shift;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
// TODO(inteltf) Address comments from PR 8968.
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc
index 6c027f8e72..b02cc5384c 100644
--- a/tensorflow/core/kernels/mkl_identity_op.cc
+++ b/tensorflow/core/kernels/mkl_identity_op.cc
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index cda1402b03..dc4da33a06 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -369,8 +369,8 @@ class MklInputConversionOp : public OpKernel {
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
op_data_type, has_avx512f_,
kInputIndex_1);
- SetDummyMklShapeOutput(context, kInputIndex_0);
- SetDummyMklShapeOutput(context, kInputIndex_1);
+ SetDummyMklDnnShapeOutput(context, kInputIndex_0);
+ SetDummyMklDnnShapeOutput(context, kInputIndex_1);
return;
}
@@ -439,11 +439,11 @@ class MklInputConversionOp : public OpKernel {
tensor_out, &net);
if(!reordered) {
// This is the case that the TF tensor has the same shape and format of
- // mkl tensor. However, tf_tensor can not be simply forwarded to the output
- // tensor since mkl data tensor is always one dimensional tensor.
- // Tensor::CopyFrom shares the buffer of the other tensor while set its shape
- // to the other tensor.
- tensor_out->CopyFrom(*tf_tensor, tensor_out->shape());
+ // mkl tensor. However, tf_tensor can not be simply forwarded to the
+ // output tensor since mkl data tensor is always one dimensional tensor.
+ // Tensor::CopyFrom shares the buffer of the other tensor while set its
+ // shape to the other tensor.
+ CHECK(tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()));
}
else
stream(stream::kind::eager).submit(net).wait();
@@ -458,7 +458,7 @@ class MklInputConversionOp : public OpKernel {
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
op_data_type, has_avx512f_,
mkl_tensor_index);
- SetDummyMklShapeOutput(context, mkl_tensor_index);
+ SetDummyMklDnnShapeOutput(context, mkl_tensor_index);
// The tensor in TF format passes through
ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index eef254cdad..dfe50e6a7f 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -22,8 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <vector>
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -31,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#if !defined(IS_MOBILE_PLATFORM)
@@ -45,8 +42,13 @@ using mkldnn::lrn_backward;
using mkldnn::lrn_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
+
namespace tensorflow {
namespace {
@@ -1236,7 +1238,7 @@ class MklLRNGradOp : public OpKernel {
auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
Tensor* output_dnn_data;
- MklShape mkl_output_mkl_shape;
+ MklDnnShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
mkl_output_mkl_shape.SetDimensions(4);
AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index dfa6cecc9b..62c0404891 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
// and when it is undefined at build time, this file becomes an empty
// compilation unit
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#include "mkl_cblas.h"
#include "tensorflow/core/framework/op.h"
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 1ed43834dd..78abbdb730 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -23,9 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
@@ -38,7 +35,11 @@ using mkldnn::prop_kind;
using mkldnn::relu_backward;
using mkldnn::relu_forward;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index 2cfde1f6fd..c44a6f3477 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -24,15 +24,17 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
+
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
template <typename Device, typename T>
@@ -250,7 +252,7 @@ class MklReshapeOp : public OpKernel {
memory::primitive_desc(output_tf_md, cpu_engine);
Tensor* output_tensor = nullptr;
- MklShape mkl_shape_output;
+ MklDnnShape mkl_shape_output;
mkl_shape_output.SetMklTensor(false);
// We allocate output tensor in the shape expected by Reshape.
AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index f79e18cff2..638392954e 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -25,8 +25,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "mkldnn.h"
-#include "mkldnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "mkldnn.hpp"
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index 4120f013ac..7e8ed1b1d6 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -32,8 +32,10 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/tensor_format.h"
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc
index 3f07b317c4..b180c2ff20 100644
--- a/tensorflow/core/kernels/mkl_transpose_op.cc
+++ b/tensorflow/core/kernels/mkl_transpose_op.cc
@@ -15,7 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#define EIGEN_USE_THREADS
#include "mkl_trans.h"
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index ff38026ac7..e1fc2ea128 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -143,14 +143,10 @@ class ScatterNdUpdateOp : public OpKernel {
void Compute(OpKernelContext* c) override {
if (dtype_ == DT_RESOURCE) {
- if (use_exclusive_lock_) {
- Var* v;
- OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
- mutex_lock m(*v->mu());
- DoCompute(c);
- } else {
- DoCompute(c);
- }
+ Var* v;
+ OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ mutex_lock m(*v->mu());
+ DoCompute(c);
} else if (use_exclusive_lock_) {
// If we're here, it means the input type is a ref.
DCHECK(IsRefType(c->input_dtype(0)));
@@ -176,13 +172,7 @@ class ScatterNdUpdateOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
Tensor* t = v->tensor();
- if (!use_exclusive_lock_) {
- // We're not holding the lock in the outer scope so need it here.
- mutex_lock m(*v->mu());
- OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
- } else {
- OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
- }
+ OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
params = *t;
params_shape = params.shape();
} else if (IsRefType(c->input_dtype(0))) {
@@ -260,7 +250,9 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
- scatter_nd_op::UpdateOp::SUB);
+ scatter_nd_op::UpdateOp::SUB); \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
+ type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD);
#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 7b85ff2ea4..765467bc1e 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -81,7 +81,8 @@ TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
std::atomic<int64> TensorArray::tensor_array_counter{0};
-Status TensorArray::CopyShapesFrom(TensorArray* rhs) {
+Status TensorArray::CopyShapesFrom(TensorArray* rhs,
+ const TensorShape* shape_to_prepend) {
mutex_lock l(mu_);
mutex_lock l_rhs(rhs->mu_);
TF_RETURN_IF_ERROR(LockedReturnIfClosed());
@@ -97,7 +98,12 @@ Status TensorArray::CopyShapesFrom(TensorArray* rhs) {
if (!rhs->tensors_[i].written) continue;
// Copy the shape over.
- tensors_[i].shape = rhs->tensors_[i].shape;
+ if (shape_to_prepend) {
+ tensors_[i].shape = *shape_to_prepend;
+ tensors_[i].shape.AppendShape(rhs->tensors_[i].shape);
+ } else {
+ tensors_[i].shape = rhs->tensors_[i].shape;
+ }
// Mark as written. Reads will know that if written is true and
// read is false, and cleared is false, to return zeros of the
// appropriate shape. Future aggregating writes will only use the shape
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 90b71e370c..68fab85770 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -325,13 +325,15 @@ class TensorArray : public ResourceBase {
bool HasIdenticalElementShapes() const { return identical_element_shapes_; }
// Copy the TensorShapes from another TensorArray into this one.
+ // If `shapes_to_prepend` is set, expands the rank of the copied shape by
+ // prepending the passed in shape prefix to the shape values in `rhs`.
// The sizes of the two TensorArrays must match and this one
// may not have any entries filled in. This performs a "soft copy",
// essentially filling the current TensorArray with virtual
// zero-tensors, which will be replaced by future aggregate writes,
// or instantiated by future reads. Requires a non-const pointer
// to the rhs to access its mutex.
- Status CopyShapesFrom(TensorArray* rhs);
+ Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend);
// Clear the TensorArray, including any Tensor references, and mark as closed.
void ClearAndMarkClosed() {
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index ef9748b1aa..37803ec775 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -264,7 +264,10 @@ REGISTER_GPU(bfloat16);
#endif // GOOGLE_CUDA
// GRADIENT *******************************************************************
-
+// Note that this op may have an optional third input. If present, it represents
+// a shape value. It indicates that element shape of this gradient array is that
+// shape value concatenated with the element shape of the original tensor array.
+// See TensorArrayGradWithShape.
class TensorArrayGradOp : public TensorArrayCreationOp {
public:
explicit TensorArrayGradOp(OpKernelConstruction* context)
@@ -325,18 +328,38 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
"previous write? Gradient calculation is impossible when multiple "
"writes are performed to the same index.");
}
+ TensorShape shape_to_prepend;
+ auto element_shape = PartialTensorShape();
+ if (ctx->num_inputs() > 2) {
+ TF_RETURN_IF_ERROR(
+ ctx->op_kernel().MakeShape(ctx->input(2), &shape_to_prepend));
+ auto ta_element_shape = tensor_array->ElemShape();
+ if (!ta_element_shape.unknown_rank()) {
+ std::vector<int64> dims;
+ for (auto dim : shape_to_prepend) {
+ dims.push_back(dim.size);
+ }
+ for (auto dim : ta_element_shape) {
+ dims.push_back(dim.size);
+ }
+ TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
+ gtl::ArraySlice<int64>(dims), &element_shape));
+ }
+ } else {
+ element_shape = tensor_array->ElemShape();
+ }
const auto key = strings::StrCat(output_handle(0), output_handle(1));
auto creator = [this, key, tensor_array, array_size, marked_size,
- tensor_array_output_handle,
+ element_shape, shape_to_prepend, tensor_array_output_handle,
output_handle](TensorArray** ret) -> Status {
*ret = new TensorArray(
key, tensor_array->ElemType(), *tensor_array_output_handle,
- array_size, tensor_array->ElemShape(),
- tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */,
- true /* multiple_writes_aggregate */, true /* is_grad */,
- marked_size /* marked_size */, true /* close_after_read */);
- return (*ret)->CopyShapesFrom(tensor_array);
+ array_size, element_shape, tensor_array->HasIdenticalElementShapes(),
+ false /* dynamic_size */, true /* multiple_writes_aggregate */,
+ true /* is_grad */, marked_size /* marked_size */,
+ true /* close_after_read */);
+ return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend);
};
Status s = rm->LookupOrCreate<TensorArray>(
@@ -361,7 +384,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2").Device(DEVICE_CPU),
TensorArrayGradOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3").Device(DEVICE_CPU),
TensorArrayGradOp);
-
+REGISTER_KERNEL_BUILDER(Name("TensorArrayGradWithShape").Device(DEVICE_CPU),
+ TensorArrayGradOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad")
.Device(DEVICE_GPU)
.HostMemory("handle")
@@ -377,6 +401,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3")
.HostMemory("handle")
.HostMemory("grad_handle"),
TensorArrayGradOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayGradWithShape")
+ .Device(DEVICE_GPU)
+ .HostMemory("handle")
+ .HostMemory("shape_to_prepend")
+ .HostMemory("grad_handle"),
+ TensorArrayGradOp);
// WRITE **********************************************************************
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 7177ad7888..886b3e7492 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
index ae67592d04..709b0a92e9 100644
--- a/tensorflow/core/kernels/transpose_op.h
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -42,7 +42,7 @@ class TransposeCpuOp : public TransposeOp {
gtl::ArraySlice<int32> perm, Tensor* out) override;
};
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
class MklTransposeCpuOp : public TransposeOp {
public:
explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
@@ -85,7 +85,7 @@ class ConjugateTransposeCpuOp : public TransposeOp {
bool IsConjugate() const override { return true; }
};
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
class MklConjugateTransposeCpuOp : public TransposeOp {
public:
explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/lib/io/random_inputstream.cc b/tensorflow/core/lib/io/random_inputstream.cc
index 09336e79cd..e85367df9c 100644
--- a/tensorflow/core/lib/io/random_inputstream.cc
+++ b/tensorflow/core/lib/io/random_inputstream.cc
@@ -45,16 +45,8 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
result->resize(data.size());
if (s.ok() || errors::IsOutOfRange(s)) {
pos_ += data.size();
- } else {
- return s;
}
- // If the amount of data we read is less than what we wanted, we return an
- // out of range error. We need to catch this explicitly since file_->Read()
- // would not do so if at least 1 byte is read (b/30839063).
- if (data.size() < bytes_to_read) {
- return errors::OutOfRange("reached end of file");
- }
- return Status::OK();
+ return s;
}
// To limit memory usage, the default implementation of SkipNBytes() only reads
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
index 0f22dac262..5b595f9847 100644
--- a/tensorflow/core/lib/strings/numbers_test.cc
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -289,12 +289,9 @@ TEST(safe_strtof, Float) {
EXPECT_FALSE(safe_strtof("-infinity is awesome", &result));
- // Make sure we exit cleanly if the string is not terminated
+ // Make sure we exit cleanly if the string is too long
char test_str[2 * kFastToBufferSize];
for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a';
- EXPECT_FALSE(safe_strtof(test_str, &result));
-
- // Make sure we exit cleanly if the string is too long
test_str[kFastToBufferSize + 1] = '\0';
EXPECT_FALSE(safe_strtof(test_str, &result));
@@ -330,12 +327,9 @@ TEST(safe_strtod, Double) {
EXPECT_EQ(0.1234567890123, result);
EXPECT_FALSE(safe_strtod("0.1234567890123abc", &result));
- // Make sure we exit cleanly if the string is not terminated
+ // Make sure we exit cleanly if the string is too long
char test_str[2 * kFastToBufferSize];
for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a';
- EXPECT_FALSE(safe_strtod(test_str, &result));
-
- // Make sure we exit cleanly if the string is too long
test_str[kFastToBufferSize + 1] = '\0';
EXPECT_FALSE(safe_strtod(test_str, &result));
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 16e9b2e02e..8f8c90ee97 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -8721,6 +8721,37 @@ op {
}
}
op {
+ name: "BatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchFFT"
input_arg {
name: "input"
@@ -22113,6 +22144,33 @@ op {
is_stateful: true
}
op {
+ name: "FeatureStatsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Fill"
input_arg {
name: "dims"
@@ -35791,6 +35849,52 @@ op {
}
}
op {
+ name: "PaddedBatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "padded_shapes"
+ type: DT_INT64
+ number_attr: "N"
+ }
+ input_arg {
+ name: "padding_values"
+ type_list_attr: "Toutput_types"
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "PaddingFIFOQueue"
output_arg {
name: "handle"
@@ -48134,6 +48238,43 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterNdAdd"
+ input_arg {
+ name: "ref"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterNdUpdate"
input_arg {
name: "ref"
@@ -69458,6 +69599,34 @@ op {
is_stateful: true
}
op {
+ name: "TensorArrayGradWithShape"
+ input_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "shape_to_prepend"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "grad_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "source"
+ type: "string"
+ }
+ is_stateful: true
+}
+op {
name: "TensorArrayPack"
input_arg {
name: "handle"
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 3112f35da4..eed0bce174 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -608,6 +608,50 @@ REGISTER_OP("TensorArrayGradV3")
return Status::OK();
});
+REGISTER_OP("TensorArrayGradWithShape")
+ .Input("handle: resource")
+ .Input("flow_in: float")
+ .Input("shape_to_prepend: int32")
+ .Output("grad_handle: resource")
+ .Output("flow_out: float")
+ .Attr("source: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+ c->set_output(0, c->Vector(2));
+ c->set_output(1, c->Scalar());
+ auto* shape_and_type = c->input_handle_shapes_and_types(0);
+ if (shape_and_type) {
+ auto input_shape = (*shape_and_type)[0].shape;
+ auto dtype = (*shape_and_type)[0].dtype;
+ // Note that shape_to_preped is a rank 1 Tensor representing a shape.
+ // The size of dimension 0 is the number of dimensions we need to add to
+ // output shape.
+ int64 prepend_rank = c->Value(c->Dim(c->input(2), 0));
+ if (c->RankKnown(input_shape) &&
+ prepend_rank != InferenceContext::kUnknownDim) {
+ int32 input_rank = c->Rank(input_shape);
+ std::vector<DimensionHandle> dims;
+ dims.reserve(prepend_rank + input_rank);
+ for (int i = 0; i < prepend_rank; ++i) {
+ dims.push_back(c->UnknownDim());
+ }
+ for (int i = 0; i < input_rank; ++i) {
+ dims.push_back(c->Dim(input_shape, i));
+ }
+ c->set_output_handle_shapes_and_types(0,
+ {{c->MakeShape(dims), dtype}});
+ } else {
+ c->set_output_handle_shapes_and_types(0,
+ {{c->UnknownShape(), dtype}});
+ }
+ }
+ return Status::OK();
+ });
+
REGISTER_OP("TensorArrayWriteV3")
.Input("handle: resource")
.Input("index: int32")
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index c9d1cdf412..9dca5f53ce 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -166,6 +166,18 @@ REGISTER_OP("LatencyStatsDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("FeatureStatsDataset")
+ .Input("input_dataset: variant")
+ .Input("tag: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle tag_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
@@ -363,6 +375,22 @@ REGISTER_OP("BatchDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("BatchDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // drop_remainder should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
// TODO(mrry): move SlideDataset to contrib in the future.
REGISTER_OP("SlideDataset")
.Input("input_dataset: variant")
@@ -379,6 +407,10 @@ REGISTER_OP("SlideDataset")
return shape_inference::ScalarShape(c);
});
+// TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
+// `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
+// possible to tell statically) compatible with `padded_shapes`, and that
+// `padding_values` are all scalars.
REGISTER_OP("PaddedBatchDataset")
.Input("input_dataset: variant")
.Input("batch_size: int64")
@@ -388,17 +420,32 @@ REGISTER_OP("PaddedBatchDataset")
.Attr("Toutput_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that
- // `padded_shapes` are all
- // vectors, the lengths of
- // `output_types` and
- // `output_shapes` are `N`,
- // the `output_shapes` are (as
- // far as possible to tell
- // statically) compatible with
- // `padded_shapes`, and
- // that `padding_values` are
- // all scalars.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("PaddedBatchDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("padded_shapes: N * int64")
+ .Input("padding_values: Toutput_types")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("Toutput_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // drop_remainder should be a scalar.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("DenseToSparseBatchDataset")
.Input("input_dataset: variant")
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index a6cc4b60e5..88553dff93 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -82,7 +82,7 @@ REGISTER_OP("If")
.Input("input: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
- .Attr("Tin: list(type)")
+ .Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type)")
.Attr("then_branch: func")
.Attr("else_branch: func")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7df43663c9..d3f3e87dfd 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -3005,6 +3005,37 @@ op {
}
}
op {
+ name: "BatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchFFT"
input_arg {
name: "input"
@@ -10270,6 +10301,33 @@ op {
is_stateful: true
}
op {
+ name: "FeatureStatsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Fill"
input_arg {
name: "dims"
@@ -17463,6 +17521,52 @@ op {
}
}
op {
+ name: "PaddedBatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "padded_shapes"
+ type: DT_INT64
+ number_attr: "N"
+ }
+ input_arg {
+ name: "padding_values"
+ type_list_attr: "Toutput_types"
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "PaddingFIFOQueue"
output_arg {
name: "handle"
@@ -23632,6 +23736,43 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterNdAdd"
+ input_arg {
+ name: "ref"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterNdUpdate"
input_arg {
name: "ref"
@@ -32376,6 +32517,34 @@ op {
is_stateful: true
}
op {
+ name: "TensorArrayGradWithShape"
+ input_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "shape_to_prepend"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "grad_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "source"
+ type: "string"
+ }
+ is_stateful: true
+}
+op {
name: "TensorArrayPack"
input_arg {
name: "handle"
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 664f52452e..aa975cb77b 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -222,6 +222,15 @@ REGISTER_OP("ResourceScatterNdUpdate")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
+REGISTER_OP("ResourceScatterNdAdd")
+ .Input("ref: resource")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+
REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 522a9d84fd..cb1fd09dbb 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 8
+#define TF_MINOR_VERSION 9
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc0"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/exec_on_stall_test.cc b/tensorflow/core/util/exec_on_stall_test.cc
index df8118d611..42e66a7e84 100644
--- a/tensorflow/core/util/exec_on_stall_test.cc
+++ b/tensorflow/core/util/exec_on_stall_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/util/exec_on_stall.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -32,14 +33,24 @@ Chunk* NewChunk(int stall_seconds, std::function<void()> f) {
}
TEST(ExecuteOnStallTest, BothWays) {
- bool a_triggered = false;
- bool b_triggered = false;
- Chunk* a = NewChunk(1, [&a_triggered]() { a_triggered = true; });
- Chunk* b = NewChunk(1, [&b_triggered]() { b_triggered = true; });
+ mutex mu;
+ bool a_triggered(false);
+ bool b_triggered(false);
+ Chunk* a = NewChunk(1, [&mu, &a_triggered]() {
+ mutex_lock l(mu);
+ a_triggered = true;
+ });
+ Chunk* b = NewChunk(1, [&mu, &b_triggered]() {
+ mutex_lock l(mu);
+ b_triggered = true;
+ });
delete a;
Env::Default()->SleepForMicroseconds(2000000);
- EXPECT_FALSE(a_triggered);
- EXPECT_TRUE(b_triggered);
+ {
+ mutex_lock l(mu);
+ EXPECT_FALSE(a_triggered);
+ EXPECT_TRUE(b_triggered);
+ }
delete b;
}
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 230b4278ca..90b6533690 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -22,10 +22,13 @@ limitations under the License.
#include <unordered_map>
#include <utility>
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#include "mkl_service.h"
#include "mkl_trans.h"
+#endif
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -39,6 +42,7 @@ limitations under the License.
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
+#include "tensorflow/core/lib/core/stringpiece.h"
using mkldnn::engine;
using mkldnn::memory;
@@ -51,11 +55,12 @@ using mkldnn::reorder;
typedef unsigned int uint;
#endif
-// The file contains a number of utility classes and functions used by MKL
-// enabled kernels
namespace tensorflow {
+// The file contains a number of utility classes and functions used by MKL
+// enabled kernels
+
// This class encapsulates all the meta data that is associated with an MKL
// tensor. A tensor is an MKL tensor if it was created as the result of an
// MKL operation, and did not go through a conversion to a standard
@@ -71,6 +76,7 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+#ifdef INTEL_MKL_ML
class MklShape {
public:
MklShape() {}
@@ -331,7 +337,7 @@ class MklShape {
nullptr; // TF dimension corresponding to this MKL dimension
};
-#ifndef INTEL_MKL_ML
+#else
// Forward decl
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
@@ -664,12 +670,14 @@ class MklDnnShape {
// List of MklShape objects. Used in Concat/Split layers.
-typedef std::vector<MklShape> MklShapeList;
#ifndef INTEL_MKL_ML
typedef std::vector<MklDnnShape> MklDnnShapeList;
+#else
+typedef std::vector<MklShape> MklShapeList;
#endif
+#ifdef INTEL_MKL_ML
// Check if all tensors specified by MklShapes are MKL tensors.
inline bool AreAllMklTensors(const MklShapeList& shapes) {
for (auto& s : shapes) {
@@ -680,7 +688,6 @@ inline bool AreAllMklTensors(const MklShapeList& shapes) {
return true;
}
-#ifdef INTEL_MKL_ML
template <typename T>
inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
const MklShape& mkl_shape) {
@@ -753,6 +760,7 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
#endif
// Get the MKL shape from the second string tensor
+#ifdef INTEL_MKL_ML
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape(
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
@@ -763,8 +771,7 @@ inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
.size() *
sizeof(uint8));
}
-
-#ifndef INTEL_MKL_ML
+#else
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
mklshape->DeSerializeMklDnnShape(
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
@@ -838,6 +845,7 @@ inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
}
#endif
+#ifdef INTEL_MKL_ML
// Allocate the second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -853,7 +861,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
-#ifndef INTEL_MKL_ML
+#else
// Allocate the second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -870,6 +878,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
}
#endif
+#ifdef INTEL_MKL_ML
// Allocate the output tensor, create a second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -890,7 +899,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
-#ifndef INTEL_MKL_ML
+#else
// Allocate the output tensor, create a second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -925,8 +934,7 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
tf_shape, tensor_out));
*buf_out = static_cast<void*>(tensor_out->flat<T>().data());
}
-#endif
-
+#else
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
dnnLayout_t lt_buff, void** buf_out) {
TensorShape tf_shape;
@@ -940,6 +948,7 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
*buf_out = static_cast<void*>(tensor_out->flat<float>().data());
}
+#endif
template <typename T>
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
TensorShape tf_shape) {
@@ -963,6 +972,7 @@ inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
}
}
+#ifdef INTEL_MKL_ML
inline void MklSizesToTFSizes(OpKernelContext* context,
TensorFormat data_format_,
const MklShape& mkl_shape,
@@ -988,6 +998,7 @@ inline void MklSizesToTFSizes(OpKernelContext* context,
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
}
+#endif
inline int32 GetMklTensorDimIndex(char dimension) {
switch (dimension) {
@@ -1005,12 +1016,14 @@ inline int32 GetMklTensorDimIndex(char dimension) {
}
}
+#ifdef INTEL_MKL_ML
inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
int index = GetMklTensorDimIndex(dimension);
CHECK(index >= 0 && index < mkl_shape.GetDimension())
<< "Invalid index from the dimension: " << index << ", " << dimension;
return mkl_shape.dim_size(index);
}
+#endif
inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
int idx_out) {
@@ -1130,6 +1143,14 @@ inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
}
#ifndef INTEL_MKL_ML
+// Set a dummy MKLDNN shape (called when the output is in TF format)
+inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
+ uint32 idx_data_out) {
+ MklDnnShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
+}
+
inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
int idx_in, int idx_out,
const MklDnnShape& mkl_shape) {
@@ -1165,6 +1186,7 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
}
}
+#ifdef INTEL_MKL_ML
// Set a dummy MKL shape (called when the output is in TF format)
inline void SetDummyMklShapeOutput(OpKernelContext* context,
uint32 idx_data_out) {
@@ -1172,8 +1194,6 @@ inline void SetDummyMklShapeOutput(OpKernelContext* context,
mkl_shape_output.SetMklTensor(false);
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
}
-
-#ifdef INTEL_MKL_ML
// We don't need these functions in MKLDNN. We have defined equality operator
// on MklDnnShape class directly.
@@ -1243,7 +1263,6 @@ inline bool MklCompareShapes(const TensorShape* input_shape_0,
return true;
}
-#endif
// These functions do not compile with MKL-DNN since mkl.h is missing.
// We may need to remove them later.
@@ -1281,6 +1300,7 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
}
}
+#endif
// -------------------------------------------------------------------
#ifndef INTEL_MKL_ML
@@ -1846,10 +1866,7 @@ class FactoryKeyCreator {
~FactoryKeyCreator() {}
- void AddAsKey(const string &str) {
- auto buffer = reinterpret_cast<const char *>(str.c_str());
- Append(buffer, str.length());
- }
+ void AddAsKey(const string& str) { Append(str); }
void AddAsKey(const mkldnn::memory::dims &dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
@@ -1860,7 +1877,7 @@ class FactoryKeyCreator {
template <typename T>
void AddAsKey(const T data) {
auto buffer = reinterpret_cast<const char *>(&data);
- Append(buffer, sizeof(T));
+ Append(StringPiece(buffer, sizeof(T)));
}
std::string GetKey() {
@@ -1871,8 +1888,8 @@ class FactoryKeyCreator {
string key_;
const char delimiter = 'x';
const int kMaxKeyLength = 256;
- void Append(const char* data, int len) {
- key_.append(data, len);
+ void Append(StringPiece s) {
+ key_.append(s.ToString());
key_.append(1, delimiter);
}
};
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index c0fce207e7..fb70318078 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -78,7 +78,10 @@ class GroupIterable {
typedef gtl::ArraySlice<int64> VarDimArray;
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
- : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
+ : ix_(ix),
+ vals_(vals),
+ dims_(dims),
+ group_dims_(group_dims.begin(), group_dims.end()) {}
class IteratorStep;
@@ -127,7 +130,7 @@ class GroupIterable {
Tensor ix_;
Tensor vals_;
const int dims_;
- const VarDimArray group_dims_;
+ const gtl::InlinedVector<int64, 8> group_dims_;
};
// Implementation of Group::values<T>()
diff --git a/tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md b/tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md
deleted file mode 100644
index 74fe4a323a..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# BayesFlow Monte Carlo (contrib)
-[TOC]
-
-Monte Carlo integration and helpers.
-
-## Background
-
-Monte Carlo integration refers to the practice of estimating an expectation with
-a sample mean. For example, given random variable Z in \\(R^k\\) with density `p`,
-the expectation of function `f` can be approximated like:
-
-$$E_p[f(Z)] = \int f(z) p(z) dz$$
-$$ ~ S_n
- := n^{-1} \sum_{i=1}^n f(z_i), z_i\ iid\ samples\ from\ p.$$
-
-If \\(E_p[|f(Z)|] < infinity\\), then \\(S_n\\) --> \\(E_p[f(Z)]\\) by the strong law of large
-numbers. If \\(E_p[f(Z)^2] < infinity\\), then \\(S_n\\) is asymptotically normal with
-variance \\(Var[f(Z)] / n\\).
-
-Practitioners of Bayesian statistics often find themselves wanting to estimate
-\\(E_p[f(Z)]\\) when the distribution `p` is known only up to a constant. For
-example, the joint distribution `p(z, x)` may be known, but the evidence
-\\(p(x) = \int p(z, x) dz\\) may be intractable. In that case, a parameterized
-distribution family \\(q_\lambda(z)\\) may be chosen, and the optimal \\(\lambda\\) is the
-one minimizing the KL divergence between \\(q_\lambda(z)\\) and
-\\(p(z | x)\\). We only know `p(z, x)`, but that is sufficient to find \\(\lambda\\).
-
-
-## Log-space evaluation and subtracting the maximum
-
-Care must be taken when the random variable lives in a high dimensional space.
-For example, the naive importance sample estimate \\(E_q[f(Z) p(Z) / q(Z)]\\)
-involves the ratio of two terms \\(p(Z) / q(Z)\\), each of which must have tails
-dropping off faster than \\(O(|z|^{-(k + 1)})\\) in order to have finite integral.
-This ratio would often be zero or infinity up to numerical precision.
-
-For that reason, we write
-
-$$Log E_q[ f(Z) p(Z) / q(Z) ]$$
-$$ = Log E_q[ \exp\{Log[f(Z)] + Log[p(Z)] - Log[q(Z)] - C\} ] + C,$$ where
-$$C := Max[ Log[f(Z)] + Log[p(Z)] - Log[q(Z)] ].$$
-
-The maximum value of the exponentiated term will be 0.0, and the expectation
-can be evaluated in a stable manner.
-
-## Ops
-
-* @{tf.contrib.bayesflow.monte_carlo.expectation}
-* @{tf.contrib.bayesflow.monte_carlo.expectation_importance_sampler}
-* @{tf.contrib.bayesflow.monte_carlo.expectation_importance_sampler_logspace}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md
deleted file mode 100644
index e169897f31..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md
+++ /dev/null
@@ -1,32 +0,0 @@
-# Random variable transformations (contrib)
-[TOC]
-
-Bijector Ops.
-
-An API for invertible, differentiable transformations of random variables.
-
-## Background
-
-Differentiable, bijective transformations of continuous random variables alter
-the calculations made in the cumulative/probability distribution functions and
-sample function. This module provides a standard interface for making these
-manipulations.
-
-For more details and examples, see the `Bijector` docstring.
-
-To apply a `Bijector`, use `distributions.TransformedDistribution`.
-
-## Bijectors
-
-* @{tf.contrib.distributions.bijectors.Affine}
-* @{tf.contrib.distributions.bijectors.AffineLinearOperator}
-* @{tf.contrib.distributions.bijectors.Bijector}
-* @{tf.contrib.distributions.bijectors.Chain}
-* @{tf.contrib.distributions.bijectors.CholeskyOuterProduct}
-* @{tf.contrib.distributions.bijectors.Exp}
-* @{tf.contrib.distributions.bijectors.Identity}
-* @{tf.contrib.distributions.bijectors.Inline}
-* @{tf.contrib.distributions.bijectors.Invert}
-* @{tf.contrib.distributions.bijectors.PowerTransform}
-* @{tf.contrib.distributions.bijectors.SoftmaxCentered}
-* @{tf.contrib.distributions.bijectors.Softplus}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.md
deleted file mode 100644
index 533d7dac13..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.distributions.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# Statistical Distributions (contrib)
-[TOC]
-
-Classes representing statistical distributions and ops for working with them.
-
-## Classes for statistical distributions
-
-Classes that represent batches of statistical distributions. Each class is
-initialized with parameters that define the distributions.
-
-## Base classes
-
-* @{tf.contrib.distributions.ReparameterizationType}
-* @{tf.contrib.distributions.Distribution}
-
-## Univariate (scalar) distributions
-
-* @{tf.contrib.distributions.Binomial}
-* @{tf.contrib.distributions.Bernoulli}
-* @{tf.contrib.distributions.Beta}
-* @{tf.contrib.distributions.Categorical}
-* @{tf.contrib.distributions.Chi2}
-* @{tf.contrib.distributions.Chi2WithAbsDf}
-* @{tf.contrib.distributions.Exponential}
-* @{tf.contrib.distributions.Gamma}
-* @{tf.contrib.distributions.InverseGamma}
-* @{tf.contrib.distributions.Laplace}
-* @{tf.contrib.distributions.LaplaceWithSoftplusScale}
-* @{tf.contrib.distributions.Normal}
-* @{tf.contrib.distributions.NormalWithSoftplusScale}
-* @{tf.contrib.distributions.Poisson}
-* @{tf.contrib.distributions.StudentT}
-* @{tf.contrib.distributions.StudentTWithAbsDfSoftplusScale}
-* @{tf.contrib.distributions.Uniform}
-
-## Multivariate distributions
-
-### Multivariate normal
-
-* @{tf.contrib.distributions.MultivariateNormalDiag}
-* @{tf.contrib.distributions.MultivariateNormalTriL}
-* @{tf.contrib.distributions.MultivariateNormalDiagPlusLowRank}
-* @{tf.contrib.distributions.MultivariateNormalDiagWithSoftplusScale}
-
-### Other multivariate distributions
-
-* @{tf.contrib.distributions.Dirichlet}
-* @{tf.contrib.distributions.DirichletMultinomial}
-* @{tf.contrib.distributions.Multinomial}
-* @{tf.contrib.distributions.WishartCholesky}
-* @{tf.contrib.distributions.WishartFull}
-
-### Multivariate Utilities
-
-* @{tf.contrib.distributions.matrix_diag_transform}
-
-## Transformed distributions
-
-* @{tf.contrib.distributions.TransformedDistribution}
-* @{tf.contrib.distributions.QuantizedDistribution}
-
-## Mixture Models
-
-* @{tf.contrib.distributions.Mixture}
-
-## Posterior inference with conjugate priors
-
-Functions that transform conjugate prior/likelihood pairs to distributions
-representing the posterior or posterior predictive.
-
-## Normal likelihood with conjugate prior
-
-* @{tf.contrib.distributions.normal_conjugates_known_scale_posterior}
-* @{tf.contrib.distributions.normal_conjugates_known_scale_predictive}
-
-## Kullback-Leibler Divergence
-
-* @{tf.contrib.distributions.kl_divergence}
-* @{tf.contrib.distributions.RegisterKL}
-
-## Utilities
-
-* @{tf.contrib.distributions.softplus_inverse}
diff --git a/tensorflow/docs_src/get_started/eager.md b/tensorflow/docs_src/get_started/eager.md
index f08ac74425..bbb25e20c6 100644
--- a/tensorflow/docs_src/get_started/eager.md
+++ b/tensorflow/docs_src/get_started/eager.md
@@ -1,3 +1,3 @@
# Get Started with Eager Execution
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.8.0/samples/core/get_started/eager.ipynb)
+[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.9.0/samples/core/get_started/eager.ipynb)
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 1abd840ab3..2901848745 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -38,7 +38,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for macOS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.8.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index 52a2a3f8a6..55bc0f64e7 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.8.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.9.0-rc0.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 1256fb99c4..637231da12 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
</dependency>
```
@@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
</dependency>
</dependencies>
</project>
@@ -124,12 +124,12 @@ instead:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
</dependency>
```
@@ -148,7 +148,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or macOS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -167,7 +167,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.8.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -175,13 +175,13 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0-rc0.zip).
3. Extract this .zip file.
-
+__Note__: The native library (`tensorflow_jni.dll`) requires `msvcp140.dll` at runtime, which is included in the [Visual C++ 2015 Redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=48145) package.
### Validate the installation
@@ -227,7 +227,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.8.0.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.9.0-rc0.jar HelloTF.java</b></pre>
### Running
@@ -241,11 +241,11 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and macOS X:
-<pre><b>java -cp libtensorflow-1.8.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.9.0-rc0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.8.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.9.0-rc0.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 7b56b6a508..c8d706cf3c 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -436,7 +436,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
## Validate your installation
@@ -682,14 +682,14 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -701,14 +701,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -720,14 +720,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
</pre>
@@ -739,14 +739,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 29a867a9e3..9d01271c5a 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -242,7 +242,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
@@ -350,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (<i>targetDirectory</i>)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -522,7 +522,7 @@ The value you specify depends on your Python version.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl
</pre>
@@ -530,5 +530,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index cc29074757..dc6c1e36fc 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -328,10 +328,10 @@ Invoke `pip install` to install that pip package.
The filename of the `.whl` file depends on your platform.
For example, the following command will install the pip package
-for TensorFlow 1.8.0 on Linux:
+for TensorFlow 1.9.0rc0 on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.8.0-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0rc0-py2-none-any.whl</b>
</pre>
## Validate your installation
@@ -433,6 +433,8 @@ Stack Overflow and specify the `tensorflow` tag.
**Linux**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.11.0</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.9.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.11.0</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.10.0</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow_gpu-1.8.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.10.0</td><td>N/A</td><td>N/A</td></tr>
@@ -456,6 +458,7 @@ Stack Overflow and specify the `tensorflow` tag.
**Mac**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.11.0</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.10.1</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.10.1</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.8.1</td><td>N/A</td><td>N/A</td></tr>
@@ -472,6 +475,8 @@ Stack Overflow and specify the `tensorflow` tag.
**Windows**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.9.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow_gpu-1.8.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md
index 5622034827..3d1733024e 100644
--- a/tensorflow/docs_src/mobile/tflite/index.md
+++ b/tensorflow/docs_src/mobile/tflite/index.md
@@ -37,8 +37,9 @@ a custom (less-dynamic) memory allocator to ensure minimal load, initialization,
and execution latency.
TensorFlow Lite provides an interface to leverage hardware acceleration, if
-available on the device. It does so via the Android Neural Networks library,
-released as part of Android O-MR1.
+available on the device. It does so via the
+[Android Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/index.html),
+available on Android 8.1 (API level 27) and higher.
## Why do we need a new mobile-specific library?
@@ -116,6 +117,10 @@ following:
Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
to all first-party and third-party apps.
+ Also see the complete list of
+ [TensorFlow Lite's supported models](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md),
+ including the model sizes, performance numbers, and downloadable model files.
+
- Quantized versions of the MobileNet model, which runs faster than the
non-quantized (float) version on CPU.
@@ -131,10 +136,10 @@ compatibility with this release.
## Getting Started
We recommend you try out TensorFlow Lite with the pre-tested models indicated
-above. If you have an existing mode, you will need to test whether your model is
-compatible with both the converter and the supported operator set. To test your
-model, see the [documentation on
-GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite).
+above. If you have an existing model, you will need to test whether your model
+is compatible with both the converter and the supported operator set. To test
+your model, see the
+[documentation on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite).
### Retrain Inception-V3 or MobileNet for a custom data set
diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md
index 6bd941886d..fc845c68f4 100644
--- a/tensorflow/docs_src/programmers_guide/debugger.md
+++ b/tensorflow/docs_src/programmers_guide/debugger.md
@@ -33,8 +33,9 @@ and [`inf`s](https://en.wikipedia.org/wiki/Infinity), a frequently-encountered
type of bug in TensorFlow model development.
The following example is for users who use the low-level
[`Session`](https://www.tensorflow.org/api_docs/python/tf/Session) API of
-TensorFlow. A later section of this document describes how to use **tfdbg**
-with a higher-level API, namely `Estimator`s.
+TensorFlow. Later sections of this document describe how to use **tfdbg**
+with higher-level APIs of TensorFlow, including `tf.estimator`,
+`tf.keras` / `keras` and `tf.contrib.slim`.
To *observe* such an issue, run the following command without the debugger (the
source code can be found
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/debug/examples/debug_mnist.py)):
@@ -477,20 +478,31 @@ for more details.
## Debugging Keras Models with TFDBG
-To use TFDBG with [Keras](https://keras.io/), let the Keras backend use
-a TFDBG-wrapped Session object. For example, to use the CLI wrapper:
+To use TFDBG with
+[tf.keras](https://www.tensorflow.org/api_docs/python/tf/keras),
+let the Keras backend use a TFDBG-wrapped Session object. For example, to use
+the CLI wrapper:
``` python
import tensorflow as tf
-from keras import backend as keras_backend
from tensorflow.python import debug as tf_debug
-keras_backend.set_session(tf_debug.LocalCLIDebugWrapperSession(tf.Session()))
+tf.keras.backend.set_session(tf_debug.LocalCLIDebugWrapperSession(tf.Session()))
# Define your keras model, called "model".
-model.fit(...) # This will break into the TFDBG CLI.
+
+# Calls to `fit()`, 'evaluate()` and `predict()` methods will break into the
+# TFDBG CLI.
+model.fit(...)
+model.evaluate(...)
+model.predict(...)
```
+With minor modification, the preceding code example also works for the
+[non-TensorFlow version of Keras](https://keras.io/) running against a
+TensorFlow backend. You just need to replace `tf.keras.backend` with
+`keras.backend`.
+
## Debugging tf-slim with TFDBG
TFDBG supports debugging of training and evaluation with
diff --git a/tensorflow/docs_src/programmers_guide/keras.md b/tensorflow/docs_src/programmers_guide/keras.md
index 6a9df12a25..c6aca7ebf4 100644
--- a/tensorflow/docs_src/programmers_guide/keras.md
+++ b/tensorflow/docs_src/programmers_guide/keras.md
@@ -1,334 +1,304 @@
# Keras
-## What's Keras?
-
-Keras is a high-level API specification for building and training deep learning
-models, suitable for fast prototyping, advanced research, and production.
-It offers three key advantages:
-
-- **User friendliness.** Keras follows best practices for reducing
- cognitive load: it offers consistent & simple interfaces,
- it minimizes the number of user actions required for common use cases,
- and it provides clear and actionable feedback upon user error.
-- **Modularity and composability.** A Keras model is composed of
- fully-configurable building blocks that can be plugged together
- with as few restrictions as possible -- like Lego bricks.
-- **Easy extensibility.** You can easily write your own building blocks
- (such as new layers, new loss functions, new models where you write
- the forward pass from scratch). This allows for total expressiveness,
- making Keras suitable for advanced research.
-
-
-## What's tf.keras?
-
-`tf.keras` is TensorFlow's implementation of the Keras API specification, that
-serves as the TensorFlow high-level API: it's how you build models in TensorFlow.
-`tf.keras` seamlessly integrates with the rest of the TensorFlow API
-(such as `tf.data` input pipelines), bringing you the full power and flexibility
-of TensorFlow through an easy-to-use interface.
-
-You can import `tf.keras` via:
+Keras is a high-level API to build and train deep learning models. It's used for
+fast prototyping, advanced research, and production, with three key advantages:
+
+- *User friendly*<br>
+ Keras has a simple, consistent interface optimized for common use cases. It
+ provides clear and actionable feedback for user errors.
+- *Modular and composable*<br>
+ Keras models are made by connecting configurable building blocks together,
+ with few restrictions.
+- *Easy to extend*<br> Write custom building blocks to express new ideas for
+ research. Create new layers, loss functions, and develop state-of-the-art
+ models.
+
+## Import tf.keras
+
+`tf.keras` is TensorFlow's implementation of the
+[Keras API specification](https://keras.io){:.external}. This is a high-level
+API to build and train models that includes first-class support for
+TensorFlow-specific functionality, such as [eager execution](#eager_execution),
+`tf.data` pipelines, and [Estimators](/programmers_guide/estimators).
+`tf.keras` makes TensorFlow easier to use without sacrificing flexibility and
+performance.
+
+To get started, import `tf.keras` as part of your TensorFlow program setup:
```python
+import tensorflow as tf
from tensorflow import keras
```
-What follows is a quick introduction to the basics of `tf.keras`.
+`tf.keras` can run any Keras-compatible code, but keep in mind:
+* The `tf.keras` version in the latest TensorFlow release might not be the same
+ as the latest `keras` version from PyPI. Check `tf.keras.__version__`.
+* When [saving a model's weights](#weights_only), `tf.keras` defaults to the
+ [checkpoint format](/get_started/checkpoints). Pass `save_format='h5'` to use
+ HDF5.
-## Table of contents
+## Build a simple model
-- [Getting started: the Sequential model](#getting-started-the-sequential-model)
-- [Configuring layers](#configuring-layers)
-- [Configuring training](#configuring-training)
-- [Training and evaluation](#training-and-evaluation)
-- [Building advanced models: the functional API](#building-advanced-models-the-functional-api)
-- [Building fully-customizable research models: the Model subclassing API](#building-fully-customizable-research-models-the-model-subclassing-api)
-- [Callbacks](#callbacks)
-- [Saving and serialization](#saving-and-serialization)
-- [Developing custom layers](#developing-custom-layers)
-- [Eager execution](#eager-execution)
-- [Further reading](#further-reading)
-- [FAQ](#faq)
+### Sequential model
+In Keras, you assemble *layers* to build *models*. A model is (usually) a graph
+of layers. The most common type of model is a stack of layers: the
+`tf.keras.Sequential` model.
----
-
-## Getting started: the Sequential model
-
-In `tf.keras`, you're assembling together **layers** to build **models**.
-A model is generally a graph of layers.
-The most common type of model is just a stack of layers: the `Sequential` class.
-
-Here's how to build a simple fully-connected network (multi-layer perceptron):
+To build a simple, fully-connected network (i.e. multi-layer perceptron):
```python
-from tensorflow import keras
-from tensorflow.keras import layers
-
model = keras.Sequential()
-# This adds to the model a densely-connected layer with 64 units:
-model.add(Dense(64, activation='relu'))
-# Another one:
-model.add(Dense(64, activation='relu'))
-# This adds a softmax layer with 10 output units:
-model.add(Dense(10, activation='softmax'))
+# Adds a densely-connected layer with 64 units to the model:
+model.add(keras.layers.Dense(64, activation='relu'))
+# Add another:
+model.add(keras.layers.Dense(64, activation='relu'))
+# Add a softmax layer with 10 output units:
+model.add(keras.layers.Dense(10, activation='softmax'))
```
----
-
-## Configuring layers
-
-Each layer may have unique constructor arguments, but some common arguments include:
+### Configure the layers
-- `activation`: the activation function to be used.
- It could be specified by name, as a string (for built-in functions)
- or as a callable object. By default, no activation is applied.
-- `kernel_initializer` and `bias_initializer`: the initialization schemes to use
- to create the layer's weights (kernel and bias).
- Likewise, they may be passed either by name or by specifying a callable.
- By default, the "Glorot uniform" initializer is used.
-- `kernel_regularizer` and `bias_regularizer`: the regularization schemes to
- apply to the layer's weights (kernel and bias), such as L1
- or L2 regularization. By default, no regularization is applied.
+There are many `tf.keras.layers` available with some common constructor
+parameters:
+* `activation`: Set the activation function for the layer. This parameter is
+ specified by the name of a built-in function or as a callable object. By
+ default, no activation is applied.
+* `kernel_initializer` and `bias_initializer`: The initialization schemes
+ that create the layer's weights (kernel and bias). This parameter is a name or
+ a callable object. This defaults to the `"Glorot uniform"` initializer.
+* `kernel_regularizer` and `bias_regularizer`: The regularization schemes
+ that apply the layer's weights (kernel and bias), such as L1 or L2
+ regularization. By default, no regularization is applied.
-### Examples
+The following instantiates `tf.keras.layers.Dense` layers using constructor
+arguments:
```python
-import tensorflow as tf
-from tensorflow.keras.layers import Dense
-from tensorflow.keras import regularizers
-from tensorflow.keras import initializers
-
-# A sigmoid layer:
-Dense(64, activation='sigmoid')
-# Another way to define the same sigmoid layer:
-Dense(64, activation=tf.sigmoid)
-
-# A linear layer with L1 regularization of factor 0.01
-# applied to the kernel matrix:
-Dense(64, kernel_regularizer=regularizers.l1(0.01))
-# A linear layer with L2 regularization of factor 0.01
-# applied to the bias vector:
-Dense(64, bias_regularizer=regularizers.l2(0.01))
+# Create a sigmoid layer:
+layers.Dense(64, activation='sigmoid')
+# Or:
+layers.Dense(64, activation=tf.sigmoid)
+
+# A linear layer with L1 regularization of factor 0.01 applied to the kernel matrix:
+layers.Dense(64, kernel_regularizer=keras.regularizers.l1(0.01))
+# A linear layer with L2 regularization of factor 0.01 applied to the bias vector:
+layers.Dense(64, bias_regularizer=keras.regularizers.l2(0.01))
# A linear layer with a kernel initialized to a random orthogonal matrix:
-Dense(64, kernel_initializer='orthogonal')
+layers.Dense(64, kernel_initializer='orthogonal')
# A linear layer with a bias vector initialized to 2.0s:
-Dense(64, bias_initializer=initializers.constant(2.0))
+layers.Dense(64, bias_initializer=keras.initializers.constant(2.0))
```
----
+## Train and evaluate
-## Configuring training
+### Set up training
-Once your model looks good, configure its learning process by calling `compile`:
+After the model is constructed, configure its learning process by calling the
+`compile` method:
```python
-import tensorflow as tf
-
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
```
-There are three key arguments that you need to specify:
+`tf.keras.Model.compile` takes three important arguments:
-- An `optimizer`: this object specifies the training procedure.
- We recommend that you pass instances of optimizers from the `tf.train` module
- (such as [`AdamOptimizer`](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer),
- [`RMSPropOptimizer`](https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer),
- or [`GradientDescentOptimizer`](https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer)).
-- A `loss` function to minimize: this specifies the optimization objective.
- Common choices include mean square error (`mse`), `categorical_crossentropy`
- and `binary_crossentropy`. Loss functions may be specified by name
- or by passing a callable (e.g. from the `tf.keras.losses` module).
-- Some `metrics` to monitor during training: again, you can pass these as either
- string names or callables (e.g. from the `tf.keras.metrics` module).
+* `optimizer`: This object specifies the training procedure. Pass it optimizer
+ instances from the `tf.train` module, such as
+ [`AdamOptimizer`](/api_docs/python/tf/train/AdamOptimizer),
+ [`RMSPropOptimizer`](/api_docs/python/tf/train/RMSPropOptimizer), or
+ [`GradientDescentOptimizer`](/api_docs/python/tf/train/GradientDescentOptimizer).
+* `loss`: The function to minimize during optimization. Common choices include
+ mean square error (`mse`), `categorical_crossentropy`, and
+ `binary_crossentropy`. Loss functions are specified by name or by
+ passing a callable object from the `tf.keras.losses` module.
+* `metrics`: Used to monitor training. These are string names or callables from
+ the `tf.keras.metrics` module.
-
-### Examples
+The following shows a few examples of configuring a model for training:
```python
-# Configures a model to do mean-squared error regression.
+# Configure a model for mean-squared error regression.
model.compile(optimizer=tf.train.AdamOptimizer(0.01),
- loss='mse', # mean squared error
+ loss='mse', # mean squared error
metrics=['mae']) # mean absolute error
-```
-```python
-# Configures a model to do categorical classification.
+
+# Configure a model for categorical classification.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
- loss=tf.keras.losses.categorical_crossentropy,
- metrics=[tf.keras.metrics.categorical_accuracy])
+ loss=keras.losses.categorical_crossentropy,
+ metrics=[keras.metrics.categorical_accuracy])
```
----
-
-## Training and evaluation
+### Input NumPy data
-### From Numpy data
-
-When running locally on small datasets, the easiest way to do training and
-evaluation is to pass data to your model as Numpy arrays of inputs and targets.
-You can "fit" your model to some training data using the `model.fit()` method:
+For small datasets, use in-memory [NumPy](https://www.numpy.org/){:.external}
+arrays to train and evaluate a model. The model is "fit" to the training data
+using the `fit` method:
```python
import numpy as np
-data = np.random.random(shape=(1000, 32))
-targets = np.random.random(shape=(1000, 10))
+data = np.random.random((1000, 32))
+labels = np.random.random((1000, 10))
-model.fit(data, targets, epochs=10, batch_size=32)
+model.fit(data, labels, epochs=10, batch_size=32)
```
-Here are some key arguments you can pass to the `fit` method:
-
-- `epochs`: Training is structured into **epochs**. An epoch is one iteration
- over the entire input data (which is done in smaller batches).
-- `batch_size`: when passing Numpy data, the model will slice the data into
- smaller batches and iterate over these batches during training.
- This integer specifies the size of each batch
- (the last batch may be smaller if the total number of samples is not
- divisible by the batch size).
-- `validation_data`: when prototyping a model, you want to be able to quickly
- monitor its performance on some validation data.
- When you pass this argument (it expects a tuple of inputs and targets),
- the model will display the loss and metrics in inference mode on the data
- you passed, at the end of each epoch.
+`tf.keras.Model.fit` takes three important arguments:
+
+* `epochs`: Training is structured into *epochs*. An epoch is one iteration over
+ the entire input data (this is done in smaller batches).
+* `batch_size`: When passed NumPy data, the model slices the data into smaller
+ batches and iterates over these batches during training. This integer
+ specifies the size of each batch. Be aware that the last batch may be smaller
+ if the total number of samples is not divisible by the batch size.
+* `validation_data`: When prototyping a model, you want to easily monitor its
+ performance on some validation data. Passing this argument—a tuple of inputs
+ and labels—allows the model to display the loss and metrics in inference mode
+ for the passed data, at the end of each epoch.
Here's an example using `validation_data`:
```python
import numpy as np
-data = np.random.random(shape=(1000, 32))
-targets = np.random.random(shape=(1000, 10))
+data = np.random.random((1000, 32))
+labels = np.random.random((1000, 10))
-val_data = np.random.random(shape=(100, 32))
-val_targets = np.random.random(shape=(100, 10))
+val_data = np.random.random((100, 32))
+val_labels = np.random.random((100, 10))
-model.fit(data, targets, epochs=10, batch_size=32,
- validation_data=(val_data, val_targets))
+model.fit(data, labels, epochs=10, batch_size=32,
+ validation_data=(val_data, val_labels))
```
-### From tf.data datasets
+### Input tf.data datasets
-When you need to scale to large datasets or multi-device training,
-training from Numpy arrays in memory will not be ideal.
-In such cases, you should use [the `tf.data` API](https://www.tensorflow.org/programmers_guide/datasets).
-You can pass a `tf.data.Dataset` instance to the `fit` method:
+Use the [Datasets API](/programmers_guide/datasets) to scale to large datasets
+or multi-device training. Pass a `tf.data.Dataset` instance to the `fit`
+method:
```python
-import tensorflow as tf
-
# Instantiates a toy dataset instance:
-dataset = tf.data.Dataset.from_tensor_slices((data, targets)).batch(32)
+dataset = tf.data.Dataset.from_tensor_slices((data, labels))
+dataset = dataset.batch(32)
+dataset = dataset.repeat()
# Don't forget to specify `steps_per_epoch` when calling `fit` on a dataset.
model.fit(dataset, epochs=10, steps_per_epoch=30)
```
-When doing so, the dataset itself will yield batches of data,
-so the model does not need to be passed `batch_size` information.
-Instead, the model needs to know for how many steps (or batches of data)
-it should run at each epoch.
-You specify this with the `steps_per_epoch` argument: it's the number of
-training steps the model will run before moving on the next epoch.
+Here, the `fit` method uses the `steps_per_epoch` argument—this is the number of
+training steps the model runs before it moves to the next epoch. Since the
+`Dataset` yields batches of data, this snippet does not require a `batch_size`.
-You can also pass datasets for validation:
+Datasets can also be used for validation:
```python
-dataset = tf.data.Dataset.from_tensor_slices((data, targets)).batch(32)
-val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_targets)).batch(32)
+dataset = tf.data.Dataset.from_tensor_slices((data, labels))
+dataset = dataset.batch(32).repeat()
-model.fit(dataset, epochs=10, steps_per_epoch=30, validation_data=val_dataset, validation_steps=3)
+val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
+val_dataset = val_dataset.batch(32).repeat()
+
+model.fit(dataset, epochs=10, steps_per_epoch=30,
+ validation_data=val_dataset,
+ validation_steps=3)
```
### Evaluate and predict
-In addition, you get access to the following methods
-(both with Numpy data and dataset instances):
+The `tf.keras.Model.evaluate` and `tf.keras.Model.predict` methods can use NumPy
+data and a `tf.data.Dataset`.
-- `model.evaluate(x, y, batch_size=32)` or `model.evaluate(dataset, steps=30)`
- will return the inference-mode loss and metrics for the data provided.
-- `model.predict(x, y, batch_size=32)` or `model.predict(dataset, steps=30)`
- will return the output(s) of the last layer(s) in inference on the data
- provided, as Numpy array(s).
+To *evaluate* the inference-mode loss and metrics for the data provided:
----
+```python
+model.evaluate(x, y, batch_size=32)
-## Building advanced models: the functional API
+model.evaluate(dataset, steps=30
+```
-The `Sequential` model cannot represent arbitrary models -- only simple stacks
-of layers. If you need to use more complex model topologies,
-such as multi-input models, multi-output models,
-models with a same layer called several times (shared layers),
-or models with non-sequential data flows (e.g. residual connections),
-you can use the 'functional API'.
+And to *predict* the output of the last layer in inference for the data provided,
+as a NumPy array:
-Here's how it works:
+```
+model.predict(x, batch_size=32)
-- A layer instance is callable (on a tensor), and it returns a tensor.
-- Input tensor(s) and output tensor(s) can then be used to define a `Model` instance.
-- Such a model can be trained just like the `Sequential` model.
+model.predict(dataset, steps=30)
+```
-Here's a basic example showing the same model we previously defined,
-built using the functional API:
+## Build advanced models
-```python
-from tensorflow import keras
-from tensorflow.keras import layers
+### Functional API
-# This returns a placeholder tensor:
-inputs = keras.Input(shape=(784,))
+The `tf.keras.Sequential` model is a simple stack of layers that cannot
+represent arbitrary models. Use the
+[Keras functional API](https://keras.io/getting-started/functional-api-guide/){:.external}
+to build complex model topologies such as:
+
+* Multi-input models,
+* Multi-output models,
+* Models with shared layers (the same layer called several times),
+* Models with non-sequential data flows (e.g. residual connections).
+
+Building a model with the functional API works like this:
+
+1. A layer instance is callable and returns a tensor.
+2. Input tensors and output tensors are used to define a `tf.keras.Model`
+ instance.
+3. This model is trained just like the `Sequential` model.
+
+The following example uses the functional API to build a simple, fully-connected
+network:
+
+```python
+inputs = keras.Input(shape=(32,)) # Returns a placeholder tensor
# A layer instance is callable on a tensor, and returns a tensor.
-x = layers.Dense(64, activation='relu')(inputs)
-x = layers.Dense(64, activation='relu')(x)
-predictions = layers.Dense(10, activation='softmax')(x)
+x = keras.layers.Dense(64, activation='relu')(inputs)
+x = keras.layers.Dense(64, activation='relu')(x)
+predictions = keras.layers.Dense(10, activation='softmax')(x)
-# Instantiates the model given inputs and outputs.
+# Instantiate the model given inputs and outputs.
model = keras.Model(inputs=inputs, outputs=predictions)
-# The "compile" step specifies the training configuration.
-model.compile(optimizer='rmsprop',
+# The compile step specifies the training configuration.
+model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
-# Trains for 5 epochs.
+# Trains for 5 epochs
model.fit(data, labels, batch_size=32, epochs=5)
```
-This API enables you to create models with multiple inputs and outputs,
-and to "share" layers across different inputs
-(i.e. to reuse a same instance multiple times).
-For examples of these use cases,
-please see [this guide to the functional API in Keras](https://keras.io/getting-started/functional-api-guide/).
+### Model subclassing
----
+Build a fully-customizable model by subclassing `tf.keras.Model` and defining
+your own forward pass. Create layers in the `__init__` method and set them as
+attributes of the class instance. Define the forward pass in the `call` method.
-## Building fully-customizable research models: the Model subclassing API
+Model subclassing is particularly useful when
+[eager execution](/programmers_guide/eager) is enabled since the forward pass
+can be written imperatively.
-Besides `Sequential` and the functional API, one last, more flexible way to
-define models is to directly subclass the `Model` class and define your own
-forward pass manually.
+Key Point: Use the right API for the job. While model subclassing offers
+flexibility, it comes at a cost of greater complexity and more opportunities for
+user errors. If possible, prefer the functional API.
-In this API, you instante layers in `__init__` and set them as attribute of the
-class instance. Then you specify the forward pass in `call`.
-This API is particularly valuable when using TensorFlow with [eager execution](https://www.tensorflow.org/programmers_guide/eager),
-since eager execution allows you to write your forward pass in an
-imperative fashion (as if you were writing Numpy code, for instance).
+The following example shows a subclassed `tf.keras.Model` using a custom forward
+pass:
```python
-import tensorflow as tf
-from tensorflow import keras
-
-
class MyModel(keras.Model):
- def __init__(self, num_classes=2):
+ def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
# Define your layers here.
@@ -351,10 +321,10 @@ class MyModel(keras.Model):
# Instantiates the subclassed model.
-model = MyModel(num_classes=2)
+model = MyModel(num_classes=10)
-# The "compile" step specifies the training configuration.
-model.compile(optimizer='rmsprop',
+# The compile step specifies the training configuration.
+model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
@@ -362,353 +332,291 @@ model.compile(optimizer='rmsprop',
model.fit(data, labels, batch_size=32, epochs=5)
```
-**Remember:** use the right API for the right job.
-Using the `Model` subclassing API offers more flexibility,
-but at the cost of greater complexity and a larger potential user error surface.
-Prefer using the functional API when possible.
----
+### Custom layers
-## Callbacks
+Create a custom layer by subclassing `tf.keras.layers.Layer` and implementing
+the following methods:
-Callbacks are objects that you can pass to your model that customize and extend
-its behavior during training.
-There are callbacks for saving checkpoints of your model at regular intervals
-(`tf.keras.callbacks.ModelCheckpoint`),
-to dynamically change the learning rate (`tf.keras.callbacks.LearningRateScheduler`)
-or to interrupt training when validation performance has stopped improving
-(`tf.keras.callbacks.EarlyStopping`).
-You can also use a callback to monitor your model's behavior using
-[TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard)
-(`tf.keras.callbacks.TensorBoard`).
-You can also write your own custom callbacks.
-
-Different built-in callback are found in `tf.keras.callbacks`.
-You use them by passing a `Callback` instance to `fit`:
+* `build`: Create the weights of the layer. Add weights with the `add_weight`
+ method.
+* `call`: Define the forward pass.
+* `compute_output_shape`: Specify how to compute the output shape of the layer
+ given the input shape.
+* Optionally, a layer can be serialized by implementing the `get_config` method
+ and the `from_config` class method.
+
+Here's an example of a custom layer that implements a `matmul` of an input with
+a kernel matrix:
```python
-from tensorflow import keras
+class MyLayer(keras.layers.Layer):
+
+ def __init__(self, output_dim, **kwargs):
+ self.output_dim = output_dim
+ super(MyLayer, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ shape = tf.TensorShape((input_shape[1], self.output_dim))
+ # Create a trainable weight variable for this layer.
+ self.kernel = self.add_weight(name='kernel',
+ shape=shape,
+ initializer='uniform',
+ trainable=True)
+ # Be sure to call this at the end
+ super(MyLayer, self).build(input_shape)
-callbacks = [
- # Interrupt training if `val_loss` stops improving for over 2 epochs
- keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
- # Write TensorBoard logs to `./logs` directory
- keras.callbacks.TensorBoard(log_dir='./logs')
-]
-model.fit(data, labels, batch_size=32, epochs=5, callbacks=callbacks)
-```
+ def call(self, inputs):
+ return tf.matmul(inputs, self.kernel)
----
+ def compute_output_shape(self, input_shape):
+ shape = tf.TensorShape(input_shape).as_list()
+ shape[-1] = self.output_dim
+ return tf.TensorShape(shape)
-## Saving and serialization
+ def get_config(self):
+ base_config = super(MyLayer, self).get_config()
+ base_config['output_dim'] = self.output_dim
-### Weights-only saving
+ @classmethod
+ def from_config(cls, config):
+ return cls(**config)
-You can save the weight values of a model via `model.save_weights(filepath)`:
-```python
-# Saves weights to a SavedModel file.
-model.save_weights('my_model')
+# Create a model using the custom layer
+model = keras.Sequential([MyLayer(10),
+ keras.layers.Activation('softmax')])
-# Restores the model's state
-# (this requires a model that has the same architecture).
-model.load_weights('my_model')
+# The compile step specifies the training configuration
+model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+# Trains for 5 epochs.
+model.fit(data, targets, batch_size=32, epochs=5)
```
-By default, this saves the weight in the TensorFlow
-[`SavedModel`](https://www.tensorflow.org/programmers_guide/saved_model) format.
-You could also save them in the Keras HDF5 format
-(which is the default in the multi-backend implementation of Keras):
-```python
-# Saves weights to a HDF5 file.
-model.save_weights('my_model.h5', format='h5')
+## Callbacks
-# Restores the model's state.
-model.load_weights('my_model.h5')
-```
+A callback is an object passed to a model to customize and extend its behavior
+during training. You can write your own custom callback, or use the built-in
+`tf.keras.callbacks` that include:
-### Configuration-only saving (serialization)
+* `tf.keras.callbacks.ModelCheckpoint`: Save checkpoints of your model at
+ regular intervals.
+* `tf.keras.callbacks.LearningRateScheduler`: Dynamically change the learning
+ rate.
+* `tf.keras.callbacks.EarlyStopping`: Interrupt training when validation
+ performance has stopped improving.
+* `tf.keras.callbacks.TensorBoard`: Monitor the model's behavior using
+ [TensorBoard](/programmers_guide/summaries_and_tensorboard).
-You can also save the model's configuration
-(its architecture, without any weight values),
-which allows you to recreate the same model later (freshly initialized) even if
-you don't have the code that defined it anymore.
-Two possible serialization formats are JSON and YAML:
+To use a `tf.keras.callbacks.Callback`, pass it to the model's `fit` method:
```python
-from tensorflow.keras import models
-
-# Serializes a model to JSON.
-json_string = model.to_json()
-# Recreates the model (freshly initialized).
-fresh_model = models.from_json(json_string)
-
-# Serializes a model to YAML.
-yaml_string = model.to_yaml()
-# Recreates the model.
-fresh_model = models.from_yaml(yaml_string)
+callbacks = [
+ # Interrupt training if `val_loss` stops improving for over 2 epochs
+ keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
+ # Write TensorBoard logs to `./logs` directory
+ keras.callbacks.TensorBoard(log_dir='./logs')
+]
+model.fit(data, labels, batch_size=32, epochs=5, callbacks=callbacks,
+ validation_data=(val_data, val_targets))
```
-Note that this feature is not available with subclassed models,
-because they are simply not serializable:
-their architecture is defined as Python code
-(the body of the `call` method of the model).
-### Whole-model saving
+## Save and restore
-Finally, you can also save a model wholesale, to a file that will contain both
-the weight values, the model's configuration,
-and even the optimizer's configuration.
-The allows you to checkpoint a model and resume training later --
-from the exact same state -- even if you don't have access to the original code.
+### Weights only
-```python
-from tensorflow.keras import models
+Save and load the weights of a model using `tf.keras.Model.save_weights`:
-model.save('my_model.h5')
+```python
+# Save weights to a TensorFlow Checkpoint file
+model.save_weights('./my_model')
-# Recreates the exact same model, complete with weights and optimizer.
-model = models.load_model('my_model.h5')
+# Restore the model's state,
+# this requires a model with the same architecture.
+model.load_weights('my_model')
```
----
-
-## Developing custom layers
-
-You can write your own custom layers by subclassing the class
-`tf.keras.layers.Layer`. You will need to implement the following three methods:
-
-- `build`: Creates the weights of the layer.
- Weights should be added via the `add_weight` method.
-- `call`: Specifies the forward pass.
-- `compute_output_shape`: Specifies how to compute the output shape of the layer
- given the input shape.
-
-Optionally, you may also implement the method `get_config()` and the
-class method `from_config()` if you want your layer to be serializable.
-
-Here's a simple example of a custom layer that implements a `matmul`
-of an input with a kernel matrix:
+By default, this saves the model's weights in the
+[TensorFlow checkpoint](/get_started/checkpoints) file format. Weights can also
+be saved to the Keras HDF5 format (the default for the multi-backend
+implementation of Keras):
```python
-import tensorflow as tf
-from tensorflow.keras import layers
-
-class MyLayer(layers.Layer):
-
- def __init__(self, output_dim, **kwargs):
- self.output_dim = output_dim
- super(MyLayer, self).__init__(**kwargs)
-
- def build(self, input_shape):
- # Create a trainable weight variable for this layer.
- self.kernel = self.add_weight(name='kernel',
- shape=(input_shape[1], self.output_dim),
- initializer='uniform',
- trainable=True)
- # Be sure to call this at the end
- super(MyLayer, self).build(input_shape)
-
- def call(self, inputs):
- return tf.matmul(inputs, self.kernel)
-
- def compute_output_shape(self, input_shape):
- shape = tf.TensorShape(input_shape).as_list()
- shape[-1] = self.output_dim
- return tf.TensorShape(shape)
-
- def get_config(self):
- base_config = super(MyLayer, self).get_config()
- base_config['output_dim'] = self.output_dim
-
- @classmethod
- def from_config(cls, config):
- return cls(**config)
-```
+# Save weights to a HDF5 file
+model.save_weights('my_model.h5', save_format='h5')
----
-
-## Eager execution
+# Restore the model's state
+model.load_weights('my_model.h5')
+```
-[Eager execution](https://www.tensorflow.org/programmers_guide/eager)
-is a way to write TensorFlow code imperatively.
-All three `tf.keras` model-building APIs
-(`Sequential`, the functional API `Model(inputs, outputs)`,
-and the subclassing API `MyModel(Model)`) are compatible with eager execution.
-When using `Sequential` or the functional API, it makes no difference to the
-user experience whether the model is executing eagerly or not.
-Eager execution is most beneficial when used with the `Model` subclassing API,
-or when prototyping a custom layer -- that is to say, in APIs that require you
-to *write a forward pass as code*, rather than in APIs that allow you to create
-models by assembling together existing layers.
+### Configuration only
-While the same training and evaluating APIs presented in this guide work
-as usual with eager execution, you can in addition
-write custom training loops using the eager `GradientTape`
-and define-by-run autodifferentiation:
+A model's configuration can be saved—this serializes the model architecture
+without any weights. A saved configuration can recreate and initialize the same
+model, even without the code that defined the original model. Keras supports
+JSON and YAML serialization formats:
```python
-import tensorflow as tf
-from tensorflow.contrib import eager as tfe
-
-# This call begins the eager execution session.
-tf.enable_eager_execution()
-
-model = ... # Defines a Keras model (we recommend Model subclassing in this case).
-dataset = ... # Defines a `tf.data` dataset.
+# Serialize a model to JSON format
+json_string = model.to_json()
-optimizer = tf.train.AdamOptimizer(0.01)
+# Recreate the model (freshly initialized)
+fresh_model = keras.models.from_json(json_string)
-for data, labels in dataset:
- # Runs the forward pass and loss computation under a `GradientTape` scope,
- # which will record all operations in order to prepare for the backward pass.
- with tfe.GradientTape() as tape:
- predictions = model(data)
- loss = loss_function(labels, predictions)
+# Serializes a model to YAML format
+yaml_string = model.to_yaml()
- # Runs the backward pass manually using the operations recorded
- # by the gradient tape.
- grads = tape.gradient(loss, model.trainable_weights)
- optimizer.apply_gradients(zip(grads, model.trainable_weights),
- global_step=tf.train.get_or_create_global_step())
+# Recreate the model
+fresh_model = keras.models.from_yaml(yaml_string)
```
----
+Caution: Subclassed models are not serializable because their architecture is
+defined by the Python code in the body of the `call` method.
-## Further reading
-### Documentation
+### Entire model
-- [tf.keras documentation](https://www.tensorflow.org/api_docs/python/tf/keras)
-- [keras.io](https://keras.io/)
+The entire model can be saved to a file that contains the weight values, the
+model's configuration, and even the optimizer's configuration. This allows you
+to checkpoint a model and resume training later—from the exact same
+state—without access to the original code.
-### tf.keras tutorials and examples
-
-- [Fashion-MNIST with tf.Keras](https://medium.com/tensorflow/hello-deep-learning-fashion-mnist-with-keras-50fcff8cd74a)
-- [Predicting the price of wine with the Keras Functional API and TensorFlow](
- https://medium.com/tensorflow/predicting-the-price-of-wine-with-the-keras-functional-api-and-tensorflow-a95d1c2c1b03)
+```python
+# Create a trivial model
+model = keras.Sequential([
+ keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
+ keras.layers.Dense(10, activation='softmax')
+])
+model.compile(optimizer='rmsprop',
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+model.fit(data, targets, batch_size=32, epochs=5)
----
+# Save entire model to a HDF5 file
+model.save('my_model.h5')
-## FAQ
+# Recreate the exact same model, including weights and optimizer.
+model = keras.models.load_model('my_model.h5')
+```
-### What are the differences between tf.keras and the multi-backend Keras implementation?
-`tf.keras` includes first-class support for important TensorFlow-specific
-functionality not found in other Keras implementations, in particular:
+## Eager execution
-- Support for eager execution.
-- Support for the `tf.data` API.
-- Integration with the
- [`tf.estimator` API](https://www.tensorflow.org/programmers_guide/estimators),
- via `tf.keras.estimator.model_to_estimator`.
+[Eager execution](/programmers_guide/eager) is an imperative programming
+environment that evaluates operations immediately. This is not required for
+Keras, but is supported by `tf.keras` and useful for inspecting your program and
+debugging.
-In terms of API differences: `tf.keras` is a full implementation of the
-Keras API, so any code targeting the Keras API will run on `tf.keras`.
-However, keep in mind that:
+All of the `tf.keras` model-building APIs are compatible with eager execution.
+And while the `Sequential` and functional APIs can be used, eager execution
+especially benefits *model subclassing* and building *custom layers*—the APIs
+that require you to write the forward pass as code (instead of the APIs that
+create models by assembling existing layers).
-- The `tf.keras` API version in the latest TensorFlow release might not be the
- same as the latest `keras` version from PyPI.
- Check out `tf.keras.__version__` if in doubt.
-- In `tf.keras`, the default file format saved by `model.save_weights` is the
- TensorFlow `SavedModel` format.
- To use HDF5, you can pass the `format='h5'` argument.
+See the [eager execution guide](/programmers_guide/eager#build_a_model) for
+examples of using Keras models with custom training loops and `tf.GradientTape`.
-### What is the relationship between tf.keras and tf.estimator?
+## Distribution
-The [`tf.estimator` API](https://www.tensorflow.org/programmers_guide/estimators)
-is a high-level TensorFlow API for training "estimator" models,
-in particular in distributed settings.
-This API targets industry use cases, such as distributed training
-on large datasets with a focus on eventually exporting a production model.
+### Estimators
-If you have a `tf.keras` model that would like to train with the `tf.estimator`
-API, you can convert your model to an `Estimator` object via the
-`model_to_estimator` utility](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models):
+The [Estimators](/programmers_guide/estimators) API is used for training models
+for distributed environments. This targets industry use cases such as
+distributed training on large datasets that can export a model for production.
+A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the
+model to an `tf.estimator.Estimator` object with
+`tf.keras.estimator.model_to_estimator`. See
+[Creating Estimators from Keras models](/programmers_guide/estimators#creating_estimators_from_keras_models).
```python
-estimator = tf.keras.estimator.model_to_estimator(model)
-```
+model = keras.Sequential([layers.Dense(10,activation='softmax'),
+ layers.Dense(10,activation='softmax')])
-When using `model_to_estimator`, enabling eager execution is helpful for
-developing and debugging your `input_fn`
-(as it allows you to easily print your data).
+model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+estimator = keras.estimator.model_to_estimator(model)
+```
+Note: Enable [eager execution](/programmers_guide/eager) for debugging
+[Estimator input functions](/programmers_guide/premade_estimators#create_input_functions)
+and inspecting data.
-### How can I run tf.keras models on multiple GPUs?
+### Multiple GPUs
-You can run tf.keras models on multiple GPUs using the
-[`DistributionStrategy API`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/DistributionStrategy).
-The `DistributionStrategy` API allow you to distribute training on multiple GPUs
-with almost no changes to your existing code.
+`tf.keras` models can run on multiple GPUs using
+`tf.contrib.distribute.DistributionStrategy`. This API provides distributed
+training on multiple GPUs with almost no changes to existing code.
-Currently [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy)
-is the only supported strategy.
-`MirroredStrategy` allows you to do in-graph replication with synchronous
-training using all-reduce on a single machine.
-To use `DistributionStrategy` with a `tf.keras` model,
-you can use the `model_to_estimator` utility to convert a `tf.keras` model to
-an `Estimator` and then train the estimator.
+Currently, `tf.contrib.distribute.MirroredStrategy` is the only supported
+distribution strategy. `MirroredStrategy` does in-graph replication with
+synchronous training using all-reduce on a single machine. To use
+`DistributionStrategy` with Keras, convert the `tf.keras.Model` to a
+`tf.estimator.Estimator` with `tf.keras.estimator.model_to_estimator`, then
+train the estimator
-Here is a simple example of distributing a `tf.keras` model across multiple GPUs
-on a single machine.
+The following example distributes a `tf.keras.Model` across multiple GPUs on a
+single machine.
-Let's first define a simple model:
+First, define a simple model:
```python
-model = tf.keras.Sequential()
-model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)))
-model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
+model = keras.Sequential()
+model.add(keras.layers.Dense(16, activation='relu', input_shape=(10,)))
+model.add(keras.layers.Dense(1, activation='sigmoid'))
+
optimizer = tf.train.GradientDescentOptimizer(0.2)
+
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
```
-Let's use `model_to_estimator` to create an `Estimator` instance from the
-`tf.keras` model defined above.
+Convert the Keras model to a `tf.estimator.Estimator` instance:
```python
-keras_estimator = tf.keras.estimator.model_to_estimator(
- keras_model=model,
- config=config,
- model_dir='/tmp/model_dir')
+keras_estimator = keras.estimator.model_to_estimator(
+ keras_model=model,
+ config=config,
+ model_dir='/tmp/model_dir')
```
-We'll use `tf.data.Datasets` to define our input pipeline.
-Our `input_fn` returns a `tf.data.Dataset` object that we then use to distribute
-the data across multiple devices with each device processing
+Define an *input pipeline*. The `input_fn` returns a `tf.data.Dataset` object
+used to distribute the data across multiple devices—with each device processing
a slice of the input batch.
```python
def input_fn():
- x = np.random.random((1024, 10))
- y = np.random.randint(2, size=(1024, 1))
- x = tf.cast(x, tf.float32)
- dataset = tf.data.Dataset.from_tensor_slices((x, y))
- dataset = dataset.repeat(10)
- dataset = dataset.batch(32)
- return dataset
+ x = np.random.random((1024, 10))
+ y = np.random.randint(2, size=(1024, 1))
+ x = tf.cast(x, tf.float32)
+ dataset = tf.data.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(10)
+ dataset = dataset.batch(32)
+ return dataset
```
-The next step is to create a `RunConfig` and set the train_distribute argument
-to the new `MirroredStrategy` instance.
-You can specify a list of devices or the `num_gpus` argument when creating
-a `MirroredStrategy` instance.
-Not specifying any arguments defaults to using all the available GPUs like we do
-in this example.
+Next, create a `tf.estimator.RunConfig` and set the `train_distribute` argument
+to the `tf.contrib.distribute.MirroredStrategy` instance. When creating
+`MirroredStrategy`, you can specify a list of devices or set the `num_gpus`
+argument. The default uses all available GPUs, like the following:
```python
strategy = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
```
-Call train on the `Estimator` instance providing the `input_fn` and `steps`
-arguments as input:
+Finally, train the `Estimator` instance by providing the `input_fn` and `steps`
+arguments:
```python
keras_estimator.train(input_fn=input_fn, steps=10)
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 016f572c3f..d2593587e6 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2892,6 +2892,104 @@ func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
return op.Output(0)
}
+// Gives a guarantee to the TF runtime that the input tensor is a constant.
+//
+// The runtime is then free to make optimizations based on this.
+//
+// Only accepts value typed tensors as inputs and rejects resource variable handles
+// as input.
+//
+// Returns the input tensor without modification.
+func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "GuaranteeConst",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Splits a tensor into `num_split` tensors along one dimension.
+//
+// Arguments:
+// value: The tensor to split.
+// size_splits: list containing the sizes of each output tensor along the split
+// dimension. Must sum to the dimension of value along split_dim.
+// Can contain one -1 indicating that dimension is to be inferred.
+// axis: 0-D. The dimension along which to split. Must be in the range
+// `[-rank(value), rank(value))`.
+//
+//
+// Returns Tensors whose shape matches that of `value`
+// except along `axis`, where their sizes are
+// `size_splits[i]`.
+func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_split": num_split}
+ opspec := tf.OpSpec{
+ Type: "SplitV",
+ Input: []tf.Input{
+ value, size_splits, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ scope.UpdateErr("SplitV", err)
+ return
+ }
+ return output
+}
+
+// Splits a tensor into `num_split` tensors along one dimension.
+//
+// Arguments:
+// axis: 0-D. The dimension along which to split. Must be in the range
+// `[-rank(value), rank(value))`.
+// value: The tensor to split.
+// num_split: The number of ways to split. Must evenly divide
+// `value.shape[split_dim]`.
+//
+// Returns They are identically shaped tensors, whose shape matches that of `value`
+// except along `axis`, where their sizes are
+// `values.shape[split_dim] / num_split`.
+func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_split": num_split}
+ opspec := tf.OpSpec{
+ Type: "Split",
+ Input: []tf.Input{
+ axis, value,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ scope.UpdateErr("Split", err)
+ return
+ }
+ return output
+}
+
// Creates a sequence of numbers.
//
// This operation creates a sequence of numbers that begins at `start` and
@@ -7457,6 +7555,36 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra
return op.Output(0)
}
+// Returns the element-wise sum of a list of tensors.
+//
+// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
+// wait for all of its inputs to be ready before beginning to sum. This can
+// save memory if inputs are ready at different times, since minimum temporary
+// storage is proportional to the output size rather than the inputs size.
+//
+// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
+//
+// Returns a `Tensor` of same shape and type as the elements of `inputs`.
+//
+// Arguments:
+// inputs: A list of `Tensor` objects, each with same shape and type.
+// shape: Shape of elements of `inputs`.
+func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "AccumulateNV2",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
@@ -7527,6 +7655,69 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s
return op.Output(0)
}
+// Returns immutable tensor from memory region.
+//
+// The current implementation memmaps the tensor from a file.
+//
+// Arguments:
+// dtype: Type of the returned tensor.
+// shape: Shape of the returned tensor.
+// memory_region_name: Name of readonly memory region used by the tensor, see
+// NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
+func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name}
+ opspec := tf.OpSpec{
+ Type: "ImmutableConst",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// StringJoinAttr is an optional argument to StringJoin.
+type StringJoinAttr func(optionalAttr)
+
+// StringJoinSeparator sets the optional separator attribute to value.
+//
+// value: string, an optional join separator.
+// If not specified, defaults to ""
+func StringJoinSeparator(value string) StringJoinAttr {
+ return func(m optionalAttr) {
+ m["separator"] = value
+ }
+}
+
+// Joins the strings in the given list of string tensors into one tensor;
+//
+// with the given separator (default is an empty separator).
+//
+// Arguments:
+// inputs: A list of string tensors. The tensors must all have the same shape,
+// or be scalars. Scalars may be mixed in; these will be broadcast to the shape
+// of non-scalar inputs.
+func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringJoin",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// LRNGradAttr is an optional argument to LRNGrad.
type LRNGradAttr func(optionalAttr)
@@ -8326,95 +8517,6 @@ func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, s
return op.Output(0)
}
-// ImagAttr is an optional argument to Imag.
-type ImagAttr func(optionalAttr)
-
-// ImagTout sets the optional Tout attribute to value.
-// If not specified, defaults to DT_FLOAT
-func ImagTout(value tf.DataType) ImagAttr {
- return func(m optionalAttr) {
- m["Tout"] = value
- }
-}
-
-// Returns the imaginary part of a complex number.
-//
-// Given a tensor `input` of complex numbers, this operation returns a tensor of
-// type `float` that is the imaginary part of each element in `input`. All
-// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
-// is the real part and *b* is the imaginary part returned by this operation.
-//
-// For example:
-//
-// ```
-// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
-// tf.imag(input) ==> [4.75, 5.75]
-// ```
-func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Imag",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ComplexAttr is an optional argument to Complex.
-type ComplexAttr func(optionalAttr)
-
-// ComplexTout sets the optional Tout attribute to value.
-// If not specified, defaults to DT_COMPLEX64
-func ComplexTout(value tf.DataType) ComplexAttr {
- return func(m optionalAttr) {
- m["Tout"] = value
- }
-}
-
-// Converts two real numbers to a complex number.
-//
-// Given a tensor `real` representing the real part of a complex number, and a
-// tensor `imag` representing the imaginary part of a complex number, this
-// operation returns complex numbers elementwise of the form \\(a + bj\\), where
-// *a* represents the `real` part and *b* represents the `imag` part.
-//
-// The input tensors `real` and `imag` must have the same shape.
-//
-// For example:
-//
-// ```
-// # tensor 'real' is [2.25, 3.25]
-// # tensor `imag` is [4.75, 5.75]
-// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
-// ```
-func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Complex",
- Input: []tf.Input{
- real, imag,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Divides sparse updates into the variable referenced by `resource`.
//
// This operation computes
@@ -8456,6 +8558,23 @@ func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
+// Mutually reduces multiple tensors of identical type and shape.
+func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
+ opspec := tf.OpSpec{
+ Type: "CollectiveReduce",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
type StatelessRandomNormalAttr func(optionalAttr)
@@ -11174,63 +11293,6 @@ func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataT
return op.Output(0), op.Output(1), op.Output(2)
}
-// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp.
-type ResourceApplyRMSPropAttr func(optionalAttr)
-
-// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, ms, and mom tensors is protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the RMSProp algorithm.
-//
-// Note that in dense implementation of this algorithm, ms and mom will
-// update even if the grad is zero, but in this sparse implementation, ms
-// and mom will not update in iterations during which the grad is zero.
-//
-// mean_square = decay * mean_square + (1-decay) * gradient ** 2
-// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
-//
-// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-// var <- var - mom
-//
-// Arguments:
-// var_: Should be from a Variable().
-// ms: Should be from a Variable().
-// mom: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// rho: Decay rate. Must be a scalar.
-//
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
-//
-// Returns the created operation.
-func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyRMSProp",
- Input: []tf.Input{
- var_, ms, mom, lr, rho, momentum, epsilon, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
type ResourceScatterNdUpdateAttr func(optionalAttr)
@@ -11759,23 +11821,6 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow
return op.Output(0)
}
-// Mutually reduces multiple tensors of identical type and shape.
-func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
- opspec := tf.OpSpec{
- Type: "CollectiveReduce",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// This op consumes a lock created by `MutexLock`.
//
// This op exists to consume a tensor created by `MutexLock` (other than
@@ -11877,81 +11922,6 @@ func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and
return tensors
}
-// Creates a dataset that skips `count` elements from the `input_dataset`.
-//
-// Arguments:
-//
-// count: A scalar representing the number of elements from the `input_dataset`
-// that should be skipped. If count is -1, skips everything.
-//
-//
-func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
- opspec := tf.OpSpec{
- Type: "SkipDataset",
- Input: []tf.Input{
- input_dataset, count,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the maximum along segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// Computes a tensor such that
-// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
-// that `segment_ids[j] == i`.
-//
-// If the max is empty for a given segment ID `i`, `output[i] = 0`.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension. Values should be sorted and can be repeated.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SegmentMax",
- Input: []tf.Input{
- data, segment_ids,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes hyperbolic tangent of `x` element-wise.
-func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tanh",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Receives a tensor value broadcast from another device.
func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
if scope.Err() != nil {
@@ -13665,6 +13635,170 @@ func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values
return op.Output(0), op.Output(1)
}
+// ComplexAttr is an optional argument to Complex.
+type ComplexAttr func(optionalAttr)
+
+// ComplexTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_COMPLEX64
+func ComplexTout(value tf.DataType) ComplexAttr {
+ return func(m optionalAttr) {
+ m["Tout"] = value
+ }
+}
+
+// Converts two real numbers to a complex number.
+//
+// Given a tensor `real` representing the real part of a complex number, and a
+// tensor `imag` representing the imaginary part of a complex number, this
+// operation returns complex numbers elementwise of the form \\(a + bj\\), where
+// *a* represents the `real` part and *b* represents the `imag` part.
+//
+// The input tensors `real` and `imag` must have the same shape.
+//
+// For example:
+//
+// ```
+// # tensor 'real' is [2.25, 3.25]
+// # tensor `imag` is [4.75, 5.75]
+// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
+// ```
+func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Complex",
+ Input: []tf.Input{
+ real, imag,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ImagAttr is an optional argument to Imag.
+type ImagAttr func(optionalAttr)
+
+// ImagTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_FLOAT
+func ImagTout(value tf.DataType) ImagAttr {
+ return func(m optionalAttr) {
+ m["Tout"] = value
+ }
+}
+
+// Returns the imaginary part of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// type `float` that is the imaginary part of each element in `input`. All
+// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
+// is the real part and *b* is the imaginary part returned by this operation.
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.imag(input) ==> [4.75, 5.75]
+// ```
+func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Imag",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the maximum along segments of a tensor.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Computes a tensor such that
+// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
+// that `segment_ids[j] == i`.
+//
+// If the max is empty for a given segment ID `i`, `output[i] = 0`.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// first dimension. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SegmentMax",
+ Input: []tf.Input{
+ data, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes hyperbolic tangent of `x` element-wise.
+func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tanh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a dataset that skips `count` elements from the `input_dataset`.
+//
+// Arguments:
+//
+// count: A scalar representing the number of elements from the `input_dataset`
+// that should be skipped. If count is -1, skips everything.
+//
+//
+func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "SkipDataset",
+ Input: []tf.Input{
+ input_dataset, count,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
//
// The Hurwitz zeta function is defined as:
@@ -13894,6 +14028,42 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms
return scope.AddOperation(opspec)
}
+// Computes the gradient for the inverse of `x` wrt its input.
+//
+// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy`
+// is the corresponding input gradient.
+func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReciprocalGrad",
+ Input: []tf.Input{
+ y, dy,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
+//
+// *NOTE*: `Minimum` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Minimum",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// RealAttr is an optional argument to Real.
type RealAttr func(optionalAttr)
@@ -16287,6 +16457,63 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D
return op.Output(0)
}
+// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp.
+type ResourceApplyRMSPropAttr func(optionalAttr)
+
+// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, ms, and mom tensors is protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the RMSProp algorithm.
+//
+// Note that in dense implementation of this algorithm, ms and mom will
+// update even if the grad is zero, but in this sparse implementation, ms
+// and mom will not update in iterations during which the grad is zero.
+//
+// mean_square = decay * mean_square + (1-decay) * gradient ** 2
+// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+//
+// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+// var <- var - mom
+//
+// Arguments:
+// var_: Should be from a Variable().
+// ms: Should be from a Variable().
+// mom: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// rho: Decay rate. Must be a scalar.
+//
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+//
+// Returns the created operation.
+func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyRMSProp",
+ Input: []tf.Input{
+ var_, ms, mom, lr, rho, momentum, epsilon, grad,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Returns element-wise remainder of division. This emulates C semantics in that
//
// the result here is consistent with a truncating divide. E.g. `truncate(x / y) *
@@ -17560,69 +17787,6 @@ func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.D
return op.Output(0), op.Output(1), op.Output(2)
}
-// StringJoinAttr is an optional argument to StringJoin.
-type StringJoinAttr func(optionalAttr)
-
-// StringJoinSeparator sets the optional separator attribute to value.
-//
-// value: string, an optional join separator.
-// If not specified, defaults to ""
-func StringJoinSeparator(value string) StringJoinAttr {
- return func(m optionalAttr) {
- m["separator"] = value
- }
-}
-
-// Joins the strings in the given list of string tensors into one tensor;
-//
-// with the given separator (default is an empty separator).
-//
-// Arguments:
-// inputs: A list of string tensors. The tensors must all have the same shape,
-// or be scalars. Scalars may be mixed in; these will be broadcast to the shape
-// of non-scalar inputs.
-func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringJoin",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns immutable tensor from memory region.
-//
-// The current implementation memmaps the tensor from a file.
-//
-// Arguments:
-// dtype: Type of the returned tensor.
-// shape: Shape of the returned tensor.
-// memory_region_name: Name of readonly memory region used by the tensor, see
-// NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
-func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name}
- opspec := tf.OpSpec{
- Type: "ImmutableConst",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Inverse real-valued fast Fourier transform.
//
// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued
@@ -19195,88 +19359,58 @@ func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Forwards the input to the output.
-//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
-//
-// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
-//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LoopCond",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
+// RandomGammaAttr is an optional argument to RandomGamma.
+type RandomGammaAttr func(optionalAttr)
-// Computes the gradient for the inverse of `x` wrt its input.
+// RandomGammaSeed sets the optional seed attribute to value.
//
-// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy`
-// is the corresponding input gradient.
-func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReciprocalGrad",
- Input: []tf.Input{
- y, dy,
- },
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomGammaSeed(value int64) RandomGammaAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
}
-// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
+// RandomGammaSeed2 sets the optional seed2 attribute to value.
//
-// *NOTE*: `Minimum` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Minimum",
- Input: []tf.Input{
- x, y,
- },
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomGammaSeed2(value int64) RandomGammaAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
}
-// Returns the element-wise sum of a list of tensors.
-//
-// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
-// wait for all of its inputs to be ready before beginning to sum. This can
-// save memory if inputs are ready at different times, since minimum temporary
-// storage is proportional to the output size rather than the inputs size.
-//
-// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
+// Outputs random values from the Gamma distribution(s) described by alpha.
//
-// Returns a `Tensor` of same shape and type as the elements of `inputs`.
+// This op uses the algorithm by Marsaglia et al. to acquire samples via
+// transformation-rejection from pairs of uniform and normal random variables.
+// See http://dl.acm.org/citation.cfm?id=358414
//
// Arguments:
-// inputs: A list of `Tensor` objects, each with same shape and type.
-// shape: Shape of elements of `inputs`.
-func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
+// shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in alpha.
+// alpha: A tensor in which each scalar is a "shape" parameter describing the
+// associated gamma distribution.
+//
+// Returns A tensor with shape `shape + shape(alpha)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
+func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"shape": shape}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "AccumulateNV2",
+ Type: "RandomGamma",
Input: []tf.Input{
- tf.OutputList(inputs),
+ shape, alpha,
},
Attrs: attrs,
}
@@ -19332,60 +19466,24 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// RandomGammaAttr is an optional argument to RandomGamma.
-type RandomGammaAttr func(optionalAttr)
-
-// RandomGammaSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomGammaSeed(value int64) RandomGammaAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomGammaSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomGammaSeed2(value int64) RandomGammaAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from the Gamma distribution(s) described by alpha.
+// Forwards the input to the output.
//
-// This op uses the algorithm by Marsaglia et al. to acquire samples via
-// transformation-rejection from pairs of uniform and normal random variables.
-// See http://dl.acm.org/citation.cfm?id=358414
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
//
// Arguments:
-// shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in alpha.
-// alpha: A tensor in which each scalar is a "shape" parameter describing the
-// associated gamma distribution.
+// input: A boolean scalar, representing the branch predicate of the Switch op.
//
-// Returns A tensor with shape `shape + shape(alpha)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
-func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) {
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "RandomGamma",
+ Type: "LoopCond",
Input: []tf.Input{
- shape, alpha,
+ input,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -25031,6 +25129,41 @@ func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, o
return op.Output(0)
}
+// Runs multiple additive regression ensemble predictors on input instances and
+//
+// computes the update to cached logits. It is designed to be used during training.
+// It traverses the trees starting from cached tree id and cached node id and
+// calculates the updates to be pushed to the cache.
+//
+// Arguments:
+//
+// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
+// tree of prediction.
+// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
+// node of prediction.
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for partial logits
+// shape.
+//
+// Returns Rank 2 Tensor containing logits update (with respect to cached
+// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
+func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesTrainingPredict",
+ Input: []tf.Input{
+ tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// MapSizeAttr is an optional argument to MapSize.
type MapSizeAttr func(optionalAttr)
@@ -29790,41 +29923,6 @@ func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Outpu
return scope.AddOperation(opspec)
}
-// Runs multiple additive regression ensemble predictors on input instances and
-//
-// computes the update to cached logits. It is designed to be used during training.
-// It traverses the trees starting from cached tree id and cached node id and
-// calculates the updates to be pushed to the cache.
-//
-// Arguments:
-//
-// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
-// tree of prediction.
-// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
-// node of prediction.
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for partial logits
-// shape.
-//
-// Returns Rank 2 Tensor containing logits update (with respect to cached
-// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
-func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
- opspec := tf.OpSpec{
- Type: "BoostedTreesTrainingPredict",
- Input: []tf.Input{
- tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Elementwise computes the bitwise AND of `x` and `y`.
//
// The result will have those bits set, that are set in both `x` and `y`. The
@@ -30612,101 +30710,3 @@ func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset
}
return offset
}
-
-// Splits a tensor into `num_split` tensors along one dimension.
-//
-// Arguments:
-// axis: 0-D. The dimension along which to split. Must be in the range
-// `[-rank(value), rank(value))`.
-// value: The tensor to split.
-// num_split: The number of ways to split. Must evenly divide
-// `value.shape[split_dim]`.
-//
-// Returns They are identically shaped tensors, whose shape matches that of `value`
-// except along `axis`, where their sizes are
-// `values.shape[split_dim] / num_split`.
-func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_split": num_split}
- opspec := tf.OpSpec{
- Type: "Split",
- Input: []tf.Input{
- axis, value,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- scope.UpdateErr("Split", err)
- return
- }
- return output
-}
-
-// Splits a tensor into `num_split` tensors along one dimension.
-//
-// Arguments:
-// value: The tensor to split.
-// size_splits: list containing the sizes of each output tensor along the split
-// dimension. Must sum to the dimension of value along split_dim.
-// Can contain one -1 indicating that dimension is to be inferred.
-// axis: 0-D. The dimension along which to split. Must be in the range
-// `[-rank(value), rank(value))`.
-//
-//
-// Returns Tensors whose shape matches that of `value`
-// except along `axis`, where their sizes are
-// `size_splits[i]`.
-func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_split": num_split}
- opspec := tf.OpSpec{
- Type: "SplitV",
- Input: []tf.Input{
- value, size_splits, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- scope.UpdateErr("SplitV", err)
- return
- }
- return output
-}
-
-// Gives a guarantee to the TF runtime that the input tensor is a constant.
-//
-// The runtime is then free to make optimizations based on this.
-//
-// Only accepts value typed tensors as inputs and rejects resource variable handles
-// as input.
-//
-// Returns the input tensor without modification.
-func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "GuaranteeConst",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index 08cc860f57..38e87b1639 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index fcc7eacc33..36c984e280 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 3d22d86a49..4c846de05a 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 0a09a5ea7c..f2a0a97eae 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 77ec6a0ddb..eb0a952c7d 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
@@ -16,7 +16,7 @@
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
- <version>3.3.1</version>
+ <version>3.5.1</version>
</dependency>
</dependencies>
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 6136ccfdfb..bf19c09b1d 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -31,7 +31,7 @@ if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then
# Bintray does not allow snapshots.
DEPLOY_BINTRAY="false"
fi
-PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.3.0/protoc-3.3.0-linux-x86_64.zip"
+PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip"
if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then
echo "Must deploy to at least one of Bintray or OSSRH" >&2
exit 2
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 0df1f28149..48668a47f2 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index debd95fc62..9b171f66ec 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -376,9 +376,6 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
}
}
// op annotations
- op_class.add_annotation(
- Annotation::Create("Generated", "javax.annotation")
- .attributes("value = \"TensorFlow Java Op Generator\""));
if (endpoint.deprecated()) {
op_class.add_annotation(Annotation::Create("Deprecated"));
string explanation;
@@ -415,8 +412,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
SourceFileWriter writer(op_file.get());
std::list<Type> dependencies;
CollectOpDependencies(op, mode, &dependencies);
- writer.Write(kLicense).EndLine().BeginType(op_class, PUBLIC | FINAL,
- &dependencies, &op_javadoc);
+ writer.Write(kLicense)
+ .EndLine()
+ .Write("// This class has been generated, DO NOT EDIT!")
+ .EndLine()
+ .EndLine()
+ .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc);
if (!op.optional_attributes().empty()) {
RenderOptionsClass(op, op_class, &writer);
}
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5507d011bb..648e35cdf2 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -619,21 +619,12 @@ class BaseSession(SessionInterface):
self._config = None
self._add_shapes = False
- # pylint: disable=protected-access
- # We cache _USE_C_API's value because some test cases will create a session
- # with _USE_C_API = False but set it back to True before calling close().
- self._created_with_new_api = ops._USE_C_API
- # pylint: enable=protected-access
-
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
- if self._created_with_new_api:
- # pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
- # pylint: enable=protected-access
- else:
- self._session = tf_session.TF_NewDeprecatedSession(opts)
+ # pylint: disable=protected-access
+ self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ # pylint: enable=protected-access
finally:
tf_session.TF_DeleteSessionOptions(opts)
@@ -660,11 +651,7 @@ class BaseSession(SessionInterface):
Returns:
A list of devices in the session.
"""
- if self._created_with_new_api:
- raw_device_list = tf_session.TF_SessionListDevices(self._session)
- else:
- raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
- self._session)
+ raw_device_list = tf_session.TF_SessionListDevices(self._session)
device_list = []
size = tf_session.TF_DeviceListCount(raw_device_list)
for i in range(size):
@@ -684,16 +671,9 @@ class BaseSession(SessionInterface):
tf.errors.OpError: Or one of its subclasses if an error occurs while
closing the TensorFlow session.
"""
- if self._created_with_new_api:
- if self._session and not self._closed:
- self._closed = True
- tf_session.TF_CloseSession(self._session)
-
- else:
- with self._extend_lock:
- if self._opened and not self._closed:
- self._closed = True
- tf_session.TF_CloseDeprecatedSession(self._session)
+ if self._session and not self._closed:
+ self._closed = True
+ tf_session.TF_CloseSession(self._session)
def __del__(self):
# cleanly ignore all exceptions
@@ -703,10 +683,7 @@ class BaseSession(SessionInterface):
pass
if self._session is not None:
try:
- if self._created_with_new_api:
- tf_session.TF_DeleteSession(self._session)
- else:
- tf_session.TF_DeleteDeprecatedSession(self._session)
+ tf_session.TF_DeleteSession(self._session)
except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case,
@@ -1005,12 +982,9 @@ class BaseSession(SessionInterface):
try:
subfeed_t = self.graph.as_graph_element(
subfeed, allow_tensor=True, allow_operation=False)
- if self._created_with_new_api:
- # pylint: disable=protected-access
- feed_list.append(subfeed_t._as_tf_output())
- # pylint: enable=protected-access
- else:
- feed_list.append(compat.as_bytes(subfeed_t.name))
+ # pylint: disable=protected-access
+ feed_list.append(subfeed_t._as_tf_output())
+ # pylint: enable=protected-access
except Exception as e:
e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
e.args = (e.message,)
@@ -1023,22 +997,13 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph()
- if self._created_with_new_api:
- return tf_session.TF_SessionPRunSetup_wrapper(
- session, feed_list, fetch_list, target_list)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
- target_list, status)
+ return tf_session.TF_SessionPRunSetup_wrapper(
+ session, feed_list, fetch_list, target_list)
- if self._created_with_new_api:
- # pylint: disable=protected-access
- final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
- final_targets = [op._c_op for op in fetch_handler.targets()]
- # pylint: enable=protected-access
- else:
- final_fetches = _name_list(fetch_handler.fetches())
- final_targets = _name_list(fetch_handler.targets())
+ # pylint: disable=protected-access
+ final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
+ final_targets = [op._c_op for op in fetch_handler.targets()]
+ # pylint: enable=protected-access
return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
final_targets)
@@ -1196,14 +1161,10 @@ class BaseSession(SessionInterface):
# Create a fetch handler to take care of the structure of fetches.
fetch_handler = _FetchHandler(self._graph, fetches, {})
- if self._created_with_new_api:
- # pylint: disable=protected-access
- fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
- target_list = [op._c_op for op in fetch_handler.targets()]
- # pylint: enable=protected-access
- else:
- fetch_list = _name_list(fetch_handler.fetches())
- target_list = _name_list(fetch_handler.targets())
+ # pylint: disable=protected-access
+ fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
+ target_list = [op._c_op for op in fetch_handler.targets()]
+ # pylint: enable=protected-access
def _callable_template_with_options_and_metadata(fetch_list,
target_list,
@@ -1289,16 +1250,11 @@ class BaseSession(SessionInterface):
Raises:
tf.errors.OpError: Or one of its subclasses on error.
"""
- if self._created_with_new_api:
- # pylint: disable=protected-access
- feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
- fetches = [t._as_tf_output() for t in fetch_list]
- targets = [op._c_op for op in target_list]
- # pylint: enable=protected-access
- else:
- feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
- fetches = _name_list(fetch_list)
- targets = _name_list(target_list)
+ # pylint: disable=protected-access
+ feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
+ fetches = [t._as_tf_output() for t in fetch_list]
+ targets = [op._c_op for op in target_list]
+ # pylint: enable=protected-access
def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
# Ensure any changes to the graph are reflected in the runtime.
@@ -1335,22 +1291,8 @@ class BaseSession(SessionInterface):
raise type(e)(node_def, op, message)
def _extend_graph(self):
- if self._created_with_new_api:
- with self._graph._lock: # pylint: disable=protected-access
- tf_session.ExtendSession(self._session)
- else:
- # Ensure any changes to the graph are reflected in the runtime.
- with self._extend_lock:
- if self._graph.version > self._current_version:
- # pylint: disable=protected-access
- graph_def, self._current_version = self._graph._as_graph_def(
- from_version=self._current_version, add_shapes=self._add_shapes)
- # pylint: enable=protected-access
-
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_ExtendGraph(self._session,
- graph_def.SerializeToString(), status)
- self._opened = True
+ with self._graph._lock: # pylint: disable=protected-access
+ tf_session.ExtendSession(self._session)
# The threshold to run garbage collection to delete dead tensors.
_DEAD_HANDLES_THRESHOLD = 10
@@ -1403,24 +1345,13 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata):
- if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- self._session, options, feed_dict, fetch_list, target_list,
- run_metadata)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_Run(
- self._session, options, feed_dict, fetch_list, target_list,
- status, run_metadata)
+ return tf_session.TF_SessionRun_wrapper(
+ self._session, options, feed_dict, fetch_list, target_list,
+ run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
- if self._created_with_new_api:
- return tf_session.TF_SessionPRun_wrapper(
- self._session, handle, feed_dict, fetch_list)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_PRun(
- self._session, handle, feed_dict, fetch_list, status)
+ return tf_session.TF_SessionPRun_wrapper(
+ self._session, handle, feed_dict, fetch_list)
# pylint: disable=protected-access
class _Callable(object):
@@ -1433,12 +1364,8 @@ class BaseSession(SessionInterface):
compat.as_bytes(callable_options.SerializeToString()))
try:
with errors.raise_exception_on_not_ok_status() as status:
- if session._created_with_new_api:
- self._handle = tf_session.TF_SessionMakeCallable(
- session._session, options_ptr, status)
- else:
- self._handle = tf_session.TF_DeprecatedSessionMakeCallable(
- session._session, options_ptr, status)
+ self._handle = tf_session.TF_SessionMakeCallable(
+ session._session, options_ptr, status)
finally:
tf_session.TF_DeleteBuffer(options_ptr)
@@ -1446,12 +1373,8 @@ class BaseSession(SessionInterface):
# TODO(b/74355905): Support argument and return value nested structures,
# and tensor-like objects such as SparseTensors.
with errors.raise_exception_on_not_ok_status() as status:
- if self._session._created_with_new_api:
- return tf_session.TF_SessionRunCallable(
- self._session._session, self._handle, args, status, None)
- else:
- return tf_session.TF_DeprecatedSessionRunCallable(
- self._session._session, self._handle, args, status, None)
+ return tf_session.TF_SessionRunCallable(
+ self._session._session, self._handle, args, status, None)
def __del__(self):
# NOTE(mrry): It is possible that `self._session.__del__()` could be
@@ -1459,12 +1382,8 @@ class BaseSession(SessionInterface):
# will be `None`.
if self._handle is not None and self._session._session is not None:
with errors.raise_exception_on_not_ok_status() as status:
- if self._session._created_with_new_api:
- tf_session.TF_SessionReleaseCallable(
- self._session._session, self._handle, status)
- else:
- tf_session.TF_DeprecatedSessionReleaseCallable(
- self._session._session, self._handle, status)
+ tf_session.TF_SessionReleaseCallable(
+ self._session._session, self._handle, status)
# pylint: enable=protected-access
# TODO(b/74355905): Reimplement `Session.make_callable()` using this method
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index c8fabc4363..e86c2f6993 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index bd80b9dbf5..50bb0837b7 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -18,8 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
@@ -35,73 +34,83 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase):
+class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('even', 28, 14, False),
+ ('uneven_with_remainder', 28, 15, False),
+ ('uneven_without_remainder', 28, 15, True),
+ ('empty', 0, 14, False),
+ )
+ def testBatchDataset(self, count, batch_size, drop_remainder):
+ """Tests the batch dataset logic for various input configurations.
+
+ Args:
+ count: the number of input elements
+ batch_size: the batch size
+ drop_remainder: whether a smaller batch size should be produced if batch
+ size does not divide number of inputs evenly
+ """
- def testBatchDataset(self):
- """Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count) -> BatchDataset(batch_size).
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count).batch(batch_size).make_initializable_iterator())
+ .repeat(count).batch(batch_size,
+ drop_remainder).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ if drop_remainder:
+ dim0 = batch_size
+ else:
+ dim0 = None
+ self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
with self.test_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ batch_size_t: batch_size,
+ drop_remainder_t: drop_remainder
+ })
+ num_full_batches = (count * 7) // batch_size
+ for i in range(num_full_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
+ for j in range(batch_size):
+ self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
+ if not drop_remainder and (count * 7) % batch_size > 0:
result = sess.run(get_next)
for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
+ for j in range((count * 7) % batch_size):
+ self.assertAllEqual(
+ component[(num_full_batches * batch_size + j) % 7]**2,
+ result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ def testBatchDatasetInvalidBatchSize(self):
+ iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
+ get_next = iterator.get_next()
- # Empty batch should be an initialization time error.
+ with self.test_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
+ sess.run(get_next)
def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
@@ -210,66 +219,108 @@ class BatchDatasetTest(test.TestCase):
r'First element had shape \[3\] and element 2 had shape \[4\].'):
sess.run(next_element)
- def testPaddedBatchDataset(self):
- seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
- padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
+
+def _random_seq_lens(count):
+ return np.random.randint(20, size=(count,)).astype(np.int32)
+
+
+class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('default_padding', _random_seq_lens(32), 4, [-1], False),
+ ('constant_padding', _random_seq_lens(32), 4, [25], False),
+ ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
+ ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
+ )
+ def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
+ drop_remainder):
+ """Tests the padded batch dataset logic for various input configurations.
+
+ Args:
+ seq_lens: the input sequence lengths
+ batch_size: the batch size
+ padded_shapes: the padded shapes to use
+ drop_remainder: whether a smaller batch size should be produced if batch
+ size does not divide number of inputs evenly
+ """
+
+ seq_lens_t = array_ops.placeholder(dtypes.int32, shape=[None])
+ batch_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ padded_shapes_t = array_ops.placeholder(dtypes.int64, shape=[1])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
iterator = (
- dataset_ops.Dataset.from_tensor_slices(seq_lens)
+ dataset_ops.Dataset.from_tensor_slices(seq_lens_t)
.map(lambda x: array_ops.fill([x], x)).padded_batch(
- 4, padded_shapes=padded_shape).make_initializable_iterator())
+ batch_size=batch_size_t,
+ drop_remainder=drop_remainder_t,
+ padded_shapes=padded_shapes_t).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
- # Test with random sequence lengths, and max padding.
- random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(
- init_op, feed_dict={
- padded_shape: [-1],
- seq_lens: random_seq_lens
+ init_op,
+ feed_dict={
+ seq_lens_t: seq_lens,
+ batch_size_t: batch_size,
+ padded_shapes_t: padded_shapes,
+ drop_remainder_t: drop_remainder,
})
- for i in range(8):
+
+ num_full_batches = len(seq_lens) // batch_size
+
+ for i in range(num_full_batches):
result = sess.run(get_next)
- padded_len = np.max(result)
- self.assertEqual((4, padded_len), result.shape)
- for j in range(4):
- seq_len = random_seq_lens[(i * 4) + j]
+ padded_len = padded_shapes[0]
+ if padded_len is None or padded_len == -1:
+ padded_len = np.max(result)
+ self.assertEqual((batch_size, padded_len), result.shape)
+ for j in range(batch_size):
+ seq_len = seq_lens[(i * batch_size) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.assertAllEqual(result[j, seq_len:],
+ [0] * (padded_len - seq_len))
- # Test with random sequence lengths, and constant padding.
- sess.run(
- init_op, feed_dict={
- padded_shape: [25],
- seq_lens: random_seq_lens
- })
- for i in range(8):
+ if not drop_remainder and len(seq_lens) % batch_size > 0:
result = sess.run(get_next)
- self.assertEqual((4, 25), result.shape)
- for j in range(4):
- seq_len = random_seq_lens[(i * 4) + j]
+ padded_len = np.max(result)
+ self.assertEqual((len(seq_lens) % batch_size, padded_len),
+ result.shape)
+ for j in range(len(seq_lens) % batch_size):
+ seq_len = seq_lens[num_full_batches * batch_size + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
+ self.assertAllEqual(result[j, seq_len:],
+ [0] * (padded_len - seq_len))
+
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Test correct handling of empty tensors.
- sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
+ def testPaddedBatchShortPadding(self):
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices([6, 5, 5, 5, 5])
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaises(errors.DataLossError):
+ sess.run(get_next)
+
+ def testPaddedBatchEmptyTensors(self):
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices([0, 0, 0, 0])
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Test error handling with constant sequence lengths, and
- # too-short padding.
- sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
- with self.assertRaises(errors.DataLossError):
- result = sess.run(get_next)
-
def testPaddedBatchDatasetNonDefaultPadding(self):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
@@ -371,6 +422,44 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(TypeError):
_ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
+ def testPaddedBatchShapeError(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1])
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(3,\).'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=[1])
+
+ with self.assertRaisesRegexp(
+ ValueError, r'Padded shape .* must be a 1-D tensor '
+ r'of tf.int64 values, but its shape was \(2, 2\).'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=[[1, 1], [1, 1]])
+
+ with self.assertRaisesRegexp(
+ TypeError, r'Padded shape .* must be a 1-D tensor '
+ r'of tf.int64 values, but its element type was float32.'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=constant_op.constant([1., 2., 3.]))
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
+ _ = dataset_ops.Dataset.range(10).padded_batch(
+ 5, padded_shapes=shape_as_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(\?, \?\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
+ _ = dataset_ops.Dataset.range(10).padded_batch(
+ 5, padded_shapes=shape_as_tensor)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 1ad0b9de5e..768d4ac82c 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from collections import namedtuple
import threading
import time
+import warnings
import numpy as np
@@ -638,6 +639,26 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testWarnOnLookupTable(self):
+ def collecting_function(x):
+ _ = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer([], []), 0.0, name="t1")
+ return x
+
+ warnings.simplefilter("always")
+ with warnings.catch_warnings(record=True) as w:
+ _ = dataset_ops.Dataset.range(10).map(collecting_function)
+ # NOTE(mrry): Python 3 prints other warnings in addition to the one we are
+ # testing, so we search for the expected warning.
+ self.assertGreaterEqual(len(w), 1)
+ found_warning = False
+ for warning in w:
+ if ("Creating lookup tables inside a function passed to Dataset.map() is "
+ "not supported." in str(warning)):
+ found_warning = True
+ break
+ self.assertTrue(found_warning)
+
class MapDatasetBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 5f17444797..672ce014f6 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import abc
import threading
+import warnings
import numpy as np
import six
@@ -32,6 +33,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -790,7 +792,7 @@ class Dataset(object):
return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)
- def batch(self, batch_size):
+ def batch(self, batch_size, drop_remainder=False):
"""Combines consecutive elements of this dataset into batches.
NOTE: If the number of elements (`N`) in this dataset is not an exact
@@ -802,13 +804,21 @@ class Dataset(object):
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in the case its has fewer than
+ `batch_size` elements; the default behavior is not to drop the smaller
+ batch.
Returns:
Dataset: A `Dataset`.
"""
- return BatchDataset(self, batch_size)
+ return BatchDataset(self, batch_size, drop_remainder)
- def padded_batch(self, batch_size, padded_shapes, padding_values=None):
+ def padded_batch(self,
+ batch_size,
+ padded_shapes,
+ padding_values=None,
+ drop_remainder=False):
"""Combines consecutive elements of this dataset into padded batches.
This transformation combines multiple consecutive elements of the input
@@ -851,11 +861,16 @@ class Dataset(object):
`tf.Tensor`, representing the padding values to use for the
respective components. Defaults are `0` for numeric types and
the empty string for string types.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in the case its has fewer than
+ `batch_size` elements; the default behavior is not to drop the smaller
+ batch.
Returns:
Dataset: A `Dataset`.
"""
- return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values)
+ return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
+ drop_remainder)
def map(self, map_func, num_parallel_calls=None):
"""Maps `map_func` across this dataset.
@@ -1654,21 +1669,34 @@ class SkipDataset(Dataset):
class BatchDataset(Dataset):
"""A `Dataset` that batches contiguous elements from its input."""
- def __init__(self, input_dataset, batch_size):
+ def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
super(BatchDataset, self).__init__()
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
def _as_variant_tensor(self):
- return gen_dataset_ops.batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- batch_size=self._batch_size,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
+ if smart_cond.smart_constant_value(self._drop_remainder) is False:
+ return gen_dataset_ops.batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
+ else:
+ return gen_dataset_ops.batch_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ drop_remainder=self._drop_remainder,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
@property
def output_classes(self):
@@ -1678,7 +1706,9 @@ class BatchDataset(Dataset):
def output_shapes(self):
input_shapes = self._input_dataset.output_shapes
return nest.pack_sequence_as(input_shapes, [
- tensor_shape.vector(None).concatenate(s)
+ tensor_shape.vector(
+ tensor_util.constant_value(self._batch_size) if smart_cond.
+ smart_constant_value(self._drop_remainder) else None).concatenate(s)
for s in nest.flatten(self._input_dataset.output_shapes)
])
@@ -1687,20 +1717,77 @@ class BatchDataset(Dataset):
return self._input_dataset.output_types
-def _partial_shape_to_tensor(shape_like):
+def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
+ """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
+
+ Args:
+ padded_shape: A `tf.TensorShape`.
+ input_component_shape: A `tf.TensorShape`.
+
+ Returns:
+ `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
+ `False`.
+ """
+
+ if padded_shape.dims is None or input_component_shape.dims is None:
+ return True
+ if len(padded_shape.dims) != len(input_component_shape.dims):
+ return False
+ for padded_dim, input_dim in zip(
+ padded_shape.dims, input_component_shape.dims):
+ if (padded_dim.value is not None and input_dim.value is not None
+ and padded_dim.value < input_dim.value):
+ return False
+ return True
+
+
+def _padded_shape_to_tensor(padded_shape, input_component_shape):
+ """Converts `padded_shape` to a `tf.Tensor` representing that shape.
+
+ Args:
+ padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
+ sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
+ input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
+ be compatible.
+
+ Returns:
+ A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
+
+ Raises:
+ ValueError: If `padded_shape` is not a shape or not compatible with
+ `input_component_shape`.
+ TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
+ """
try:
- # First attempt to convert the input to a shape, and return the
- # "canonical" tensor representation, which uses `-1` in place of
- # `None`.
- shape_like = tensor_shape.as_shape(shape_like)
- return ops.convert_to_tensor(
- [dim if dim is not None else -1 for dim in shape_like.as_list()],
- dtype=dtypes.int64)
+ # Try to convert the `padded_shape` to a `tf.TensorShape`
+ padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
+ # We will return the "canonical" tensor representation, which uses
+ # `-1` in place of `None`.
+ ret = ops.convert_to_tensor(
+ [dim if dim is not None else -1
+ for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
except (TypeError, ValueError):
# The argument was not trivially convertible to a
# `tf.TensorShape`, so fall back on the conversion to tensor
# machinery.
- return ops.convert_to_tensor(shape_like, dtype=dtypes.int64)
+ ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
+ if ret.shape.dims is not None and len(ret.shape.dims) != 1:
+ raise ValueError(
+ "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
+ "shape was %s." % (padded_shape, ret.shape))
+ if ret.dtype != dtypes.int64:
+ raise TypeError(
+ "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
+ "element type was %s." % (padded_shape, ret.dtype.name))
+ padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
+
+ if not _is_padded_shape_compatible_with(padded_shape_as_shape,
+ input_component_shape):
+ raise ValueError("The padded shape %s is not compatible with the "
+ "corresponding input component shape %s."
+ % (padded_shape_as_shape, input_component_shape))
+
+ return ret
def _padding_value_to_tensor(value, output_type):
@@ -1742,7 +1829,8 @@ def _default_padding(input_dataset):
class PaddedBatchDataset(Dataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
- def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
+ def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
+ drop_remainder):
"""See `Dataset.batch()` for details."""
super(PaddedBatchDataset, self).__init__()
if sparse.any_sparse(input_dataset.output_classes):
@@ -1755,23 +1843,51 @@ class PaddedBatchDataset(Dataset):
padding_values = (
padding_values
if padding_values is not None else _default_padding(input_dataset))
- self._padded_shapes = nest.map_structure_up_to(
- input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes)
+
+ flat_padded_shapes = nest.flatten_up_to(input_dataset.output_shapes,
+ padded_shapes)
+
+ flat_padded_shapes_as_tensors = []
+
+ for input_component_shape, padded_shape in zip(
+ nest.flatten(input_dataset.output_shapes), flat_padded_shapes):
+ flat_padded_shapes_as_tensors.append(
+ _padded_shape_to_tensor(padded_shape, input_component_shape))
+
+ self._padded_shapes = nest.pack_sequence_as(input_dataset.output_shapes,
+ flat_padded_shapes_as_tensors)
+
self._padding_values = nest.map_structure_up_to(
input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
input_dataset.output_types)
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
def _as_variant_tensor(self):
- return gen_dataset_ops.padded_batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- batch_size=self._batch_size,
- padded_shapes=[
- ops.convert_to_tensor(s, dtype=dtypes.int64)
- for s in nest.flatten(self._padded_shapes)
- ],
- padding_values=nest.flatten(self._padding_values),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
+ if smart_cond.smart_constant_value(self._drop_remainder) is False:
+ return gen_dataset_ops.padded_batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ padded_shapes=[
+ ops.convert_to_tensor(s, dtype=dtypes.int64)
+ for s in nest.flatten(self._padded_shapes)
+ ],
+ padding_values=nest.flatten(self._padding_values),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ else:
+ return gen_dataset_ops.padded_batch_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ padded_shapes=[
+ ops.convert_to_tensor(s, dtype=dtypes.int64)
+ for s in nest.flatten(self._padded_shapes)
+ ],
+ padding_values=nest.flatten(self._padding_values),
+ drop_remainder=self._drop_remainder,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
@property
def output_classes(self):
@@ -1781,8 +1897,10 @@ class PaddedBatchDataset(Dataset):
def output_shapes(self):
def _padded_shape_to_batch_shape(s):
- return tensor_shape.vector(None).concatenate(
- tensor_util.constant_value_as_shape(s))
+ return tensor_shape.vector(
+ tensor_util.constant_value(self._batch_size) if smart_cond.
+ smart_constant_value(self._drop_remainder) else None).concatenate(
+ tensor_util.constant_value_as_shape(s))
return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)
@@ -1796,6 +1914,24 @@ def _should_unpack_args(args):
return type(args) is tuple # pylint: disable=unidiomatic-typecheck
+def _warn_if_collections(transformation_name):
+ """Prints warning message if the current graph uses common graph collections.
+
+ NOTE(mrry): Currently a warning is only generated for lookup tables. Any
+ variables created will be automatically hoisted out to the outermost scope
+ using `init_scope()`. Some collections (such as for control-flow contexts)
+ are benign and should not generate a warning.
+
+ Args:
+ transformation_name: A human-readable name for the transformation.
+ """
+ if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS):
+ warnings.warn("Creating lookup tables inside a function passed to %s is not"
+ " supported. Create each table outside the function, and "
+ "capture it inside the function to use it."
+ % transformation_name)
+
+
class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input."""
@@ -1855,6 +1991,8 @@ class MapDataset(Dataset):
self._output_types = nest.pack_sequence_as(
ret, [t.dtype for t in nest.flatten(ret)])
+ _warn_if_collections("Dataset.map()")
+
# Serialize any sparse tensors.
ret = nest.pack_sequence_as(
ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
@@ -1943,6 +2081,8 @@ class FlatMapDataset(Dataset):
if not isinstance(dataset, Dataset):
raise TypeError("`map_func` must return a `Dataset` object.")
+ _warn_if_collections(self._transformation_name())
+
self._output_classes = dataset.output_classes
self._output_types = dataset.output_types
self._output_shapes = dataset.output_shapes
@@ -1974,6 +2114,9 @@ class FlatMapDataset(Dataset):
def output_types(self):
return self._output_types
+ def _transformation_name(self):
+ return "Dataset.flat_map()"
+
class InterleaveDataset(FlatMapDataset):
"""A `Dataset` that maps a function over its input and interleaves the result.
@@ -1999,6 +2142,9 @@ class InterleaveDataset(FlatMapDataset):
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ def _transformation_name(self):
+ return "Dataset.interleave()"
+
class FilterDataset(Dataset):
"""A `Dataset` that filters its input according to a predicate function."""
@@ -2033,6 +2179,8 @@ class FilterDataset(Dataset):
ret.shape.is_compatible_with(tensor_shape.scalar())):
raise ValueError("`predicate` must return a scalar boolean tensor.")
+ _warn_if_collections("Dataset.filter()")
+
return ret
self._predicate = tf_predicate
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index a73a8b5cdc..6a72ed380f 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -156,6 +156,9 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
# pylint: enable=protected-access
+ def _transformation_name(self):
+ return "tf.contrib.data.parallel_interleave()"
+
@tf_export("data.TFRecordDataset")
class TFRecordDataset(dataset_ops.Dataset):
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 0fc32d51b9..5fcc62b60b 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -70,6 +70,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
],
)
diff --git a/tensorflow/python/data/util/convert.py b/tensorflow/python/data/util/convert.py
index eeb1d700f3..99b3300900 100644
--- a/tensorflow/python/data/util/convert.py
+++ b/tensorflow/python/data/util/convert.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
def optional_param_to_tensor(argument_name,
@@ -32,3 +33,39 @@ def optional_param_to_tensor(argument_name,
else:
return constant_op.constant(
argument_default, dtype=argument_dtype, name=argument_name)
+
+
+def partial_shape_to_tensor(shape_like):
+ """Returns a @{tf.Tensor} that represents the given shape.
+
+ Args:
+ shape_like: A value that can be converted to a @{tf.TensorShape} or a
+ @{tf.Tensor}.
+
+ Returns:
+ A 1-D `tf.Tensor` of `tf.int64` elements representing the given shape, where
+ `-1` is substituted for any unknown dimensions.
+ """
+ try:
+ # First attempt to convert the input to a shape, and return the
+ # "canonical" tensor representation, which uses `-1` in place of
+ # `None`.
+ shape_like = tensor_shape.as_shape(shape_like)
+ return ops.convert_to_tensor(
+ [dim if dim is not None else -1 for dim in shape_like.as_list()],
+ dtype=dtypes.int64)
+ except (TypeError, ValueError):
+ # The argument was not trivially convertible to a
+ # `tf.TensorShape`, so fall back on the conversion to tensor
+ # machinery.
+ ret = ops.convert_to_tensor(shape_like, preferred_dtype=dtypes.int64)
+ if ret.shape.dims is not None and len(ret.shape.dims) != 1:
+ raise ValueError("The given shape %s must be a 1-D tensor of tf.int64 "
+ "values, but the shape was %s."
+ % (shape_like, ret.shape))
+ if ret.dtype != dtypes.int64:
+ raise TypeError("The given shape %s must be a 1-D tensor of tf.int64 "
+ "values, but the element type was %s."
+ % (shape_like, ret.dtype.name))
+
+ return ret
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 2cb6488070..6a67093e48 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -19,7 +19,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import convert
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -48,6 +50,77 @@ class ConvertTest(test.TestCase):
with self.test_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
+ def testPartialShapeToTensorKnownDimension(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([1]))))
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor([1])))
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([1], dtype=dtypes.int64))))
+
+ def testPartialShapeToTensorUnknownDimension(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([None]))))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ (None,))))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ [None])))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ [-1])))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([-1], dtype=dtypes.int64))))
+
+ with self.assertRaisesRegexp(
+ ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 "
+ r"values, but the shape was \(2, 2\)."):
+ convert.partial_shape_to_tensor(constant_op.constant(
+ [[1, 1], [1, 1]], dtype=dtypes.int64))
+
+ with self.assertRaisesRegexp(
+ TypeError, r"The given shape .* must be a 1-D tensor of tf.int64 "
+ r"values, but the element type was float32."):
+ convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
+
+ def testPartialShapeToTensorMultipleDimensions(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([3, 6]))))
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ (3, 6))))
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ [3, 6])))
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([3, 6], dtype=dtypes.int64))))
+
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([3, None]))))
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ (3, None))))
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ [3, None])))
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([3, -1], dtype=dtypes.int64))))
+
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([None, None]))))
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ (None, None))))
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ [None, None])))
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([-1, -1], dtype=dtypes.int64))))
+
+ def testPartialShapeToTensorScalar(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([]))))
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor([])))
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([], dtype=dtypes.int64))))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 559063d6ae..03393bcd46 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -409,7 +409,15 @@ class GraphModeFunction(object):
backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
def _backprop_call(self, args):
- """Calls the wrapped function and records the result on a tape."""
+ """Calls the wrapped function and records the result on a tape.
+
+ (Only records results on a tape if the function has outputs)
+
+ Args:
+ args: The tensor inputs to the function.
+ Returns:
+ The call output.
+ """
all_args = args + self._extra_inputs
signature = self._forward_fdef.signature
ctx = context.context()
@@ -420,6 +428,8 @@ class GraphModeFunction(object):
inputs=all_args,
attrs=None,
ctx=ctx)
+ if not outputs:
+ return None
else:
g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access
@@ -431,8 +441,9 @@ class GraphModeFunction(object):
name="FunctionCall",
compute_shapes=False)
outputs = op.outputs
- outputs = [outputs] if isinstance(
- outputs, (ops.Tensor, type(None))) else list(outputs)
+ if not outputs:
+ return op
+ outputs = [outputs] if isinstance(outputs, ops.Tensor) else list(outputs)
for i, s in enumerate(self._output_shapes):
outputs[i].set_shape(s)
real_outputs = outputs[:len(self._returns)]
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index f53d6c2608..cfdbe5f079 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -349,6 +349,23 @@ class FunctionTest(test.TestCase):
g(constant_op.constant(1.0))
+ def testNestedDefunWithNoOutputAndTapedInput(self):
+ three = resource_variable_ops.ResourceVariable(3.0, name='v')
+
+ @function.defun
+ def f(x):
+ # This function intentionally takes a taped variable as input,
+ # but does not return any values
+ math_ops.add(x, three)
+
+ @function.defun
+ def g(x):
+ tape.watch_variable(x)
+ y = math_ops.add(x, three)
+ f(y)
+
+ g(three)
+
def testGradientTensorConversionWithDefun(self):
three = resource_variable_ops.ResourceVariable(3.0, name='v')
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index e3ce0ef9d0..52b3268903 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -873,22 +873,6 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
-static tensorflow::int64 FastHandleId(PyObject* variable) {
- PyObject* handle = PyObject_GetAttrString(variable, "handle");
- if (handle == nullptr) {
- return -1;
- }
- tensorflow::int64 id = FastTensorId(handle);
- Py_DECREF(handle);
- return id;
-}
-
-struct CompareByHandleId {
- bool operator()(PyObject* lhs, PyObject* rhs) {
- return FastHandleId(lhs) < FastHandleId(rhs);
- }
-};
-
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
@@ -897,35 +881,63 @@ class GradientTape
persistent) {}
virtual ~GradientTape() {
- for (PyObject* v : watched_variables_) {
- Py_DECREF(v);
+ for (const IdAndVariable& v : watched_variables_) {
+ Py_DECREF(v.variable);
}
}
void WatchVariable(PyObject* v) {
- auto insert_result = watched_variables_.insert(v);
- if (insert_result.second) {
- // Only increment the reference count if we aren't already watching this
- // variable.
- Py_INCREF(v);
- }
- PyObject* handle = PyObject_GetAttrString(v, "handle");
+ tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
- tensorflow::int64 id = FastTensorId(handle);
- Py_DECREF(handle);
+ tensorflow::int64 id = FastTensorId(handle.get());
+
if (!PyErr_Occurred()) {
this->Watch(id);
}
+
+ tensorflow::mutex_lock l(watched_variables_mu_);
+ auto insert_result = watched_variables_.emplace(id, v);
+
+ if (insert_result.second) {
+ // Only increment the reference count if we aren't already watching this
+ // variable.
+ Py_INCREF(v);
+ }
}
- const std::set<PyObject*, CompareByHandleId> WatchedVariables() {
- return watched_variables_;
+ PyObject* GetVariablesAsPyTuple() {
+ tensorflow::mutex_lock l(watched_variables_mu_);
+ PyObject* result = PyTuple_New(watched_variables_.size());
+ Py_ssize_t pos = 0;
+ for (const IdAndVariable& id_and_variable : watched_variables_) {
+ PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
+ Py_INCREF(id_and_variable.variable);
+ }
+ return result;
}
private:
- std::set<PyObject*, CompareByHandleId> watched_variables_;
+ // We store an IdAndVariable in the map since the map needs to be locked
+ // during insert, but should not call back into python during insert to avoid
+ // deadlocking with the GIL.
+ struct IdAndVariable {
+ tensorflow::int64 id;
+ PyObject* variable;
+
+ IdAndVariable(tensorflow::int64 id, PyObject* variable)
+ : id(id), variable(variable) {}
+ };
+ struct CompareById {
+ bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) {
+ return lhs.id < rhs.id;
+ }
+ };
+
+ tensorflow::mutex watched_variables_mu_;
+ std::set<IdAndVariable, CompareById> watched_variables_
+ GUARDED_BY(watched_variables_mu_);
};
typedef struct {
@@ -1217,15 +1229,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
- const auto& watched_variables =
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
- PyObject* result = PyTuple_New(watched_variables.size());
- Py_ssize_t pos = 0;
- for (PyObject* variable : watched_variables) {
- PyTuple_SET_ITEM(result, pos++, variable);
- Py_INCREF(variable);
- }
- return result;
+ return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
namespace {
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index d538c6c415..9e716e81f4 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -12,6 +12,10 @@ py_library(
name = "estimator_py",
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":baseline",
":boosted_trees",
@@ -971,7 +975,10 @@ py_test(
size = "large",
srcs = ["keras_test.py"],
srcs_version = "PY2AND3",
- tags = ["notsan"],
+ tags = [
+ "no_windows",
+ "notsan",
+ ],
deps = [
":keras",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
new file mode 100644
index 0000000000..cddee9b8f3
--- /dev/null
+++ b/tensorflow/python/estimator/api/BUILD
@@ -0,0 +1,17 @@
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+
+gen_api_init_files(
+ name = "estimator_python_api_gen",
+ api_name = "estimator",
+ output_files = ESTIMATOR_API_INIT_FILES,
+ package = "tensorflow.python.estimator",
+)
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
index 980c057372..3c6816cb03 100644
--- a/tensorflow/python/estimator/canned/baseline.py
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -59,7 +59,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.3 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -174,7 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer,
train_op_fn=train_op_fn)
-@tf_export('estimator.BaselineClassifier')
+@estimator_export('estimator.BaselineClassifier')
class BaselineClassifier(estimator.Estimator):
"""A classifier that can establish a simple baseline.
@@ -277,7 +277,7 @@ class BaselineClassifier(estimator.Estimator):
config=config)
-@tf_export('estimator.BaselineRegressor')
+@estimator_export('estimator.BaselineRegressor')
class BaselineRegressor(estimator.Estimator):
"""A regressor that can establish a simple baseline.
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 4e6010a162..6b54f51ca6 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -39,7 +39,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
@@ -712,7 +712,7 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
-@tf_export('estimator.BoostedTreesClassifier')
+@estimator_export('estimator.BoostedTreesClassifier')
class BoostedTreesClassifier(estimator.Estimator):
"""A Classifier for Tensorflow Boosted Trees models."""
@@ -830,7 +830,7 @@ class BoostedTreesClassifier(estimator.Estimator):
model_fn=_model_fn, model_dir=model_dir, config=config)
-@tf_export('estimator.BoostedTreesRegressor')
+@estimator_export('estimator.BoostedTreesRegressor')
class BoostedTreesRegressor(estimator.Estimator):
"""A Regressor for Tensorflow Boosted Trees models."""
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1feac36f35..b924ad5df4 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -32,7 +32,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.05 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -201,7 +201,7 @@ def _dnn_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNClassifier')
+@estimator_export('estimator.DNNClassifier')
class DNNClassifier(estimator.Estimator):
"""A classifier for TensorFlow DNN models.
@@ -353,7 +353,7 @@ class DNNClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNRegressor')
+@estimator_export('estimator.DNNRegressor')
class DNNRegressor(estimator.Estimator):
"""A regressor for TensorFlow DNN models.
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 95efc0a028..64d81c46ce 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -37,7 +37,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rates are a historical artifact of the initial
# implementation.
@@ -225,7 +225,7 @@ def _dnn_linear_combined_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNLinearCombinedClassifier')
+@estimator_export('estimator.DNNLinearCombinedClassifier')
class DNNLinearCombinedClassifier(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined classification models.
@@ -406,7 +406,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNLinearCombinedRegressor')
+@estimator_export('estimator.DNNLinearCombinedRegressor')
class DNNLinearCombinedRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined models for regression.
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 81657f0c01..705fc3ce06 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -164,7 +164,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
logits=logits)
-@tf_export('estimator.LinearClassifier')
+@estimator_export('estimator.LinearClassifier')
class LinearClassifier(estimator.Estimator):
"""Linear classifier model.
@@ -317,7 +317,7 @@ class LinearClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.LinearRegressor')
+@estimator_export('estimator.LinearRegressor')
class LinearRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear regression problems.
diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py
index 74e5e5a1be..1ae0f1e9f7 100644
--- a/tensorflow/python/estimator/canned/parsing_utils.py
+++ b/tensorflow/python/estimator/canned/parsing_utils.py
@@ -23,10 +23,10 @@ import six
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.classifier_parse_example_spec')
+@estimator_export('estimator.classifier_parse_example_spec')
def classifier_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.int64,
@@ -166,7 +166,7 @@ def classifier_parse_example_spec(feature_columns,
return parsing_spec
-@tf_export('estimator.regressor_parse_example_spec')
+@estimator_export('estimator.regressor_parse_example_spec')
def regressor_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.float32,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 4f57a4ef79..41c25f1c73 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -66,14 +66,14 @@ from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_VALID_MODEL_FN_ARGS = set(
['features', 'labels', 'mode', 'params', 'self', 'config'])
-@tf_export('estimator.Estimator')
+@estimator_export('estimator.Estimator')
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
@@ -566,7 +566,8 @@ class Estimator(object):
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_validate_features_in_predict_input',
+ '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_validate_features_in_predict_input',
'_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
@@ -893,11 +894,14 @@ class Estimator(object):
estimator_spec.scaffold.local_init_op or
monitored_session.Scaffold.default_local_init_op())
- saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
- sharded=True)
+ # This saver will be used both for restoring variables now,
+ # and in saving out the metagraph below. This ensures that any
+ # Custom Savers stored with the Scaffold are passed through to the
+ # SavedModel for restore later.
+ graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True)
try:
- saver_for_restore.restore(session, checkpoint_path)
+ graph_saver.restore(session, checkpoint_path)
except errors.NotFoundError as e:
msg = ('Could not load all requested variables from the checkpoint. '
'Please make sure your model_fn does not expect variables '
@@ -918,7 +922,8 @@ class Estimator(object):
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
strip_default_attrs=strip_default_attrs,
- legacy_init_op=local_init_op)
+ legacy_init_op=local_init_op,
+ saver=graph_saver)
if save_variables:
builder.add_meta_graph_and_variables(
@@ -1630,11 +1635,12 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
+
VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
-tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
+estimator_export('estimator.VocabInfo')(VocabInfo)
-@tf_export('estimator.WarmStartSettings')
+@estimator_export('estimator.WarmStartSettings')
class WarmStartSettings(
collections.namedtuple('WarmStartSettings', [
'ckpt_to_initialize_from',
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 9c0d0f7390..a43b820f32 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -100,6 +100,11 @@ def check_eventfile_for_keyword(keyword, dir_):
return any(summaries_with_matching_keyword(keyword, dir_))
+def get_mock_saver():
+ real_saver = saver.Saver()
+ return test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)
+
+
class EstimatorInheritanceConstraintTest(test.TestCase):
"""Tests that sub classes cannot override methods of Estimator."""
@@ -1295,9 +1300,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+ self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -1819,9 +1822,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+ self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -2315,8 +2316,8 @@ class EstimatorExportTest(test.TestCase):
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
- # Note that the SavedModel builder replaced the Saver with a new one
- self.assertTrue('save_1/LookupTableImportV2' in graph_ops)
+ # The original saver is used to restore variables
+ self.assertTrue('save/LookupTableImportV2' in graph_ops)
# Clean up.
gfile.DeleteRecursively(tmpdir)
@@ -2481,9 +2482,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+ self.mock_saver = get_mock_saver()
scores = constant_op.constant([3.])
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -2506,19 +2505,24 @@ class EstimatorExportTest(test.TestCase):
est.export_savedmodel(export_dir_base, serving_input_receiver_fn)
self.assertTrue(self.mock_saver.restore.called)
+ self.assertTrue(self.mock_saver.export_meta_graph.called)
+ self.assertTrue(self.mock_saver.save.called)
def test_scaffold_is_used_for_saver_multiple_modes(self):
tmpdir = tempfile.mkdtemp()
+ savers = {'predict_saver': None, 'train_saver': None}
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+
scores = constant_op.constant([3.])
if mode == model_fn_lib.ModeKeys.PREDICT:
- scaffold = training.Scaffold(saver=self.mock_saver)
+ savers['predict_saver'] = get_mock_saver()
+ scaffold = training.Scaffold(saver=savers['predict_saver'])
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ savers['train_saver'] = get_mock_saver()
+ scaffold = training.Scaffold(saver=savers['train_saver'])
else:
scaffold = training.Scaffold()
return model_fn_lib.EstimatorSpec(
@@ -2542,7 +2546,13 @@ class EstimatorExportTest(test.TestCase):
compat.as_bytes(tmpdir), compat.as_bytes('export'))
est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
- self.assertTrue(self.mock_saver.restore.called)
+ self.assertTrue(savers['train_saver'].restore.called)
+ self.assertEqual(savers['train_saver'].export_meta_graph.call_count, 1)
+ self.assertEqual(savers['train_saver'].save.call_count, 1)
+
+ self.assertTrue(savers['predict_saver'].restore.called)
+ self.assertEqual(savers['predict_saver'].export_meta_graph.call_count, 1)
+ self.assertEqual(savers['predict_saver'].save.call_count, 0)
def test_scaffold_is_used_for_local_init(self):
tmpdir = tempfile.mkdtemp()
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index ff19a0a7f4..010c0f3f59 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
@@ -93,7 +93,7 @@ def _check_tensor_key(name, error_label='feature'):
raise ValueError('{} keys must be strings: {}.'.format(error_label, name))
-@tf_export('estimator.export.ServingInputReceiver')
+@estimator_export('estimator.export.ServingInputReceiver')
class ServingInputReceiver(
collections.namedtuple(
'ServingInputReceiver',
@@ -161,7 +161,7 @@ class ServingInputReceiver(
receiver_tensors_alternatives=receiver_tensors_alternatives)
-@tf_export('estimator.export.TensorServingInputReceiver')
+@estimator_export('estimator.export.TensorServingInputReceiver')
class TensorServingInputReceiver(
collections.namedtuple(
'TensorServingInputReceiver',
@@ -263,7 +263,7 @@ class SupervisedInputReceiver(
receiver_tensors=receiver_tensors)
-@tf_export('estimator.export.build_parsing_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn')
def build_parsing_serving_input_receiver_fn(feature_spec,
default_batch_size=None):
"""Build a serving_input_receiver_fn expecting fed tf.Examples.
@@ -313,7 +313,7 @@ def _placeholders_from_receiver_tensors_dict(input_vals,
}
-@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_raw_serving_input_receiver_fn')
def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""Build a serving_input_receiver_fn expecting feature Tensors.
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index d387ea2940..6c26d29985 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,10 +26,10 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.export.ExportOutput')
+@estimator_export('estimator.export.ExportOutput')
class ExportOutput(object):
"""Represents an output of a model that can be served.
@@ -100,7 +100,7 @@ class ExportOutput(object):
return output_dict
-@tf_export('estimator.export.ClassificationOutput')
+@estimator_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
"""Represents the output of a classification head.
@@ -169,7 +169,7 @@ class ClassificationOutput(ExportOutput):
examples, self.classes, self.scores)
-@tf_export('estimator.export.RegressionOutput')
+@estimator_export('estimator.export.RegressionOutput')
class RegressionOutput(ExportOutput):
"""Represents the output of a regression head."""
@@ -202,7 +202,7 @@ class RegressionOutput(ExportOutput):
return signature_def_utils.regression_signature_def(examples, self.value)
-@tf_export('estimator.export.PredictOutput')
+@estimator_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
"""Represents the output of a generic prediction head.
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 766ea23f2a..b18212cfcd 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -28,10 +28,10 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.Exporter')
+@estimator_export('estimator.Exporter')
class Exporter(object):
"""A class representing a type of model export."""
@@ -172,7 +172,7 @@ def _verify_compare_fn_args(compare_fn):
(compare_fn, non_valid_args))
-@tf_export('estimator.BestExporter')
+@estimator_export('estimator.BestExporter')
class BestExporter(Exporter):
"""This class exports the serving graph and checkpoints of the best models.
@@ -367,7 +367,7 @@ class BestExporter(Exporter):
return best_eval_result
-@tf_export('estimator.FinalExporter')
+@estimator_export('estimator.FinalExporter')
class FinalExporter(Exporter):
"""This class exports the serving graph and checkpoints in the end.
@@ -418,7 +418,7 @@ class FinalExporter(Exporter):
is_the_final_export)
-@tf_export('estimator.LatestExporter')
+@estimator_export('estimator.LatestExporter')
class LatestExporter(Exporter):
"""This class regularly exports the serving graph and checkpoints.
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index eefc7c712d..a6cefdece2 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -24,7 +24,7 @@ import numpy as np
from six import string_types
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# Key name to pack the target into dict of `features`. See
# `_get_unique_target_key` for details.
@@ -87,7 +87,7 @@ def _validate_and_convert_features(x):
return ordered_dict_data
-@tf_export('estimator.inputs.numpy_input_fn')
+@estimator_export('estimator.inputs.numpy_input_fn')
def numpy_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index 1ed6ed4d84..57f8e5fd6a 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
try:
# pylint: disable=g-import-not-at-top
@@ -35,7 +35,7 @@ except ImportError:
HAS_PANDAS = False
-@tf_export('estimator.inputs.pandas_input_fn')
+@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 3edf9fe940..c60c7f63ba 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -32,10 +32,10 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.ModeKeys')
+@estimator_export('estimator.ModeKeys')
class ModeKeys(object):
"""Standard names for model modes.
@@ -62,7 +62,7 @@ EXPORT_TAG_MAP = {
}
-@tf_export('estimator.EstimatorSpec')
+@estimator_export('estimator.EstimatorSpec')
class EstimatorSpec(
collections.namedtuple('EstimatorSpec', [
'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index c7707be839..b948ce96e0 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_USE_DEFAULT = object()
@@ -296,7 +296,7 @@ class TaskType(object):
EVALUATOR = 'evaluator'
-@tf_export('estimator.RunConfig')
+@estimator_export('estimator.RunConfig')
class RunConfig(object):
"""This class specifies the configurations for an `Estimator` run."""
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index fb6a68b4f7..1572af579b 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5
@@ -115,7 +115,7 @@ def _is_google_env():
return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE
-@tf_export('estimator.TrainSpec')
+@estimator_export('estimator.TrainSpec')
class TrainSpec(
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
"""Configuration for the "train" part for the `train_and_evaluate` call.
@@ -167,7 +167,7 @@ class TrainSpec(
cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
-@tf_export('estimator.EvalSpec')
+@estimator_export('estimator.EvalSpec')
class EvalSpec(
collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
@@ -263,7 +263,7 @@ class EvalSpec(
throttle_secs=throttle_secs)
-@tf_export('estimator.train_and_evaluate')
+@estimator_export('estimator.train_and_evaluate')
def train_and_evaluate(estimator, train_spec, eval_spec):
"""Train and evaluate the `estimator`.
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 59801efc26..af2ead9b84 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -1782,9 +1782,7 @@ class _EmbeddingColumnLayer(base.Layer):
Args:
embedding_shape: Shape of the embedding variable used for lookup.
initializer: A variable initializer function to be used in embedding
- variable initialization. If not specified, defaults to
- `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
- `1/sqrt(dimension)`.
+ variable initialization.
weight_collections: A list of collection names to which the Variable will
be added. Note that, variables will also be added to collections
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 79ee57355d..82ecba310b 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -1172,3 +1172,13 @@ _DTYPE_TO_STR = {
dtypes.qint32: "qi32",
dtypes.bfloat16: "b16"
}
+
+
+def function_def_from_tf_function(c_func):
+ """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_FunctionToFunctionDef(c_func, buf)
+ data = c_api.TF_GetBuffer(buf)
+ fdef = function_pb2.FunctionDef()
+ fdef.ParseFromString(compat.as_bytes(data))
+ return fdef
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 0c06d9aa41..4a6146e0a6 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -321,32 +321,6 @@ def NCHWToNHWC(input_tensor):
return [input_tensor[a] for a in new_axes[ndims]]
-# TODO(skyewm): remove this eventually
-# pylint: disable=protected-access
-def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
- prev_value = ops._USE_C_API
- ops._USE_C_API = use_c_api
- try:
- # Reset the default graph so it has the C API enabled. We call
- # reset_default_graph() instead of creating a new default Graph context to
- # make this robust to tests that call reset_default_graph(), which requires
- # that the current default graph isn't nested.
- ops.reset_default_graph()
- fn(*args, **kwargs)
- finally:
- ops._USE_C_API = prev_value
- # Make sure default graph reflects prev_value in case next test doesn't call
- # reset_default_graph().
- ops.reset_default_graph()
-
-
-# pylint: disable=protected-access
-
-
-def c_api_and_cuda_enabled():
- return ops._USE_C_API and IsGoogleCudaEnabled()
-
-
def skip_if(condition):
"""Skips the decorated function if condition is or evaluates to True.
@@ -372,46 +346,6 @@ def skip_if(condition):
return real_skip_if
-# TODO(skyewm): remove this eventually
-def disable_c_api(fn):
- """Decorator for disabling the C API on a test.
-
- Note this disables the C API after running the test class's setup/teardown
- methods.
-
- Args:
- fn: the function to be wrapped
-
- Returns:
- The wrapped function
- """
-
- def wrapper(*args, **kwargs):
- _use_c_api_wrapper(fn, False, *args, **kwargs)
-
- return wrapper
-
-
-# TODO(skyewm): remove this eventually
-def enable_c_api(fn):
- """Decorator for enabling the C API on a test.
-
- Note this enables the C API after running the test class's setup/teardown
- methods.
-
- Args:
- fn: the function to be wrapped
-
- Returns:
- The wrapped function
- """
-
- def wrapper(*args, **kwargs):
- _use_c_api_wrapper(fn, True, *args, **kwargs)
-
- return wrapper
-
-
def enable_c_shapes(fn):
"""Decorator for enabling C shapes on a test.
@@ -425,46 +359,19 @@ def enable_c_shapes(fn):
The wrapped function
"""
+ # pylint: disable=protected-access
def wrapper(*args, **kwargs):
prev_value = ops._USE_C_SHAPES
- # Only use C shapes if the C API is already enabled.
- ops._USE_C_SHAPES = ops._USE_C_API
+ ops._USE_C_SHAPES = True
try:
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
+ # pylint: enable=protected-access
return wrapper
-# This decorator is a hacky way to run all the test methods in a decorated
-# class with and without C API enabled.
-# TODO(iga): Remove this and its uses once we switch to using C API by default.
-def with_c_api(cls):
- """Adds methods that call original methods but with C API enabled.
-
- Note this enables the C API in new methods after running the test class's
- setup method. This can be a problem if some objects are created in it
- before the C API is enabled.
-
- Args:
- cls: class to decorate
-
- Returns:
- cls with new test methods added
- """
- # If the C API is already enabled, don't do anything. Some tests break if the
- # same test is run twice, so this allows us to turn on the C API by default
- # without breaking these tests.
- if ops._USE_C_API:
- return cls
-
- for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCApi", enable_c_api(value))
- return cls
-
-
def with_c_shapes(cls):
"""Adds methods that call original methods but with C API shapes enabled.
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index af3d1fa33d..2a4a1c861c 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -4242,7 +4242,11 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
"""Apply 1D conv with un-shared weights.
Arguments:
- inputs: 3D tensor with shape: (batch_size, steps, input_dim)
+ inputs: 3D tensor with shape:
+ (batch_size, steps, input_dim)
+ if data_format is "channels_last" or
+ (batch_size, input_dim, steps)
+ if data_format is "channels_first".
kernel: the unshared weight for convolution,
with shape (output_length, feature_dim, filters)
kernel_size: a tuple of a single integer,
@@ -4272,11 +4276,20 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
xs = []
for i in range(output_length):
slice_length = slice(i * stride, i * stride + kernel_size[0])
- xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
+ if data_format == 'channels_first':
+ xs.append(reshape(inputs[:, :, slice_length], (1, -1, feature_dim)))
+ else:
+ xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
+
x_aggregate = concatenate(xs, axis=0)
# Shape: `(output_length, batch_size, filters)`.
output = batch_dot(x_aggregate, kernel)
- return permute_dimensions(output, (1, 0, 2))
+
+ if data_format == 'channels_first':
+ output = permute_dimensions(output, (1, 2, 0))
+ else:
+ output = permute_dimensions(output, (1, 0, 2))
+ return output
def local_conv2d(inputs,
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 58df263a4f..53e30e0e4a 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -810,6 +810,53 @@ class BackendNNOpsTest(test.TestCase):
padding='same', data_format='channels_last')
self.assertEqual(y.get_shape().as_list(), [10, 5, 5])
+ def test_local_conv1d_channels_dim(self):
+ input_length = 5
+ input_dim = 3
+ batch_size = 2
+
+ inputs = np.random.normal(0, 1, (batch_size, input_dim, input_length))
+ inputs_cf = keras.backend.variable(inputs)
+
+ filters = 4
+ for kernel_size in [(1,), (2,), (3,)]:
+ for strides in [(1,), (2,), (3,)]:
+ output_length = (input_length - kernel_size[0]
+ + strides[0]) // strides[0]
+
+ kernel_shape = (output_length, kernel_size[0] * input_dim, filters)
+ kernel = np.random.normal(0, 1, (output_length,
+ input_dim,
+ kernel_size[0],
+ filters))
+ kernel_cf = np.reshape(kernel, kernel_shape)
+ kernel_cf = keras.backend.variable(kernel_cf)
+
+ conv_cf = keras.backend.local_conv1d(inputs_cf,
+ kernel_cf,
+ kernel_size,
+ strides,
+ 'channels_first')
+
+ inputs_cl = np.transpose(inputs, (0, 2, 1))
+ inputs_cl = keras.backend.variable(inputs_cl)
+
+ kernel_cl = np.reshape(np.transpose(kernel, (0, 2, 1, 3)),
+ kernel_shape)
+ kernel_cl = keras.backend.variable(kernel_cl)
+
+ conv_cl = keras.backend.local_conv1d(inputs_cl,
+ kernel_cl,
+ kernel_size,
+ strides,
+ 'channels_last')
+ with self.test_session():
+ conv_cf = keras.backend.eval(conv_cf)
+ conv_cl = keras.backend.eval(conv_cl)
+
+ self.assertAllCloseAccordingToType(conv_cf,
+ np.transpose(conv_cl, (0, 2, 1)))
+
def test_conv2d(self):
val = np.random.random((10, 4, 10, 10))
x = keras.backend.variable(val)
diff --git a/tensorflow/python/keras/datasets/boston_housing.py b/tensorflow/python/keras/datasets/boston_housing.py
index 8c043638c0..4c4cab8c08 100644
--- a/tensorflow/python/keras/datasets/boston_housing.py
+++ b/tensorflow/python/keras/datasets/boston_housing.py
@@ -39,9 +39,10 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
assert 0 <= test_split < 1
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/keras-datasets/boston_housing.npz',
+ origin=origin_folder + 'boston_housing.npz',
file_hash=
'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5')
f = np.load(path)
diff --git a/tensorflow/python/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/datasets/fashion_mnist.py
index 45e27aad34..3f4c6c7413 100644
--- a/tensorflow/python/keras/datasets/fashion_mnist.py
+++ b/tensorflow/python/keras/datasets/fashion_mnist.py
@@ -33,9 +33,15 @@ def load_data():
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
+
+ License:
+ The copyright for Fashion-MNIST is held by Zalando SE.
+ Fashion-MNIST is licensed under the [MIT license](
+ https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
+
"""
dirname = os.path.join('datasets', 'fashion-mnist')
- base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
+ base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
diff --git a/tensorflow/python/keras/datasets/imdb.py b/tensorflow/python/keras/datasets/imdb.py
index 411b3e8635..b73b024162 100644
--- a/tensorflow/python/keras/datasets/imdb.py
+++ b/tensorflow/python/keras/datasets/imdb.py
@@ -77,9 +77,10 @@ def load_data(path='imdb.npz',
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/imdb.npz',
+ origin=origin_folder + 'imdb.npz',
file_hash='599dadb1135973df5b59232a0e9a887c')
with np.load(path) as f:
x_train, labels_train = f['x_train'], f['y_train']
@@ -140,9 +141,10 @@ def get_word_index(path='imdb_word_index.json'):
Returns:
The word index dictionary.
"""
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json',
+ origin=origin_folder + 'imdb_word_index.json',
file_hash='bfafd718b763782e994055a2d397834f')
with open(path) as f:
return json.load(f)
diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py
index 631189731a..03564accc7 100644
--- a/tensorflow/python/keras/datasets/mnist.py
+++ b/tensorflow/python/keras/datasets/mnist.py
@@ -34,10 +34,18 @@ def load_data(path='mnist.npz'):
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
+
+ License:
+ Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
+ which is a derivative work from original NIST datasets.
+ MNIST dataset is made available under the terms of the
+ [Creative Commons Attribution-Share Alike 3.0 license.](
+ https://creativecommons.org/licenses/by-sa/3.0/)
"""
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
+ origin=origin_folder + 'mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
diff --git a/tensorflow/python/keras/datasets/reuters.py b/tensorflow/python/keras/datasets/reuters.py
index b070ba8d12..2120b4b242 100644
--- a/tensorflow/python/keras/datasets/reuters.py
+++ b/tensorflow/python/keras/datasets/reuters.py
@@ -75,9 +75,10 @@ def load_data(path='reuters.npz',
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/reuters.npz',
+ origin=origin_folder + 'reuters.npz',
file_hash='87aedbeb0cb229e378797a632c1997b6')
with np.load(path) as f:
xs, labels = f['x'], f['y']
@@ -124,9 +125,10 @@ def get_word_index(path='reuters_word_index.json'):
Returns:
The word index dictionary.
"""
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json',
+ origin=origin_folder + 'reuters_word_index.json',
file_hash='4d44cc38712099c9e383dc6e5f11a921')
f = open(path)
data = json.load(f)
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index f53898d8e3..427efaaf11 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import functools
import json
import os
import weakref
@@ -1300,7 +1301,11 @@ class Network(base_layer.Layer):
with h5py.File(filepath, 'w') as f:
saving.save_weights_to_hdf5_group(f, self.layers)
else:
- self._checkpointable_saver.save(filepath)
+ if context.executing_eagerly():
+ session = None
+ else:
+ session = backend.get_session()
+ self._checkpointable_saver.save(filepath, session=session)
def load_weights(self, filepath, by_name=False):
"""Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
@@ -1360,7 +1365,8 @@ class Network(base_layer.Layer):
'loading TensorFlow-formatted weights (got by_name=True to '
'load_weights).')
if not context.executing_eagerly():
- finalizer = status.run_restore_ops
+ session = backend.get_session()
+ finalizer = functools.partial(status.run_restore_ops, session=session)
if self.built:
finalizer()
else:
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 4352c5cb18..7e82db028b 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -428,26 +428,27 @@ class TestWholeModelSaving(test.TestCase):
os.remove(fname)
def test_saving_lambda_numpy_array_arguments(self):
- if h5py is None:
- self.skipTest('h5py required to run this test')
+ with self.test_session():
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
- mean = np.random.random((4, 2, 3))
- std = np.abs(np.random.random((4, 2, 3))) + 1e-5
- inputs = keras.layers.Input(shape=(4, 2, 3))
- output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
- arguments={'mu': mean, 'std': std})(inputs)
- model = keras.models.Model(inputs, output)
- model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+ mean = np.random.random((4, 2, 3))
+ std = np.abs(np.random.random((4, 2, 3))) + 1e-5
+ inputs = keras.layers.Input(shape=(4, 2, 3))
+ output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
+ arguments={'mu': mean, 'std': std})(inputs)
+ model = keras.models.Model(inputs, output)
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
- fd, fname = tempfile.mkstemp('.h5')
- keras.models.save_model(model, fname)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
- model = keras.models.load_model(fname)
- os.close(fd)
- os.remove(fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
- self.assertAllClose(mean, model.layers[1].arguments['mu'])
- self.assertAllClose(std, model.layers[1].arguments['std'])
+ self.assertAllClose(mean, model.layers[1].arguments['mu'])
+ self.assertAllClose(std, model.layers[1].arguments['std'])
def test_saving_model_with_long_layer_names(self):
if h5py is None:
@@ -604,6 +605,25 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
# Indirectly tests that the user is prompted
model.save_weights(prefix, save_format='tensorflow', overwrite=False)
+ def test_no_default_session(self):
+ with ops.Graph().as_default():
+ self.assertFalse(ops.get_default_session())
+ data = np.random.random((1000, 32)).astype(np.float32)
+ labels = np.random.random((1000, 10)).astype(np.float32)
+
+ model = keras.models.Sequential([
+ keras.layers.Dense(10, activation='softmax'),
+ keras.layers.Dense(10, activation='softmax')])
+
+ model.compile(optimizer=training_module.RMSPropOptimizer(0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ model.fit(data, labels)
+ fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt')
+ model.save_weights(fname)
+ model.load_weights(fname)
+
def test_no_graph_pollution(self):
with context.graph_mode():
graph = ops.Graph()
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 52e29b0ffa..3ca8fdd326 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -222,11 +222,16 @@ class Sequential(Model):
for layer in self._layers:
x = layer(x)
self.outputs = [x]
+ # Make sure that the model's input shape will be preserved during
+ # serialization.
+ if self._layers:
+ self._layers[0]._batch_input_shape = batch_shape
if self.inputs:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self.built = True
- self._track_layers(self._layers)
+ if self._layers:
+ self._track_layers(self._layers)
def predict_proba(self, x, batch_size=32, verbose=0):
"""Generates class probability predictions for the input samples.
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 69a288e69b..cdaf9162de 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -209,6 +209,30 @@ class TestSequential(test.TestCase):
x2 = model.predict(val_a)
assert np.abs(np.sum(x1 - x2)) > 1e-5
+ def test_sequential_deferred_build_serialization(self):
+ num_hidden = 5
+ input_dim = 3
+ batch_size = 5
+ num_classes = 2
+
+ model = keras.models.Sequential()
+ # We don't specify the input shape.
+ model.add(keras.layers.Dense(num_hidden))
+ model.add(keras.layers.Dense(num_classes))
+ model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
+ self.assertFalse(model.built)
+
+ x = np.random.random((batch_size, input_dim))
+ y = np.random.random((batch_size, num_classes))
+ model.train_on_batch(x, y)
+ self.assertTrue(model.built)
+
+ config = model.get_config()
+ new_model = keras.models.Sequential.from_config(config)
+ self.assertTrue(new_model.built)
+ self.assertEqual(len(model.layers), 2)
+ self.assertEqual(len(model.weights), 4)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index aca63f822b..fce6cbdb7a 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1011,14 +1011,16 @@ class Model(Network):
# to keep track of number of inputs and outputs and their ndim.
if isinstance(inputs, (list, tuple)):
if tensor_util.is_tensor(inputs[0]):
- dummy_output_values = self.call(inputs)
+ dummy_output_values = self.call(
+ training_utils.cast_if_floating_dtype(inputs))
else:
dummy_output_values = self.call(
[ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
dummy_input_values = list(inputs)
else:
if tensor_util.is_tensor(inputs):
- dummy_output_values = self.call(inputs)
+ dummy_output_values = self.call(
+ training_utils.cast_if_floating_dtype(inputs))
else:
dummy_output_values = self.call(
ops.convert_to_tensor(inputs, dtype=K.floatx()))
@@ -1619,7 +1621,10 @@ class Model(Network):
# Validate and standardize user data.
inputs, _, _ = self._standardize_user_data(x)
if context.executing_eagerly():
- if not isinstance(inputs, iterator_ops.EagerIterator):
+ if (isinstance(x, iterator_ops.EagerIterator) or
+ (isinstance(x, dataset_ops.Dataset) and context.executing_eagerly())):
+ inputs = training_utils.cast_if_floating_dtype(inputs)
+ else:
inputs = [
ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs
]
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 081e46aa66..e8838cd3bc 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -255,6 +255,8 @@ def iterator_fit_loop(model,
# Validate and standardize data.
x, y, sample_weights = model._standardize_user_data(
x, y, class_weight=class_weight)
+ x = training_utils.cast_if_floating_dtype(x)
+ y = training_utils.cast_if_floating_dtype(y)
if sample_weights:
sample_weights = [
ops.convert_to_tensor(val, dtype=backend.floatx())
@@ -471,6 +473,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
# Validate and standardize data.
x, y, sample_weights = model._standardize_user_data(x, y)
+ x = training_utils.cast_if_floating_dtype(x)
+ y = training_utils.cast_if_floating_dtype(y)
# Calculate model output, loss values.
loss_outs, loss, loss_metrics = _model_loss(
@@ -501,11 +505,11 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
if verbose == 1:
progbar.update(step_index + 1)
- for i in range(len(outs)):
- outs[i] /= num_samples
- if len(outs) == 1:
- return outs[0]
- return outs
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def batch_test_loop(model,
@@ -639,6 +643,7 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
# Validate and standardize data.
x, _, _ = model._standardize_user_data(x)
+ x = training_utils.cast_if_floating_dtype(x)
if model._expects_training_arg:
batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
@@ -814,7 +819,10 @@ def train_on_batch(model, inputs, targets, sample_weights=None):
Returns:
total loss and the loss associated with each output.
"""
- if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ if len(inputs) and tensor_util.is_tensor(inputs[0]):
+ inputs = training_utils.cast_if_floating_dtype(inputs)
+ targets = training_utils.cast_if_floating_dtype(targets)
+ else:
inputs = [
ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
]
@@ -849,7 +857,10 @@ def test_on_batch(model, inputs, targets, sample_weights=None):
Returns:
total loss, loss and metrics associated with each output.
"""
- if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ if len(inputs) and tensor_util.is_tensor(inputs[0]):
+ inputs = training_utils.cast_if_floating_dtype(inputs)
+ targets = training_utils.cast_if_floating_dtype(targets)
+ else:
inputs = [
ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
]
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index d9446fd437..1571a7782a 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python import keras
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util as tf_test_util
@@ -402,6 +403,24 @@ class TrainingTest(test.TestCase):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)
+ def test_generator_methods(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(3,)))
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, 'mse', metrics=['mae'])
+
+ x = np.random.random((10, 3))
+ y = np.random.random((10, 4))
+
+ def iterator():
+ while True:
+ yield x, y
+
+ model.fit_generator(iterator(), steps_per_epoch=3, epochs=1)
+ model.evaluate_generator(iterator(), steps=3)
+ out = model.predict_generator(iterator(), steps=3)
+ self.assertEqual(out.shape, (30, 4))
+
class LossWeightingTest(test.TestCase):
@@ -670,6 +689,59 @@ class CorrectnessTest(test.TestCase):
outs = model.evaluate(x, y)
self.assertEqual(outs[1], 0.)
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_loss_correctness_with_iterator(self):
+ # Test that training loss is the same in eager and graph
+ # (by comparing it to a reference value in a deterministic case)
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 3, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(2, activation='softmax', kernel_initializer='ones'))
+ model.compile(
+ loss='sparse_categorical_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+ x = np.ones((100, 4), dtype=np.float32)
+ np.random.seed(123)
+ y = np.random.randint(0, 1, size=(100, 1))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ history = model.fit(iterator, epochs=1, steps_per_epoch=10)
+ self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_metrics_correctness_with_iterator(self):
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 8, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='binary_crossentropy',
+ metrics=['accuracy'],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+ np.random.seed(123)
+ x = np.random.randint(10, size=(100, 4)).astype(np.float32)
+ y = np.random.randint(2, size=(100, 1)).astype(np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(np.around(outs[1], decimals=1), 0.5)
+
+ y = np.zeros((100, 1), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(outs[1], 0.)
+
+
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 5c02d36382..a1ab720189 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -129,8 +129,10 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
epochs=1,
batch_size=5,
verbose=0)
@@ -138,8 +140,10 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
epochs=1,
batch_size=5,
verbose=1)
@@ -147,8 +151,10 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
validation_data=({
'input_a': input_a_np,
'input_b': input_b_np
@@ -162,8 +168,10 @@ class TrainingTest(test.TestCase):
model.train_on_batch({
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np})
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ })
# Test with lists for loss, metrics
loss = ['mae', 'mse']
@@ -285,16 +293,20 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
batch_size=5,
verbose=0)
model.evaluate(
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
batch_size=5,
verbose=1)
@@ -349,9 +361,11 @@ class TrainingTest(test.TestCase):
with self.test_session():
test_inputs = [
- scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
+ scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)
+ ]
test_outputs = [
- scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
+ scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)
+ ]
in1 = keras.layers.Input(shape=(3,))
in2 = keras.layers.Input(shape=(3,))
out1 = keras.layers.Dropout(0.5, name='dropout')(in1)
@@ -1721,8 +1735,8 @@ class TestTrainingWithDatasetIterators(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1786,8 +1800,8 @@ class TestTrainingWithDatasetIterators(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1811,8 +1825,8 @@ class TestTrainingWithDatasetIterators(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(2)
dataset = dataset.batch(10)
@@ -1838,8 +1852,8 @@ class TestTrainingWithDataset(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1865,8 +1879,8 @@ class TestTrainingWithDataset(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1928,8 +1942,8 @@ class TestTrainingWithDataset(test.TestCase):
model.compile(optimizer, loss)
# User forgets to batch the dataset
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
@@ -1938,8 +1952,8 @@ class TestTrainingWithDataset(test.TestCase):
model.train_on_batch(dataset)
# Wrong input shape
- inputs = np.zeros((10, 5), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 5))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index b93f999444..728a2b493b 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -553,6 +553,10 @@ def standardize_weights(y,
def has_symbolic_tensors(ls):
if context.executing_eagerly():
return False
+ return has_tensors(ls)
+
+
+def has_tensors(ls):
if isinstance(ls, (list, tuple)):
return any(tensor_util.is_tensor(v) for v in ls)
return tensor_util.is_tensor(ls)
@@ -692,3 +696,29 @@ def check_steps_argument(input_data, steps, steps_name):
input_type=input_type_str, steps_name=steps_name))
return True
return False
+
+
+def cast_if_floating_dtype(x):
+ """Casts the given data tensors to the default floating point type.
+
+ Casts only if the input is already a floating point type.
+ Args:
+ x: tensor or list/tuple of tensors.
+
+ Returns:
+ Converted input.
+
+ Raises:
+ RuntimeError: if data isn't tensors.
+ """
+ if not has_tensors(x):
+ raise RuntimeError(
+ 'Please provide tensors for casting, got: {x}'.format(x=x))
+
+ if isinstance(x, (list, tuple)):
+ return [
+ math_ops.cast(val, dtype=K.floatx())
+ if tensor_util.is_tensor(val) and val.dtype.is_floating else val
+ for val in x
+ ]
+ return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index ce1c84e98d..9ea341139e 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -151,21 +151,23 @@ class Conv(Layer):
input_dim = int(input_shape[channel_axis])
kernel_shape = self.kernel_size + (input_dim, self.filters)
- self.kernel = self.add_variable(name='kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.kernel = self.add_weight(
+ name='kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ trainable=True,
+ dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_weight(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
self.input_spec = InputSpec(ndim=self.rank + 2,
@@ -720,21 +722,23 @@ class Conv2DTranspose(Conv2D):
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
kernel_shape = self.kernel_size + (self.filters, input_dim)
- self.kernel = self.add_variable(name='kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.kernel = self.add_weight(
+ name='kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ trainable=True,
+ dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_weight(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
self.built = True
@@ -961,7 +965,7 @@ class Conv3DTranspose(Conv3D):
kernel_shape = self.kernel_size + (self.filters, input_dim)
self.input_spec = InputSpec(ndim=5, axes={channel_axis: input_dim})
- self.kernel = self.add_variable(
+ self.kernel = self.add_weight(
'kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
@@ -970,7 +974,7 @@ class Conv3DTranspose(Conv3D):
trainable=True,
dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(
+ self.bias = self.add_weight(
'bias',
shape=(self.filters,),
initializer=self.bias_initializer,
@@ -1222,7 +1226,7 @@ class SeparableConv(Conv):
pointwise_kernel_shape = (
1,) * self.rank + (self.depth_multiplier * input_dim, self.filters)
- self.depthwise_kernel = self.add_variable(
+ self.depthwise_kernel = self.add_weight(
name='depthwise_kernel',
shape=depthwise_kernel_shape,
initializer=self.depthwise_initializer,
@@ -1230,7 +1234,7 @@ class SeparableConv(Conv):
constraint=self.depthwise_constraint,
trainable=True,
dtype=self.dtype)
- self.pointwise_kernel = self.add_variable(
+ self.pointwise_kernel = self.add_weight(
name='pointwise_kernel',
shape=pointwise_kernel_shape,
initializer=self.pointwise_initializer,
@@ -1239,13 +1243,14 @@ class SeparableConv(Conv):
trainable=True,
dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_weight(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
self.built = True
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index db0c220380..f60064ed63 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -906,21 +906,23 @@ class Dense(Layer):
'should be defined. Found `None`.')
self.input_spec = InputSpec(min_ndim=2,
axes={-1: input_shape[-1].value})
- self.kernel = self.add_variable('kernel',
- shape=[input_shape[-1].value, self.units],
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- dtype=self.dtype,
- trainable=True)
+ self.kernel = self.add_weight(
+ 'kernel',
+ shape=[input_shape[-1].value, self.units],
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ dtype=self.dtype,
+ trainable=True)
if self.use_bias:
- self.bias = self.add_variable('bias',
- shape=[self.units,],
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- dtype=self.dtype,
- trainable=True)
+ self.bias = self.add_weight(
+ 'bias',
+ shape=[self.units,],
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ dtype=self.dtype,
+ trainable=True)
else:
self.bias = None
self.built = True
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index 46c18b763e..f222ea3083 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -62,6 +62,16 @@ class LocallyConnected1D(Layer):
any `dilation_rate` value != 1.
padding: Currently only supports `"valid"` (case-insensitive).
`"same"` may be supported in the future.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, length)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
@@ -122,12 +132,16 @@ class LocallyConnected1D(Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
- input_dim = input_shape[2]
+ if self.data_format == 'channels_first':
+ input_dim, input_length = input_shape[1], input_shape[2]
+ else:
+ input_dim, input_length = input_shape[2], input_shape[1]
+
if input_dim is None:
raise ValueError('Axis 2 of input should be fully-defined. '
'Found shape:', input_shape)
output_length = conv_utils.conv_output_length(
- input_shape[1], self.kernel_size[0], self.padding, self.strides[0])
+ input_length, self.kernel_size[0], self.padding, self.strides[0])
self.kernel_shape = (output_length, self.kernel_size[0] * input_dim,
self.filters)
self.kernel = self.add_weight(
@@ -145,19 +159,33 @@ class LocallyConnected1D(Layer):
constraint=self.bias_constraint)
else:
self.bias = None
- self.input_spec = InputSpec(ndim=3, axes={2: input_dim})
+
+ if self.data_format == 'channels_first':
+ self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
+ else:
+ self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
self.built = True
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
- length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0],
+ if self.data_format == 'channels_first':
+ input_length = input_shape[2]
+ else:
+ input_length = input_shape[1]
+
+ length = conv_utils.conv_output_length(input_length, self.kernel_size[0],
self.padding, self.strides[0])
- return (input_shape[0], length, self.filters)
+
+ if self.data_format == 'channels_first':
+ return (input_shape[0], self.filters, length)
+ elif self.data_format == 'channels_last':
+ return (input_shape[0], length, self.filters)
def call(self, inputs):
- output = K.local_conv1d(inputs, self.kernel, self.kernel_size, self.strides)
+ output = K.local_conv1d(inputs, self.kernel, self.kernel_size,
+ self.strides, self.data_format)
if self.use_bias:
- output = K.bias_add(output, self.bias)
+ output = K.bias_add(output, self.bias, data_format=self.data_format)
if self.activation is not None:
output = self.activation(output)
return output
@@ -172,6 +200,8 @@ class LocallyConnected1D(Layer):
self.strides,
'padding':
self.padding,
+ 'data_format':
+ self.data_format,
'activation':
activations.serialize(self.activation),
'use_bias':
diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py
index 90ae1719e1..9123d449af 100644
--- a/tensorflow/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/layers/local_test.py
@@ -40,16 +40,17 @@ class LocallyConnectedLayersTest(test.TestCase):
for strides in [1]:
if padding == 'same' and strides != 1:
continue
-
- testing_utils.layer_test(
- keras.layers.LocallyConnected1D,
- kwargs={
- 'filters': filters,
- 'kernel_size': filter_length,
- 'padding': padding,
- 'strides': strides
- },
- input_shape=(num_samples, num_steps, input_dim))
+ for data_format in ['channels_first', 'channels_last']:
+ testing_utils.layer_test(
+ keras.layers.LocallyConnected1D,
+ kwargs={
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'padding': padding,
+ 'strides': strides,
+ 'data_format': data_format
+ },
+ input_shape=(num_samples, num_steps, input_dim))
def test_locallyconnected_1d_regularization(self):
num_samples = 2
@@ -57,35 +58,39 @@ class LocallyConnectedLayersTest(test.TestCase):
input_dim = 5
filter_length = 3
filters = 4
- kwargs = {
- 'filters': filters,
- 'kernel_size': filter_length,
- 'kernel_regularizer': 'l2',
- 'bias_regularizer': 'l2',
- 'activity_regularizer': 'l2',
- }
-
- with self.test_session():
- layer = keras.layers.LocallyConnected1D(**kwargs)
- layer.build((num_samples, num_steps, input_dim))
- self.assertEqual(len(layer.losses), 2)
- layer(
- keras.backend.variable(np.ones((num_samples, num_steps, input_dim))))
- self.assertEqual(len(layer.losses), 3)
-
- k_constraint = keras.constraints.max_norm(0.01)
- b_constraint = keras.constraints.max_norm(0.01)
- kwargs = {
- 'filters': filters,
- 'kernel_size': filter_length,
- 'kernel_constraint': k_constraint,
- 'bias_constraint': b_constraint,
- }
- with self.test_session():
- layer = keras.layers.LocallyConnected1D(**kwargs)
- layer.build((num_samples, num_steps, input_dim))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
+ for data_format in ['channels_first', 'channels_last']:
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'kernel_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'data_format': data_format
+ }
+
+ with self.test_session():
+ layer = keras.layers.LocallyConnected1D(**kwargs)
+ layer.build((num_samples, num_steps, input_dim))
+ self.assertEqual(len(layer.losses), 2)
+ layer(
+ keras.backend.variable(np.ones((num_samples,
+ num_steps,
+ input_dim))))
+ self.assertEqual(len(layer.losses), 3)
+
+ k_constraint = keras.constraints.max_norm(0.01)
+ b_constraint = keras.constraints.max_norm(0.01)
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'kernel_constraint': k_constraint,
+ 'bias_constraint': b_constraint,
+ }
+ with self.test_session():
+ layer = keras.layers.LocallyConnected1D(**kwargs)
+ layer.build((num_samples, num_steps, input_dim))
+ self.assertEqual(layer.kernel.constraint, k_constraint)
+ self.assertEqual(layer.bias.constraint, b_constraint)
@tf_test_util.run_in_graph_and_eager_modes()
def test_locallyconnected_2d(self):
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 7743d00c0f..ff51eadee9 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -183,7 +183,7 @@ class BatchNormalization(Layer):
def _add_tower_local_variable(self, *args, **kwargs):
tower_context = distribute_lib.get_tower_context()
with tower_context.tower_local_var_scope('mean'):
- return self.add_variable(*args, **kwargs)
+ return self.add_weight(*args, **kwargs)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
@@ -276,7 +276,7 @@ class BatchNormalization(Layer):
self.axis[idx] = x + 1 # Account for added dimension
if self.scale:
- self.gamma = self.add_variable(
+ self.gamma = self.add_weight(
name='gamma',
shape=param_shape,
dtype=param_dtype,
@@ -291,7 +291,7 @@ class BatchNormalization(Layer):
1.0, dtype=param_dtype, shape=param_shape)
if self.center:
- self.beta = self.add_variable(
+ self.beta = self.add_weight(
name='beta',
shape=param_shape,
dtype=param_dtype,
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index ce73e7ad3e..14a336c688 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase):
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
self.assertEqual(4, q.size().eval())
+ @test_util.run_in_graph_and_eager_modes()
def testMultipleDequeues(self):
- with self.test_session() as session:
- q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
- q.enqueue_many([[1, 2, 3]]).run()
- a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()])
- self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue_many([[1, 2, 3]]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testQueuesDontShare(self):
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 7d367a9275..6f401358a2 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -177,6 +177,12 @@ if __name__ == '__main__':
MatrixUnaryFunctorGradientTest, 'MatrixDeterminantGradient', name,
_GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_determinant,
dtype, shape))
+ _AddTest(
+ MatrixUnaryFunctorGradientTest, 'LogMatrixDeterminantGradient',
+ name,
+ _GetMatrixUnaryFunctorGradientTest(
+ lambda x: linalg_ops.log_matrix_determinant(x)[1],
+ dtype, shape))
# Tests for gradients of matrix_solve_ls
for dtype in np.float32, np.float64:
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index b59e3dd7e7..4239151070 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -27,6 +27,7 @@ from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
@@ -35,6 +36,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
@@ -532,8 +534,7 @@ class PyFuncTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testEagerReturningVariableRaisesError(self):
def return_variable():
- variable = resource_variable_ops.ResourceVariable(0.0)
- return variable
+ return resource_variable_ops.ResourceVariable(0.0)
with self.assertRaisesRegexp(errors.UnknownError,
"Attempting to return a variable"):
@@ -541,6 +542,80 @@ class PyFuncTest(test.TestCase):
return_variable, inp=[], Tout=dtypes.float32)
self.evaluate(output)
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerGradientTape(self):
+
+ def f(x):
+ return x**2
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
+ dy_dx = tape.gradient(y, x)
+ self.assertEqual(self.evaluate(dy_dx), 6.0)
+
+ def testEagerGradientGraph(self):
+
+ def f(x):
+ return x**2
+
+ x = constant_op.constant(3.0)
+ y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
+ dy_dx = gradients_impl.gradients(y, x)[0]
+ self.assertEqual(self.evaluate(dy_dx), 6.0)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerGradientTapeMultipleArgs(self):
+
+ def f(x, y):
+ return x**2 + y**2
+
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(4.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ tape.watch(y)
+ z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
+
+ dz_dx, dz_dy = tape.gradient(z, [x, y])
+ self.assertEqual(self.evaluate(dz_dx), 6.0)
+ self.assertEqual(self.evaluate(dz_dy), 8.0)
+
+ def testEagerGradientGraphMultipleArgs(self):
+
+ def f(x, y):
+ return x**2 + y**2
+
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(4.0)
+ z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
+
+ dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
+ self.assertEqual(self.evaluate(dz_dx), 6.0)
+ self.assertEqual(self.evaluate(dz_dy), 8.0)
+
+ def testEagerGradientGraphLogHuber(self):
+
+ def log_huber(x, m):
+ if math_ops.abs(x) <= m:
+ return x**2
+ else:
+ return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))
+
+ x = array_ops.placeholder(dtypes.float32)
+ m = array_ops.placeholder(dtypes.float32)
+
+ y = script_ops.eager_py_func(
+ func=log_huber, inp=[x, m], Tout=dtypes.float32)
+ dy_dx = gradients_impl.gradients(y, x)[0]
+
+ with self.test_session() as sess:
+ # Takes the first branch of log_huber.
+ y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
+ self.assertEqual(y, 1.0)
+ self.assertEqual(dy_dx, 2.0)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 00d517e64e..82e0d153c2 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -822,6 +822,16 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_add(v, [1], [3])
self.assertAllEqual([1.0, 5.0], v.numpy())
+ def testScatterNdAddStateOps(self):
+ with context.eager_mode():
+ v = resource_variable_ops.ResourceVariable(
+ [1, 1, 1, 1, 1, 1, 1, 1], dtype=dtypes.float32, name="add")
+ indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
+ expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
+ state_ops.scatter_nd_add(v, indices, updates)
+ self.assertAllClose(expected, v.numpy())
+
def testScatterUpdateCast(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index c0b36f143d..ea06357804 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -26,11 +26,13 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@@ -549,6 +551,58 @@ class TensorArrayTest(test.TestCase):
dtypes.complex64, dtypes.complex128):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
+ def testTensorArrayGradWithShapeKnownElementShape(self):
+ with self.test_session(use_gpu=True) as sess:
+ ta = tensor_array_ops.TensorArray(
+ size=3,
+ dtype=dtypes.float32,
+ element_shape=tensor_shape.TensorShape([2, 3]))
+ handle, flow = data_flow_ops.tensor_array_grad_with_shape(
+ handle=ta.handle,
+ flow_in=ta.flow,
+ shape_to_prepend=tensor_shape.TensorShape([4, 5]),
+ source="source")
+ ta_grad = tensor_array_ops.TensorArray(
+ dtypes.float32, handle=handle, flow=flow)
+ value = array_ops.placeholder(dtypes.float32)
+ ta_grad = ta_grad.write(0, value)
+ read_value = ta_grad.read(0)
+
+ # Make sure shape inference worked.
+ self.assertAllEqual([None, None, 2, 3], read_value.shape.as_list())
+ # Writing with wrong shape should not work.
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Could not write to TensorArray"):
+ fed_value = np.random.random([2, 3])
+ sess.run(read_value, feed_dict={value: fed_value})
+ # Writing with correct shape should work.
+ fed_value = np.random.random([4, 5, 2, 3])
+ self.assertAllClose(fed_value,
+ sess.run(read_value, feed_dict={value: fed_value}))
+
+ def testTensorArrayGradWithShapeUnknownElementShape(self):
+ with self.test_session(use_gpu=True) as sess:
+ ta = tensor_array_ops.TensorArray(
+ size=3, dtype=dtypes.float32,
+ element_shape=None) # Note that element_shape is unknown
+ handle, flow = data_flow_ops.tensor_array_grad_with_shape(
+ handle=ta.handle,
+ flow_in=ta.flow,
+ shape_to_prepend=tensor_shape.TensorShape([4, 5]),
+ source="source")
+ ta_grad = tensor_array_ops.TensorArray(
+ dtypes.float32, handle=handle, flow=flow)
+ value = array_ops.placeholder(dtypes.float32)
+ ta_grad = ta_grad.write(0, value)
+ read_value = ta_grad.read(0)
+
+ # Make sure shape inference worked.
+ self.assertIsNone(read_value.shape.ndims)
+ # Write with some shape and check read value.
+ fed_value = np.random.random([4, 5, 7])
+ self.assertAllClose(fed_value,
+ sess.run(read_value, feed_dict={value: fed_value}))
+
@test_util.run_in_graph_and_eager_modes()
def testMultiTensorArray(self):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index 9df38d464c..ec1ba7b8f7 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -312,6 +312,40 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
return Status::OK();
}
+
+inline void FastMemcpy(void* dst, const void* src, size_t size) {
+ // clang-format off
+ switch (size) {
+ // Most compilers will generate inline code for fixed sizes,
+ // which is significantly faster for small copies.
+ case 1: memcpy(dst, src, 1); break;
+ case 2: memcpy(dst, src, 2); break;
+ case 3: memcpy(dst, src, 3); break;
+ case 4: memcpy(dst, src, 4); break;
+ case 5: memcpy(dst, src, 5); break;
+ case 6: memcpy(dst, src, 6); break;
+ case 7: memcpy(dst, src, 7); break;
+ case 8: memcpy(dst, src, 8); break;
+ case 9: memcpy(dst, src, 9); break;
+ case 10: memcpy(dst, src, 10); break;
+ case 11: memcpy(dst, src, 11); break;
+ case 12: memcpy(dst, src, 12); break;
+ case 13: memcpy(dst, src, 13); break;
+ case 14: memcpy(dst, src, 14); break;
+ case 15: memcpy(dst, src, 15); break;
+ case 16: memcpy(dst, src, 16); break;
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX) && \
+ !defined(IS_MOBILE_PLATFORM)
+ // On Linux, memmove appears to be faster than memcpy for
+ // large sizes, strangely enough.
+ default: memmove(dst, src, size); break;
+#else
+ default: memcpy(dst, src, size); break;
+#endif
+ }
+ // clang-format on
+}
+
} // namespace
// Converts the given TF_Tensor to a numpy ndarray.
@@ -362,8 +396,8 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
" bytes but TF_Tensor was ",
TF_TensorByteSize(tensor.get()), " bytes");
} else {
- memcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()),
- PyArray_NBYTES(py_array));
+ FastMemcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()),
+ PyArray_NBYTES(py_array));
}
// PyArray_Return turns rank 0 arrays into numpy scalars
@@ -377,7 +411,7 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
// Make sure we dereference this array object in case of error, etc.
Safe_PyObjectPtr array_safe(make_safe(
- PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)));
+ PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr)));
if (!array_safe) return errors::InvalidArgument("Not a ndarray.");
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d934f27cb9..ca24f11054 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -89,7 +89,7 @@ def custom_gradient(f):
operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
- to the `Tensor`s in `x. `grad_ys` is a `Tensor` or sequence of
+ to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
`Tensor`s the same size as `y` holding the initial value gradients for
each `Tensor` in `y`. If `f` uses `Variable`s (that are not part of the
inputs), i.e. through `get_variable`, then `grad_fn` should have
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 62c5adc385..abf597ca55 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
@@ -129,11 +130,6 @@ class QueueBase(object):
@{tf.RandomShuffleQueue} for concrete
implementations of this class, and instructions on how to create
them.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self, dtypes, shapes, names, queue_ref):
@@ -157,12 +153,7 @@ class QueueBase(object):
Raises:
ValueError: If one of the arguments is invalid.
- RuntimeError: If eager execution is enabled.
"""
- if context.executing_eagerly():
- raise RuntimeError(
- "Queues are not supported when eager execution is enabled. "
- "Instead, please use tf.data to get data into your model.")
self._dtypes = dtypes
if shapes is not None:
if len(shapes) != len(dtypes):
@@ -179,6 +170,8 @@ class QueueBase(object):
self._queue_ref = queue_ref
if context.executing_eagerly():
self._name = context.context().scope_name
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ queue_ref, None)
else:
self._name = self._queue_ref.op.name.split("/")[-1]
@@ -605,6 +598,11 @@ class QueueBase(object):
else:
return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
+def _shared_name(shared_name):
+ if context.executing_eagerly():
+ return str(ops.uid())
+ return shared_name
+
@tf_export("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
@@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase):
min_after_dequeue=min_after_dequeue,
seed=seed1,
seed2=seed2,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -702,11 +695,6 @@ class FIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -752,7 +740,7 @@ class FIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -843,11 +826,6 @@ class PriorityQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -899,7 +877,7 @@ class PriorityQueue(QueueBase):
component_types=types,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
priority_dtypes = [_dtypes.int64] + types
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 6891501ae1..d81c756f1c 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -83,7 +83,6 @@ def _OpsBetween(to_ops, from_ops):
return between_ops
-@test_util.with_c_api
class GradientsTest(test_util.TensorFlowTestCase):
def _OpNames(self, op_list):
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 4a32f2351b..95d05cd4d1 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1556,13 +1556,13 @@ def is_jpeg(contents, name=None):
@tf_export('image.decode_image')
-def decode_image(contents, channels=None, name=None):
+def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
and `decode_png`.
Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
- appropriate operation to convert the input bytes `string` into a `Tensor` of
- type `uint8`.
+ appropriate operation to convert the input bytes `string` into a `Tensor`
+ of type `dtype`.
Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as
opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D
@@ -1574,10 +1574,11 @@ def decode_image(contents, channels=None, name=None):
contents: 0-D `string`. The encoded image bytes.
channels: An optional `int`. Defaults to `0`. Number of color channels for
the decoded image.
+ dtype: The desired DType of the returned `Tensor`.
name: A name for the operation (optional)
Returns:
- `Tensor` with type `uint8` with shape `[height, width, num_channels]` for
+ `Tensor` with type `dtype` and shape `[height, width, num_channels]` for
BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for
GIF images.
@@ -1601,7 +1602,7 @@ def decode_image(contents, channels=None, name=None):
channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_decode, assert_channels]):
- return gen_image_ops.decode_bmp(contents)
+ return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype)
def _gif():
# Create assert to make sure that channels is not set to 1
@@ -1614,7 +1615,7 @@ def decode_image(contents, channels=None, name=None):
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
- return gen_image_ops.decode_gif(contents)
+ return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype)
def check_gif():
# Create assert op to check that bytes are GIF decodable
@@ -1623,7 +1624,11 @@ def decode_image(contents, channels=None, name=None):
def _png():
"""Decodes a PNG image."""
- return gen_image_ops.decode_png(contents, channels)
+ return convert_image_dtype(
+ gen_image_ops.decode_png(contents, channels,
+ dtype=dtypes.uint8
+ if dtype == dtypes.uint8
+ else dtypes.uint16), dtype)
def check_png():
"""Checks if an image is PNG."""
@@ -1639,7 +1644,8 @@ def decode_image(contents, channels=None, name=None):
'images')
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
- return gen_image_ops.decode_jpeg(contents, channels)
+ return convert_image_dtype(
+ gen_image_ops.decode_jpeg(contents, channels), dtype)
# Decode normal JPEG images (start with \xff\xd8\xff\xe0)
# as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1).
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index d50ff3fb60..ae45037c17 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -3888,5 +3888,88 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
self.assertAllClose(expected_batch, actual_sobel)
+class DecodeImageTest(test_util.TensorFlowTestCase):
+
+ def testJpegUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/jpeg/testdata"
+ jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
+ image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
+ dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testPngUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/png/testdata"
+ png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
+ image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(
+ image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testGifUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/gif/testdata"
+ gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+ image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
+ dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testBmpUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/bmp/testdata"
+ bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
+ image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
+ dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testJpegFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/jpeg/testdata"
+ jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
+ image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
+ dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testPngFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/png/testdata"
+ png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
+ image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(
+ image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testGifFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/gif/testdata"
+ gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+ image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
+ dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testBmpFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/bmp/testdata"
+ bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
+ image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
+ dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 3cbbf3412a..b6b98d5c86 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -55,6 +55,17 @@ def _MatrixDeterminantGrad(op, grad):
return multipliers * a_adj_inv
+@ops.RegisterGradient("LogMatrixDeterminant")
+def _LogMatrixDeterminantGrad(op, _, grad_b):
+ """Gradient for LogMatrixDeterminant."""
+ a = op.inputs[0]
+ c = op.outputs[1]
+ a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
+ multipliers = array_ops.reshape(
+ grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
+ return multipliers * a_adj_inv
+
+
@ops.RegisterGradient("Cholesky")
def _CholeskyGrad(op, grad):
"""Gradient for Cholesky."""
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 0e547689cc..fb51fbc626 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -366,6 +366,10 @@ class KeyValueTensorInitializer(TableInitializerBase):
with ops.name_scope(
self._name, values=(table.table_ref, self._keys,
self._values)) as scope:
+ if context.executing_eagerly():
+ # Ensure a unique name when eager execution is enabled to avoid spurious
+ # sharing issues.
+ scope += str(ops.uid())
init_op = gen_lookup_ops.initialize_table_v2(
table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
@@ -1108,6 +1112,10 @@ def index_table_from_tensor(vocabulary_list,
shared_name = ""
with ops.name_scope(None, "hash_table") as hash_table_scope:
+ if context.executing_eagerly():
+ # Ensure a unique name when eager execution is enabled to avoid spurious
+ # sharing issues.
+ shared_name += str(ops.uid())
table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
init = KeyValueTensorInitializer(
table_keys,
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b7e3de7e85..a141f1e2e0 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -370,7 +370,7 @@ def erf(x, name=None):
"""Computes the Gauss error function of `x` element-wise.
Args:
- x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
@@ -2225,8 +2225,8 @@ def sigmoid(x, name=None):
Returns:
A Tensor with the same type as `x`.
- @compatibility(numpy)
- Equivalent to np.scipy.special.expit
+ @compatibility(scipy)
+ Equivalent to scipy.special.expit
@end_compatibility
"""
with ops.name_scope(name, "Sigmoid", [x]) as name:
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 16c73213d5..cc23d0d133 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Script Language Operators. See the @{$python/script_ops} guide."""
# pylint: disable=g-bad-name
@@ -30,30 +29,54 @@ import numpy as np
import six
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
+# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
+# used for differentiation.
+tape_cache = {}
+
class EagerFunc(object):
"""A wrapper for a function owned by an EagerPyFunc."""
- def __init__(self, func, Tout):
+ def __init__(self, func, Tout, is_grad_func):
"""Constructs an EagerFunc.
Args:
func: The function to wrap.
Tout: A list of datatypes for the output; an empty list if the output is
None.
+ is_grad_func: Whether this EagerFunc is the gradient of another
+ EagerPyFunc.
"""
self._func = func
self._out_dtypes = Tout
+ self._is_grad_func = is_grad_func
def _convert(self, value, dtype):
+ """Converts `value` to a tensor of type `dtype`, with error checking.
+
+ Args:
+ value: The tensor to convert.
+ dtype: The desired dtype.
+
+ Returns:
+ A tensor of type `dtype`, or a zeros tensor if value is None and
+ this function is in fact a grdient function.
+
+ Raises:
+ RuntimeError: if `value` is a variable.
+ """
+
if isinstance(value, resource_variable_ops.ResourceVariable):
raise RuntimeError(
"Attempting to return a variable from an eagerly executed py_func. "
@@ -61,22 +84,40 @@ class EagerFunc(object):
"be returned; to return the value of a variable, make sure to obtain "
"the Tensor backing it by calling `.read_value()` on the variable in "
"question: %s" % value)
+ if value is None and self._is_grad_func:
+ # Gradient functions may legitimately return a list that contains
+ # both Tensors and Python Nones. Unfortuantely this breaks the
+ # OpKernel, so for now we replace None objects with zeros, which is
+ # mathematically correct but will prevent short-circuiting gradient
+ # computations.
+ #
+ # TODO(akshayka): Make it possible to return a list of both Tensors and
+ # Nones from an EagerPyFunc.
+ return constant_op.constant(0.0, dtype=dtype)
return ops.convert_to_tensor(value, dtype=dtype)
- def __call__(self, on_gpu, args):
+ def __call__(self, on_gpu, token, args):
"""Passes `args` to `self._func`, which is executed eagerly."""
+
with context.eager_mode():
- ret = self._func(*args)
- maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu()
- if isinstance(ret, (tuple, list)):
- return [
- maybe_copy_to_gpu(self._convert(x, dtype=dtype))
- for (x, dtype) in zip(ret, self._out_dtypes)
- ]
- elif ret is None:
- return ret
- else:
- return maybe_copy_to_gpu(self._convert(ret, dtype=self._out_dtypes[0]))
+ with backprop.GradientTape() as tape:
+ for tensor in args:
+ tape.watch(tensor)
+ ret = self._func(*args)
+ # NB: The tape needs to watch copies across devices.
+ maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu()
+ if isinstance(ret, (tuple, list)):
+ outputs = [
+ maybe_copy_to_gpu(self._convert(x, dtype=dtype))
+ for (x, dtype) in zip(ret, self._out_dtypes)
+ ]
+ elif ret is None:
+ outputs = None
+ else:
+ outputs = maybe_copy_to_gpu(
+ self._convert(ret, dtype=self._out_dtypes[0]))
+ tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
+ return outputs
class FuncRegistry(object):
@@ -153,7 +194,14 @@ class FuncRegistry(object):
if func is None:
raise ValueError("callback %s is not found" % token)
if isinstance(func, EagerFunc):
- return func(on_gpu, args)
+ # NB: Different invocations of the same py_func will share the same
+ # token, and the entries they stash in the tape_cache will collide.
+ # In practice, when executing a graph, this should only happen if
+ # the py_func is in a while_loop whose iterations are run in parallel
+ # or if the graph is being driven by concurrent session.run() calls.
+ #
+ # TODO(akshayka): Key the tape cache in a thread-safe way.
+ return func(on_gpu, token, args)
else:
ret = func(*args)
# Strings seem to lead to a memory leak here if they're not wrapped in a
@@ -184,7 +232,8 @@ _py_funcs = FuncRegistry()
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
-def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
+def _internal_py_func(func, inp, Tout, stateful=None, eager=False,
+ is_grad_func=False, name=None):
"""See documentation for py_func and eager_py_func."""
is_list_or_tuple = False
@@ -194,7 +243,7 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
Tout = [Tout]
if eager:
- func = EagerFunc(func, Tout)
+ func = EagerFunc(func, Tout, is_grad_func)
token = _py_funcs.insert(func)
# We tie the registered function's lifetime with the current default graph,
@@ -231,34 +280,55 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
return result if is_list_or_tuple else result[0]
+# TODO(akshayka): Implement higher-order derivatives.
+@ops.RegisterGradient("EagerPyFunc")
+def _EagerPyFuncGrad(op, dy):
+ """Computes the gradient of an EagerPyFunc."""
+
+ token = op.get_attr("token")
+
+ def eagerly_executed_grad(dy):
+ tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
+ return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
+
+ with ops.control_dependencies(op.outputs):
+ return _internal_py_func(
+ func=eagerly_executed_grad,
+ inp=[dy] if isinstance(dy, ops.Tensor) else dy,
+ Tout=[tensor.dtype for tensor in op.inputs],
+ eager=True, is_grad_func=True)
+
+
def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op that executes it eagerly.
This function allows expressing computations in a TensorFlow graph as
Python functions. In particular, it wraps a Python function `func`
- in a TensorFlow operation that executes it with eager exeuction enabled. As a
- consequence, `tf.contrib.eager.py_func` makes it possible to express control
- flow using Python constructs (`if`, `while`, `for`, etc.), instead of
- TensorFlow control flow constructs (@{tf.cond}, @{tf.while_loop}). For
- example, you might use `tf.contrib.eager.py_func` to implement the log huber
- function:
+ in a once-differentiable TensorFlow operation that executes it with eager
+ exeuction enabled. As a consequence, `tf.contrib.eager.py_func` makes it
+ possible to express control flow using Python constructs (`if`, `while`,
+ `for`, etc.), instead of TensorFlow control flow constructs (@{tf.cond},
+ @{tf.while_loop}). For example, you might use `tf.contrib.eager.py_func` to
+ implement the log huber function:
```python
def log_huber(x, m):
if tf.abs(x) <= m:
- return x ** 2
+ return x**2
else:
- return m ** 2 * (1 - 2 * tf.log(m) + tf.log(x ** 2))
+ return m**2 * (1 - 2 * tf.log(m) + tf.log(x**2))
x = tf.placeholder(tf.float32)
m = tf.placeholder(tf.float32)
y = tf.contrib.eager.py_func(func=log_huber, inp=[x, m], Tout=tf.float32)
+ dy_dx = tf.gradients(y, x)[0]
with tf.Session() as sess:
# The session executes `log_huber` eagerly. Given the feed values below,
- # it will take the second branch, so `output` evaluates to 7.24372.
- output = sess.run(y, feed_dict={x: 3.0, m: 2.0})
+ # it will take the first branch, so `y` evaluates to 1.0 and
+ # `dy_dx` evaluates to 2.0.
+ y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
```
You can also use `tf.contrib.eager.py_func` to debug your models at runtime
@@ -277,10 +347,6 @@ def eager_py_func(func, inp, Tout, name=None):
that take Tensors as inputs, execute TensorFlow operations in their bodies,
and return Tensors as outputs.
- `tf.contrib.eager.py_func` is not differentiable, though a gradient may be
- implemented in the future; if you would like to differentiate through it,
- please file an issue on Github.
-
Like @{tf.py_func}, `tf.contrib.eager.py_func` has the following limitations
with respect to serialization and distribution:
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index d7c3a7e8dc..6118b54293 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -285,8 +285,8 @@ class EinsumTest(test.TestCase):
correct_value = np.einsum(axes, *input_vals)
err = np.abs(correct_value - output_value).max()
- print(axes, err)
- assert err < 1e-8
+ # print(axes, err)
+ self.assertLess(err, 1e-8)
def test_input_is_placeholder(self):
with ops.Graph().as_default():
@@ -298,8 +298,7 @@ class EinsumTest(test.TestCase):
m0: [[1, 2, 3]],
m1: [[2], [1], [1]],
}
- np.testing.assert_almost_equal([[7]], sess.run(
- out, feed_dict=feed_dict))
+ self.assertAllClose([[7]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 3))
@@ -310,7 +309,7 @@ class EinsumTest(test.TestCase):
m0: [[1, 2, 3]],
m1: [2, 1, 1],
}
- np.testing.assert_almost_equal([7], sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([7], sess.run(out, feed_dict=feed_dict))
# Tests for placeholders which have two or more None values
with ops.Graph().as_default():
@@ -322,8 +321,7 @@ class EinsumTest(test.TestCase):
m0: [[[1, 2]]],
m1: [[3], [2]],
}
- np.testing.assert_almost_equal([[[7]]],
- sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([[[7]]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
@@ -334,8 +332,7 @@ class EinsumTest(test.TestCase):
m0: [[3], [2]],
m1: [[[1, 2]]],
}
- np.testing.assert_almost_equal([[[7]]],
- sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([[[7]]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
@@ -346,8 +343,7 @@ class EinsumTest(test.TestCase):
m0: [[[1, 2]]],
m1: [3, 2],
}
- np.testing.assert_almost_equal([[7]], sess.run(
- out, feed_dict=feed_dict))
+ self.assertAllClose([[7]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
@@ -358,8 +354,7 @@ class EinsumTest(test.TestCase):
m0: [[[[1, 2]], [[2, 1]]]],
m1: [[3, 2]],
}
- np.testing.assert_almost_equal([[[7, 8]]],
- sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([[[7, 8]]], sess.run(out, feed_dict=feed_dict))
if __name__ == '__main__':
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 94d7458ec8..08b7cda73b 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -338,7 +338,6 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
Args:
ref: A Variable.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
- A Tensor. Must be one of the following types: int32, int64.
A tensor of indices into ref.
updates: A `Tensor`. Must have the same type as `ref`.
A Tensor. Must have the same type as ref. A tensor of updated
@@ -355,10 +354,9 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_nd_update(
ref, indices, updates, use_locking, name)
- with ops.control_dependencies([gen_state_ops.resource_scatter_nd_update(
- ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype),
- use_locking, name)]):
- return ref.read_value()
+ return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
@tf_export("scatter_add")
@@ -411,3 +409,67 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None):
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
+
+
+@tf_export("scatter_nd_add")
+def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
+ r"""Applies sparse addition to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ add = tf.scatter_nd_add(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(add)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+ See @{tf.scatter_nd} for more details about how to make updates to
+ slices.
+
+ Args:
+ ref: A mutable `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
+ `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
+ `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into ref.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to add to ref.
+ use_locking: An optional `bool`. Defaults to `True`.
+ An optional bool. Defaults to True. If True, the assignment will
+ be protected by a lock; otherwise the behavior is undefined,
+ but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ A mutable `Tensor`. Has the same type as `ref`.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_nd_add(
+ ref, indices, updates, use_locking, name)
+ return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py
index 1f70d69548..d341349804 100644
--- a/tensorflow/python/ops/tensor_array_grad.py
+++ b/tensorflow/python/ops/tensor_array_grad.py
@@ -34,6 +34,7 @@ ops.NotDifferentiable("TensorArrayCloseV2")
ops.NotDifferentiable("TensorArrayV3")
ops.NotDifferentiable("TensorArrayGradV3")
+ops.NotDifferentiable("TensorArrayGradWithShape")
ops.NotDifferentiable("TensorArraySizeV3")
ops.NotDifferentiable("TensorArrayCloseV3")
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 2609a5d222..81786fbf43 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -149,6 +149,7 @@ py_test(
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:state_ops",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 24a13c0f33..e58be804c2 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -270,6 +270,18 @@ class SavedModelBuilder(object):
self._add_train_op(train_op)
+ def _maybe_create_saver(self, saver=None):
+ """Creates a sharded saver if one does not already exist."""
+ if not saver:
+ # Initialize a saver to generate a sharded output for all saveables in the
+ # current scope.
+ saver = tf_saver.Saver(
+ variables._all_saveable_objects(), # pylint: disable=protected-access
+ sharded=True,
+ write_version=saver_pb2.SaverDef.V2,
+ allow_empty=True)
+ return saver
+
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -277,7 +289,8 @@ class SavedModelBuilder(object):
legacy_init_op=None,
clear_devices=False,
main_op=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel.
@@ -302,6 +315,9 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ saver: An instance of tf.train.Saver that will be used to export the
+ metagraph. If None, a sharded Saver that restores all variables will
+ be used.
Raises:
AssertionError: If the variables for the SavedModel have not been saved
@@ -320,18 +336,11 @@ class SavedModelBuilder(object):
# Add assets and ops
self._add_collections(assets_collection, legacy_init_op, main_op, None)
- # Initialize a saver to generate a sharded output for all saveables in the
- # current scope.
- saver = tf_saver.Saver(
- variables._all_saveable_objects(), # pylint: disable=protected-access
- sharded=True,
- write_version=saver_pb2.SaverDef.V2,
- allow_empty=True)
+ saver = self._maybe_create_saver(saver)
# The graph almost certainly previously contained at least one Saver, and
# possibly several (e.g. one for loading a pretrained embedding, and another
- # for the model weights). However, a *new* Saver was just created that
- # includes all of the variables. Removing the preexisting ones was the
+ # for the model weights). Removing the preexisting ones was the
# motivation for the clear_extraneous_savers option, but it turns out that
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
@@ -350,7 +359,8 @@ class SavedModelBuilder(object):
legacy_init_op=None,
clear_devices=False,
main_op=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel and saves variables.
@@ -377,6 +387,9 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ saver: An instance of tf.train.Saver that will be used to export the
+ metagraph and save variables. If None, a sharded Saver that restores
+ all variables will be used.
"""
# pylint: enable=line-too-long
@@ -403,13 +416,7 @@ class SavedModelBuilder(object):
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))
- # Initialize a saver to generate a sharded output for all saveables in the
- # current scope.
- saver = tf_saver.Saver(
- variables._all_saveable_objects(), # pylint: disable=protected-access
- sharded=True,
- write_version=saver_pb2.SaverDef.V2,
- allow_empty=True)
+ saver = self._maybe_create_saver(saver)
# Save the variables. Also, disable writing the checkpoint state proto. The
# file is not used during SavedModel loading. In addition, since a
@@ -421,8 +428,7 @@ class SavedModelBuilder(object):
# The graph almost certainly previously contained at least one Saver, and
# possibly several (e.g. one for loading a pretrained embedding, and another
- # for the model weights). However, a *new* Saver was just created that
- # includes all of the variables. Removing the preexisting ones was the
+ # for the model weights). Removing the preexisting ones was the
# motivation for the clear_extraneous_savers option, but it turns out that
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 7302c77ad5..effb38283b 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver_test_utils
+from tensorflow.python.training import training
from tensorflow.python.util import compat
SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
@@ -1122,6 +1123,80 @@ class SavedModelTest(test.TestCase):
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())
+ def testCustomSaver(self):
+ export_dir = self._get_export_dir("test_custom_saver")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ custom_saver = training.Saver(name="my_saver")
+ builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, ["tag"], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue("my_saver/restore_all" in graph_ops)
+ self.assertFalse("save/restore_all" in graph_ops)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name, "my_saver/restore_all")
+
+ def testNoCustomSaver(self):
+ export_dir = self._get_export_dir("test_no_custom_saver")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ training.Saver(name="my_saver")
+ builder.add_meta_graph_and_variables(sess, ["tag"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, ["tag"], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue("my_saver/restore_all" in graph_ops)
+ self.assertTrue("save/restore_all" in graph_ops)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name, "save/restore_all")
+
+ def testMultipleCustomSavers(self):
+ export_dir = self._get_export_dir("test_multiple_custom_savers")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(sess, ["tag_0"])
+
+ saver_1 = training.Saver()
+ builder.add_meta_graph(["tag_1"], saver=saver_1)
+
+ saver_2 = training.Saver()
+ builder.add_meta_graph(["tag_2"], saver=saver_2)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ def _validate_custom_saver(tag_name, saver_name):
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, [tag_name], export_dir)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name,
+ saver_name)
+
+ _validate_custom_saver("tag_0", "save/restore_all")
+ _validate_custom_saver("tag_1", "save_1/restore_all")
+ _validate_custom_saver("tag_2", "save_2/restore_all")
+
def testClearDevices(self):
export_dir = self._get_export_dir("test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py
index c08e3cca00..95eca76496 100644
--- a/tensorflow/python/training/adadelta.py
+++ b/tensorflow/python/training/adadelta.py
@@ -46,6 +46,13 @@ class AdadeltaOptimizer(optimizer.Optimizer):
use_locking: If `True` use locks for update operations.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
+ each be a callable that takes no arguments and returns the actual value to
+ use. This can be useful for changing these values across different
+ invocations of optimizer functions.
+ @end_compatibility
"""
super(AdadeltaOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
@@ -63,9 +70,13 @@ class AdadeltaOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "accum_update", self._name)
def _prepare(self):
- self._lr_t = ops.convert_to_tensor(self._lr, name="lr")
- self._rho_t = ops.convert_to_tensor(self._rho, name="rho")
- self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+ lr = self._call_if_callable(self._lr)
+ rho = self._call_if_callable(self._rho)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._lr_t = ops.convert_to_tensor(lr, name="lr")
+ self._rho_t = ops.convert_to_tensor(rho, name="rho")
+ self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
accum = self.get_slot(var, "accum")
diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py
index 50f435236b..2678016d24 100644
--- a/tensorflow/python/training/adadelta_test.py
+++ b/tensorflow/python/training/adadelta_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -32,44 +34,52 @@ from tensorflow.python.training import adadelta
class AdadeltaOptimizerTest(test.TestCase):
- def doTestBasic(self, use_resource=False):
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
num_updates = 4 # number of ADADELTA steps to perform
for dtype in [dtypes.half, dtypes.float32]:
for grad in [0.2, 0.1, 0.01]:
for lr in [1.0, 0.5, 0.1]:
- with self.test_session():
- var0_init = [1.0, 2.0]
- var1_init = [3.0, 4.0]
- if use_resource:
- var0 = resource_variable_ops.ResourceVariable(
- var0_init, dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable(
- var1_init, dtype=dtype)
- else:
- var0 = variables.Variable(var0_init, dtype=dtype)
- var1 = variables.Variable(var1_init, dtype=dtype)
-
- grads = constant_op.constant([grad, grad], dtype=dtype)
-
- accum = 0.0
- accum_update = 0.0
-
- # ADADELTA gradient optimizer
- rho = 0.95
- epsilon = 1e-8
- adadelta_opt = adadelta.AdadeltaOptimizer(lr, rho, epsilon)
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+ else:
+ var0 = variables.Variable(var0_init, dtype=dtype)
+ var1 = variables.Variable(var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ if use_callable_params:
+ adadelta_opt = adadelta.AdadeltaOptimizer(
+ learning_rate=lambda: lr, # pylint: disable=cell-var-from-loop
+ rho=lambda: rho, # pylint: disable=cell-var-from-loop
+ epsilon=lambda: epsilon) # pylint: disable=cell-var-from-loop
+ else:
+ adadelta_opt = adadelta.AdadeltaOptimizer(
+ learning_rate=lr, rho=rho, epsilon=epsilon)
+ if not context.executing_eagerly():
adadelta_update = adadelta_opt.apply_gradients(
zip([grads, grads], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+ # TODO(lxuechen): This is hard to test in eager mode,
+ # since the optimizer is not fully initialized until the first
+ # call to `apply_gradients`
opt_vars = adadelta_opt.variables()
self.assertStartsWith(opt_vars[0].name, var0._shared_name)
self.assertStartsWith(opt_vars[1].name, var0._shared_name)
self.assertStartsWith(opt_vars[2].name, var1._shared_name)
self.assertStartsWith(opt_vars[3].name, var1._shared_name)
self.assertEqual(4, len(opt_vars))
-
- variables.global_variables_initializer().run()
-
# Assign slots
slot = [None] * 2
slot_update = [None] * 2
@@ -91,36 +101,42 @@ class AdadeltaOptimizerTest(test.TestCase):
self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
self.assertFalse(slot_update[1] in variables.trainable_variables())
- # Fetch params to validate initial values
- self.assertAllClose(var0_init, var0.eval())
- self.assertAllClose(var1_init, var1.eval())
-
- update = [None] * num_updates
- tot_update = 0
- for step in range(num_updates):
- # Run adadelta update for comparison
- adadelta_update.run()
-
- # Perform initial update without previous accum values
- accum = accum * rho + (grad**2) * (1 - rho)
- update[step] = (np.sqrt(accum_update + epsilon) *
- (1. / np.sqrt(accum + epsilon)) * grad)
- accum_update = (accum_update * rho + (update[step]**2) *
- (1.0 - rho))
- tot_update += update[step] * lr
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, self.evaluate(var0))
+ self.assertAllClose(var1_init, self.evaluate(var1))
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ if not context.executing_eagerly():
+ self.evaluate(adadelta_update)
+ else:
+ adadelta_opt.apply_gradients(zip([grads, grads], [var0, var1]))
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (
+ np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (
+ accum_update * rho + (update[step]**2) * (1.0 - rho))
+ tot_update += update[step] * lr
+
+ if not context.executing_eagerly():
# Check that the accumulators have been updated
+ # TODO(lxuechen): This is hard to test in eager mode
for slot_idx in range(2):
self.assertAllCloseAccordingToType(
np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
- slot[slot_idx].eval(),
+ self.evaluate(slot[slot_idx]),
rtol=1e-5)
self.assertAllCloseAccordingToType(
np.array(
[accum_update, accum_update],
dtype=dtype.as_numpy_dtype()),
- slot_update[slot_idx].eval(),
+ self.evaluate(slot_update[slot_idx]),
rtol=1e-5)
# Check that the parameters have been updated
@@ -128,22 +144,28 @@ class AdadeltaOptimizerTest(test.TestCase):
np.array(
[var0_init[0] - tot_update, var0_init[1] - tot_update],
dtype=dtype.as_numpy_dtype()),
- var0.eval(),
+ self.evaluate(var0),
rtol=1e-5)
self.assertAllCloseAccordingToType(
np.array(
[var1_init[0] - tot_update, var1_init[1] - tot_update],
dtype=dtype.as_numpy_dtype()),
- var1.eval(),
+ self.evaluate(var1),
rtol=1e-5)
def testBasic(self):
- self.doTestBasic(use_resource=False)
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index deb4e6f546..6778f3c735 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -51,6 +51,13 @@ class AdagradOptimizer(optimizer.Optimizer):
Raises:
ValueError: If the `initial_accumulator_value` is invalid.
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate` can be a callable that
+ takes no arguments and returns the actual value to use. This can be useful
+ for changing these values across different invocations of optimizer
+ functions.
+ @end_compatibility
"""
if initial_accumulator_value <= 0.0:
raise ValueError("initial_accumulator_value must be positive: %s" %
@@ -78,8 +85,9 @@ class AdagradOptimizer(optimizer.Optimizer):
"accumulator", self._name)
def _prepare(self):
- self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
- name="learning_rate")
+ learning_rate = self._call_if_callable(self._learning_rate)
+ self._learning_rate_tensor = ops.convert_to_tensor(
+ learning_rate, name="learning_rate")
def _apply_dense(self, grad, var):
acc = self.get_slot(var, "accumulator")
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index 15b007b46d..c9aec33d09 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -20,9 +20,11 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -34,40 +36,63 @@ from tensorflow.python.training import adagrad
class AdagradOptimizerTest(test.TestCase):
- def doTestBasic(self, use_locking=False, use_resource=False):
+ def doTestBasic(self,
+ use_locking=False,
+ use_resource=False,
+ use_callable_params=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- if use_resource:
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- else:
- var0 = variables.Variable([1.0, 2.0], dtype=dtype)
- var1 = variables.Variable([3.0, 4.0], dtype=dtype)
- grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
- grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- ada_opt = adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1, use_locking=use_locking)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ learning_rate = lambda: 3.0
+ if not use_callable_params:
+ learning_rate = learning_rate()
+
+ ada_opt = adagrad.AdagradOptimizer(
+ learning_rate, initial_accumulator_value=0.1, use_locking=use_locking)
+
+ if not context.executing_eagerly():
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of adagrad
- for _ in range(3):
- ada_update.run()
- # Validate updated params
- self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+ self.evaluate(variables.global_variables_initializer())
+
+ # Fetch params to validate initial values
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([3.0, 4.0], v1_val)
+
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ if not context.executing_eagerly():
+ self.evaluate(ada_update)
+ else:
+ ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ # Validate updated params
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), v1_val)
def testBasic(self):
self.doTestBasic(use_locking=False)
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testBasicResource(self):
self.doTestBasic(use_locking=False, use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(
+ use_locking=False, use_resource=True, use_callable_params=True)
+
def testBasicLocked(self):
self.doTestBasic(use_locking=True)
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 6fa3ff6658..b65c88e972 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -85,6 +85,13 @@ class AdamOptimizer(optimizer.Optimizer):
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions.
+ @end_compatibility
"""
super(AdamOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
@@ -128,10 +135,15 @@ class AdamOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "v", self._name)
def _prepare(self):
- self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
- self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
- self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
- self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+ lr = self._call_if_callable(self._lr)
+ beta1 = self._call_if_callable(self._beta1)
+ beta2 = self._call_if_callable(self._beta2)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
+ self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
+ self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
+ self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index bc68f24c6f..ccdc7e384d 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -150,7 +150,7 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
- def doTestBasic(self, use_resource=False):
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
with self.test_session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
@@ -171,7 +171,17 @@ class AdamOptimizerTest(test.TestCase):
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
- opt = adam.AdamOptimizer()
+ learning_rate = lambda: 0.001
+ beta1 = lambda: 0.9
+ beta2 = lambda: 0.999
+ epsilon = lambda: 1e-8
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ beta1 = beta1()
+ beta2 = beta2()
+ epsilon = epsilon()
+
+ opt = adam.AdamOptimizer(learning_rate=learning_rate)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
opt_variables = opt.variables()
beta1_power, beta2_power = opt._get_beta_accumulators()
@@ -221,6 +231,10 @@ class AdamOptimizerTest(test.TestCase):
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index ab8b37bb65..7cd175f25b 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -946,7 +946,7 @@ class DistributionStrategy(object):
return control_flow_ops.group(value, name=name)
# Special handling for the common case of one op.
v, = value
- if isinstance(v, ops.Tensor):
+ if hasattr(v, "op"):
v = v.op
return v
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index a07ad19a6e..ef50f6315d 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -41,6 +40,13 @@ class GradientDescentOptimizer(optimizer.Optimizer):
use_locking: If True use locks for update operations.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "GradientDescent".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate` can be a callable that
+ takes no arguments and returns the actual value to use. This can be useful
+ for changing these values across different invocations of optimizer
+ functions.
+ @end_compatibility
"""
super(GradientDescentOptimizer, self).__init__(use_locking, name)
self._learning_rate = learning_rate
@@ -71,7 +77,6 @@ class GradientDescentOptimizer(optimizer.Optimizer):
return var.scatter_sub(delta, use_locking=self._use_locking)
def _prepare(self):
- if not context.executing_eagerly() or not isinstance(
- self._learning_rate_tensor, ops.EagerTensor):
- self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
- name="learning_rate")
+ learning_rate = self._call_if_callable(self._learning_rate)
+ self._learning_rate_tensor = ops.convert_to_tensor(
+ learning_rate, name="learning_rate")
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index f89a9c5838..b304e92421 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -83,6 +83,32 @@ class GradientDescentOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
var1.eval())
+ def testBasicCallableParams(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ lr = lambda: 3.0
+ sgd_op = gradient_descent.GradientDescentOptimizer(lr).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
def testMinimizeResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index bd9fa79d8f..cb3ec6f053 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -61,8 +61,8 @@ class MomentumOptimizer(optimizer.Optimizer):
variable(s) track the values called `theta_t + mu*v_t` in the paper.
@compatibility(eager)
- When eager execution is enabled, learning_rate and momentum can each be a
- callable that takes no arguments and returns the actual value to use. This
+ When eager execution is enabled, `learning_rate` and `momentum` can each be
+ a callable that takes no arguments and returns the actual value to use. This
can be useful for changing these values across different invocations of
optimizer functions.
@end_compatibility
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index a9287a0f0d..cae29eea93 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -1211,3 +1211,7 @@ class Optimizer(
self._deferred_slot_restorations.setdefault(
slot_name, {}).setdefault(variable_key, []).append(
slot_variable_position)
+
+ def _call_if_callable(self, param):
+ """Call the function if param is callable."""
+ return param() if callable(param) else param
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index 341b970c92..f38c9861d6 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -92,6 +92,13 @@ class RMSPropOptimizer(optimizer.Optimizer):
computation and memory. Defaults to False.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "RMSProp".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions.
+ @end_compatibility
"""
super(RMSPropOptimizer, self).__init__(use_locking, name)
self._learning_rate = learning_rate
@@ -120,12 +127,15 @@ class RMSPropOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "momentum", self._name)
def _prepare(self):
- self._learning_rate_tensor = ops.convert_to_tensor(
- self._learning_rate, name="learning_rate")
- self._decay_tensor = ops.convert_to_tensor(self._decay, name="decay")
- self._momentum_tensor = ops.convert_to_tensor(
- self._momentum, name="momentum")
- self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon")
+ lr = self._call_if_callable(self._learning_rate)
+ decay = self._call_if_callable(self._decay)
+ momentum = self._call_if_callable(self._momentum)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._learning_rate_tensor = ops.convert_to_tensor(lr, name="learning_rate")
+ self._decay_tensor = ops.convert_to_tensor(decay, name="decay")
+ self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
+ self._epsilon_tensor = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
rms = self.get_slot(var, "rms")
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index ee5385596c..6043327384 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -24,6 +24,7 @@ import math
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -141,7 +142,7 @@ class RMSPropOptimizerTest(test.TestCase):
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 4 steps of RMSProp
- for t in range(1, 5):
+ for _ in range(1, 5):
update.run()
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
@@ -261,7 +262,7 @@ class RMSPropOptimizerTest(test.TestCase):
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 4 steps of RMSProp
- for t in range(1, 5):
+ for _ in range(1, 5):
update.run()
var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
@@ -444,6 +445,55 @@ class RMSPropOptimizerTest(test.TestCase):
(0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)))
]), var1.eval())
+ def testCallableParams(self):
+ with context.eager_mode():
+ for dtype in [dtypes.half, dtypes.float32]:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ learning_rate = lambda: 2.0
+ decay = lambda: 0.9
+ momentum = lambda: 0.0
+ epsilon = lambda: 1.0
+ opt = rmsprop.RMSPropOptimizer(learning_rate, decay, momentum, epsilon)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0))
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
+ ]), self.evaluate(var1))
+ # Step 2: the root mean square accumulators contain the previous update.
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0))
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0))
+ ]), self.evaluate(var1))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index bf3961c692..e154ffb68a 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -41,17 +41,35 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import functools
import sys
from tensorflow.python.util import tf_decorator
+ESTIMATOR_API_NAME = 'estimator'
+TENSORFLOW_API_NAME = 'tensorflow'
+
+_Attributes = collections.namedtuple(
+ 'ExportedApiAttributes', ['names', 'constants'])
+
+# Attribute values must be unique to each API.
+API_ATTRS = {
+ TENSORFLOW_API_NAME: _Attributes(
+ '_tf_api_names',
+ '_tf_api_constants'),
+ ESTIMATOR_API_NAME: _Attributes(
+ '_estimator_api_names',
+ '_estimator_api_constants')
+}
+
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
pass
-class tf_export(object): # pylint: disable=invalid-name
+class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
def __init__(self, *args, **kwargs):
@@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name
overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
- allow_multiple_exports: Allows exporting the same symbol multiple
- times with multiple `tf_export` usages. Prefer however, to list all
- of the exported names in a single `tf_export` usage when possible.
-
+ api_name: Name of the API you want to generate (e.g. `tensorflow` or
+ `estimator`). Default is `tensorflow`.
"""
self._names = args
+ self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
- self._allow_multiple_exports = kwargs.get(
- 'allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
+ api_names_attr = API_ATTRS[self._api_name].names
+
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
- del undecorated_f._tf_api_names # pylint: disable=protected-access
+ delattr(undecorated_f, api_names_attr)
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
- if '_tf_api_names' in undecorated_func.__dict__:
- if self._allow_multiple_exports:
- undecorated_func._tf_api_names += self._names # pylint: disable=protected-access
- else:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access
- else:
- undecorated_func._tf_api_names = self._names # pylint: disable=protected-access
+ if api_names_attr in undecorated_func.__dict__:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (undecorated_func.__name__, getattr(
+ undecorated_func, api_names_attr))) # pylint: disable=protected-access
+ setattr(undecorated_func, api_names_attr, self._names)
return func
def export_constant(self, module_name, name):
@@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name
name: (string) Current constant name.
"""
module = sys.modules[module_name]
- if not hasattr(module, '_tf_api_constants'):
- module._tf_api_constants = [] # pylint: disable=protected-access
+ if not hasattr(module, API_ATTRS[self._api_name].constants):
+ setattr(module, API_ATTRS[self._api_name].constants, [])
# pylint: disable=protected-access
- module._tf_api_constants.append((self._names, name))
+ getattr(module, API_ATTRS[self._api_name].constants).append(
+ (self._names, name))
+
+tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
+estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py
index ace3f054ba..b9e26ecb33 100644
--- a/tensorflow/python/util/tf_export_test.py
+++ b/tensorflow/python/util/tf_export_test.py
@@ -128,13 +128,6 @@ class ValidateExportTest(test.TestCase):
with self.assertRaises(tf_export.SymbolAlreadyExposedError):
export_decorator(_test_function)
- def testEAllowMultipleExports(self):
- _test_function._tf_api_names = ['name1', 'name2']
- tf_export.tf_export('nameRed', 'nameBlue', allow_multiple_exports=True)(
- _test_function)
- self.assertEquals(['name1', 'name2', 'nameRed', 'nameBlue'],
- _test_function._tf_api_names)
-
def testOverridesFunction(self):
_test_function2._tf_api_names = ['abc']
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 0dd406aa4e..c79d8a8445 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -33,6 +33,8 @@ namespace {
PyObject* CollectionsSequenceType = nullptr;
PyTypeObject* SparseTensorValueType = nullptr;
+const int kMaxItemsInCache = 1024;
+
bool WarnedThatSetIsNotSequence = false;
bool IsString(PyObject* o) {
@@ -196,11 +198,14 @@ int IsSequenceHelper(PyObject* o) {
// NOTE: This is never decref'd, but we don't want the type to get deleted
// as long as it is in the map. This should not be too much of a
// leak, as there should only be a relatively small number of types in the
- // map, and an even smaller number that are eligible for decref.
- Py_INCREF(type);
+ // map, and an even smaller number that are eligible for decref. As a
+ // precaution, we limit the size of the map to 1024.
{
mutex_lock l(g_type_to_sequence_map);
- type_to_sequence_map->insert({type, is_sequence});
+ if (type_to_sequence_map->size() < kMaxItemsInCache) {
+ Py_INCREF(type);
+ type_to_sequence_map->insert({type, is_sequence});
+ }
}
return is_sequence;
diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md
index 44f51ad07b..ea39e17ab2 100644
--- a/tensorflow/security/index.md
+++ b/tensorflow/security/index.md
@@ -4,7 +4,7 @@ We regularly publish security advisories about using TensorFlow.
*Note*: In conjunction with these security advisories, we strongly encourage
TensorFlow users to read and understand TensorFlow's security model as outlined
-in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.md).
+in (https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)[SECURITY.md].
| Advisory Number | Type | Versions affected | Reported by | Additional Information |
|-----------------|--------------------|:-----------------:|-----------------------|-----------------------------|
@@ -14,5 +14,5 @@ in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.m
| [TFSA-2018-003](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-003.md) | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | |
| [TFSA-2018-002](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-002.md) | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | |
| [TFSA-2018-001](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-001.md) | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | |
-| - | Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+| - | Out Of Bounds Read | <= 1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index f6564df0d0..48afc06e32 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -2291,9 +2291,7 @@ class CudnnEnvVar {
// algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
struct FftTilingForward {
static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
- // TODO(csigg): Enabling this algo causes XLA test failures, for example in
- // platforms/xla/tests/internal:convolution_test_gpu. See b/80018418.
- static constexpr bool kDefaultFlag = false; // CUDNN_VERSION >= 7000;
+ static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
};
// A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
@@ -2426,6 +2424,33 @@ port::Status CudnnSupport::DoConvolveImpl(
}
}
+ // Report an error if we might be hitting a cuDNN bug that accesses illegal
+ // memory. See nvbugs/2138754, b/80018418.
+ SE_RETURN_IF_ERROR([&] {
+ if (algo_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) {
+ return port::Status::OK();
+ }
+ if (input_descriptor.ndims() < 3) {
+ return port::Status::OK();
+ }
+ // Checks that a*b is within the valid range (as provided by NVIDIA).
+ auto check_sizes = [](size_t a, size_t b) {
+ if ((a * b * 4608 - 1) >> 31 == 0) {
+ return port::Status::OK();
+ }
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ "This configuration potentially accesses illegal memory.");
+ };
+ SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(),
+ output_descriptor.feature_map_count()));
+ SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
+ input_descriptor.feature_map_count()));
+ SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
+ output_descriptor.feature_map_count()));
+ return port::Status::OK();
+ }());
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
@@ -3192,6 +3217,34 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
}
}
+ // Report an error if we might be hitting a cuDNN bug that produces incorrect
+ // results. See nvbugs/2072856
+ SE_RETURN_IF_ERROR([&] {
+ if (algo_desc.algo_id() != CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) {
+ return port::Status::OK();
+ }
+ if (output_descriptor.height() > 1 && output_descriptor.width() > 1) {
+ return port::Status::OK();
+ }
+ int convolution_size = output_descriptor.height() > 1
+ ? filter_descriptor.input_filter_height()
+ : filter_descriptor.input_filter_width();
+ if (convolution_size <= 32) {
+ return port::Status::OK();
+ }
+ cudnnConvolutionMode_t convolution_mode;
+ cudnnDataType_t compute_type;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor(
+ conv.handle(), 0, nullptr, nullptr, nullptr, nullptr, &convolution_mode,
+ &compute_type));
+ if (convolution_mode != CUDNN_CONVOLUTION) {
+ return port::Status::OK();
+ }
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ "This configuration potentially produces incorrect results.");
+ }());
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
cudnn.handle(),
/*alpha=*/alpha,
diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl
index fe3e4d1434..41713a94ec 100644
--- a/tensorflow/tools/api/generator/api_gen.bzl
+++ b/tensorflow/tools/api/generator/api_gen.bzl
@@ -11,9 +11,6 @@ TENSORFLOW_API_INIT_FILES = [
"distributions/__init__.py",
"distributions/bijectors/__init__.py",
"errors/__init__.py",
- "estimator/__init__.py",
- "estimator/export/__init__.py",
- "estimator/inputs/__init__.py",
"feature_column/__init__.py",
"gfile/__init__.py",
"graph_util/__init__.py",
@@ -91,6 +88,16 @@ TENSORFLOW_API_INIT_FILES = [
# END GENERATED FILES
]
+# keep sorted
+ESTIMATOR_API_INIT_FILES = [
+ # BEGIN GENERATED ESTIMATOR FILES
+ "__init__.py",
+ "estimator/__init__.py",
+ "estimator/export/__init__.py",
+ "estimator/inputs/__init__.py",
+ # END GENERATED ESTIMATOR FILES
+]
+
# Creates a genrule that generates a directory structure with __init__.py
# files that import all exported modules (i.e. modules with tf_export
# decorators).
@@ -110,7 +117,9 @@ TENSORFLOW_API_INIT_FILES = [
def gen_api_init_files(name,
output_files=TENSORFLOW_API_INIT_FILES,
root_init_template=None,
- srcs=[]):
+ srcs=[],
+ api_name="tensorflow",
+ package="tensorflow.python"):
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
@@ -119,7 +128,8 @@ def gen_api_init_files(name,
outs = output_files,
cmd = (
"$(location //tensorflow/tools/api/generator:create_python_api) " +
- root_init_template_flag + " --apidir=$(@D) $(OUTS)"),
+ root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
srcs = srcs,
tools = ["//tensorflow/tools/api/generator:create_python_api"],
+ visibility = ["//tensorflow:__pkg__"],
)
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index 9f210ad42b..bc8c0a7efe 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -25,10 +25,10 @@ import os
import sys
from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+API_ATTRS = tf_export.API_ATTRS
-_API_CONSTANTS_ATTR = '_tf_api_constants'
-_API_NAMES_ATTR = '_tf_api_names'
_DEFAULT_PACKAGE = 'tensorflow.python'
_GENFILES_DIR_SUFFIX = 'genfiles/'
_SYMBOLS_TO_SKIP_EXPLICITLY = {
@@ -159,12 +159,13 @@ __all__.remove('print_function')
return module_text_map
-def get_api_init_text(package):
+def get_api_init_text(package, api_name):
"""Get a map from destination module to __init__.py code for that module.
Args:
package: Base python package containing python with target tf_export
decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Returns:
A dictionary where
@@ -192,7 +193,7 @@ def get_api_init_text(package):
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
- if module_contents_name == _API_CONSTANTS_ATTR:
+ if module_contents_name == API_ATTRS[api_name].constants:
for exports, value in attr:
for export in exports:
names = export.split('.')
@@ -201,15 +202,12 @@ def get_api_init_text(package):
-1, dest_module, module.__name__, value, names[-1])
continue
- try:
- _, attr = tf_decorator.unwrap(attr)
- except Exception as e:
- print('5555: %s %s' % (module, module_contents_name), file=sys.stderr)
- raise e
+ _, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
- if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
- for export in attr._tf_api_names: # pylint: disable=protected-access
+ if (hasattr(attr, '__dict__') and
+ API_ATTRS[api_name].names in attr.__dict__):
+ for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
module_code_builder.add_import(
@@ -246,7 +244,7 @@ def get_module(dir_path, relative_to_dir):
relative_to_dir: Get module relative to this directory.
Returns:
- module that corresponds to the given directory.
+ Name of module that corresponds to the given directory.
"""
dir_path = dir_path[len(relative_to_dir):]
# Convert path separators to '/' for easier parsing below.
@@ -255,7 +253,7 @@ def get_module(dir_path, relative_to_dir):
def create_api_files(
- output_files, package, root_init_template, output_dir):
+ output_files, package, root_init_template, output_dir, api_name):
"""Creates __init__.py files for the Python API.
Args:
@@ -267,6 +265,7 @@ def create_api_files(
"#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
output_dir: output API root directory.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Raises:
ValueError: if an output file is not under api/ directory,
@@ -283,7 +282,7 @@ def create_api_files(
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text(package)
+ module_text_map = get_api_init_text(package, api_name)
# Add imports to output files.
missing_output_files = []
@@ -334,6 +333,10 @@ def main():
help='Directory where generated output files are placed. '
'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.')
+ parser.add_argument(
+ '--apiname', required=True, type=str,
+ choices=API_ATTRS.keys(),
+ help='The API you want to generate.')
args = parser.parse_args()
@@ -347,8 +350,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
- create_api_files(
- outputs, args.package, args.root_init_template, args.apidir)
+ create_api_files(outputs, args.package, args.root_init_template,
+ args.apidir, args.apiname)
if __name__ == '__main__':
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py
index 986340cf6d..651ec9d040 100644
--- a/tensorflow/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/tools/api/generator/create_python_api_test.py
@@ -57,7 +57,8 @@ class CreatePythonApiTest(test.TestCase):
def testFunctionImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ api_name='tensorflow')
expected_import = (
'from tensorflow.python.test_module '
'import test_op as test_op1')
@@ -73,7 +74,8 @@ class CreatePythonApiTest(test.TestCase):
def testClassImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ api_name='tensorflow')
expected_import = ('from tensorflow.python.test_module '
'import TestClass')
self.assertTrue(
@@ -82,7 +84,8 @@ class CreatePythonApiTest(test.TestCase):
def testConstantIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ api_name='tensorflow')
expected = ('from tensorflow.python.test_module '
'import _TEST_CONSTANT')
self.assertTrue(expected in str(imports),
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
index 8e7e945ed1..834f0954d5 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
@@ -24,7 +24,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -80,7 +80,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 5cfb2fd2f0..4d854a4cee 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -25,7 +25,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -81,7 +81,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
index 3327e5b274..601f095a60 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -25,7 +25,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -81,7 +81,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
index 9d59375282..587829a4c0 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
@@ -25,7 +25,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -81,7 +81,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 87543e374b..32fb9183e6 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -54,7 +54,7 @@ tf_module {
}
member_method {
name: "decode_image"
- argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
}
member_method {
name: "decode_jpeg"
diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
index ca8e5884b1..83bd703540 100644
--- a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
@@ -8,11 +8,11 @@ tf_class {
}
member_method {
name: "add_meta_graph"
- argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "add_meta_graph_and_variables"
- argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "save"
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index 5fa75e1d61..883bb93647 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -322,6 +322,10 @@ create_activate_virtualenv_and_install_tensorflow() {
pip install -v ${PIP_FLAGS} ${WHL_PATH} || \
die "pip install (forcing to reinstall tensorflow) FAILED"
echo "Successfully installed pip package ${TF_WHEEL_PATH}"
+
+ # Force downgrade setuptools.
+ pip install --upgrade setuptools==39.1.0
+
}
################################################################################
diff --git a/tensorflow/tools/ci_build/copy_binary.py b/tensorflow/tools/ci_build/copy_binary.py
index 420d390d2b..148526492d 100755
--- a/tensorflow/tools/ci_build/copy_binary.py
+++ b/tensorflow/tools/ci_build/copy_binary.py
@@ -32,7 +32,8 @@ import shutil
import tempfile
import zipfile
-TF_NIGHTLY_REGEX = r"(.+)tf_nightly(|_gpu)-(\d\.\d\.\d.dev[\d]{0,8})-(.+)\.whl"
+TF_NIGHTLY_REGEX = (r"(.+)tf_nightly(|_gpu)-(\d\.[\d]{1,2}"
+ "\.\d.dev[\d]{0,8})-(.+)\.whl")
BINARY_STRING_TEMPLATE = "%s-%s-%s.whl"
diff --git a/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh
new file mode 100755
index 0000000000..10a09a415a
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Usage: basic_mkl_test.sh
+
+# Helper function to traverse directories up until given file is found.
+function upsearch () {
+ test / == "$PWD" && return || \
+ test -e "$1" && echo "$PWD" && return || \
+ cd .. && upsearch "$1"
+}
+
+# Set up WORKSPACE.
+WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
+
+BUILD_TAG=mkl-ci-test CI_BUILD_USER_FORCE_BADNAME=yes ${WORKSPACE}/tensorflow/tools/ci_build/ci_build.sh cpu tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 406d134699..57a491255e 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -76,7 +76,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.8 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git .
# TODO(craigcitro): Don't install the pip package, since it makes it
# more difficult to experiment with local changes. Instead, just add
diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
index a6cd44ced1..6796ad70e5 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
@@ -3,7 +3,7 @@ FROM tensorflow/tensorflow:latest-devel
LABEL maintainer="Clayne Robison<clayne.b.robison@intel.com>"
# These arguments are parameterized. Use --build-args to override.
-ARG TF_BRANCH=r1.8
+ARG TF_BRANCH=r1.9
ARG WHL_DIR=/whl
RUN apt-get update && apt-get install -y --no-install-recommends \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index e4dcce9cdd..204b5b4dba 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -85,7 +85,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.8 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git .
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 85660f94a8..f858411876 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -117,6 +117,31 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
return Status::OK();
}
+Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
+ GraphDef* graph_def) {
+ std::unordered_set<string> input_names;
+ for (const string& input_name : context.input_names) {
+ input_names.insert(ParseTensorName(input_name).first.ToString());
+ }
+
+ for (NodeDef& node : *graph_def->mutable_node()) {
+ if (input_names.find(node.name()) == input_names.end()) {
+ continue;
+ }
+ if (node.op() == "PlaceholderWithDefault") {
+ node.set_op("Placeholder");
+ node.clear_input();
+ } else if (node.op() != "Placeholder") {
+ return errors::InvalidArgument(
+ "Input '", node.name(),
+ "' was expected to be a Placeholder or PlaceholderWithDefault op, "
+ "but was ",
+ node.op());
+ }
+ }
+ return Status::OK();
+}
+
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
@@ -165,6 +190,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
input_graph_def,
[&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
output_graph_def);
+ TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def));
return Status::OK();
}
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index a082399a87..dcdc3c2906 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -330,48 +330,6 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("unused"));
}
- void TestRemoveUnusedNodesMultipleOutputs() {
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- auto root = tensorflow::Scope::NewRootScope();
-
- // a b
- // \ /
- // shape_n
- // \ /
- // c
- auto a = Placeholder(root.WithOpName("a"), DT_FLOAT);
- auto b = Placeholder(root.WithOpName("b"), DT_FLOAT);
- auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)});
- auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]);
-
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
- GraphDef result_graph_def;
- TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
- graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def));
-
- // Only one output of shape_n node is fed input. Hence the graph search
- // should propagate to inputs of shape_n. Nothing to remove here.
- std::map<string, const NodeDef*> node_map;
- graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
- EXPECT_EQ(1, node_map.count("a"));
- EXPECT_EQ(1, node_map.count("b"));
- EXPECT_EQ(1, node_map.count("c"));
-
- result_graph_def.Clear();
- TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
- graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}},
- &result_graph_def));
-
- // Both outputs of shape_n node are fed inputs. shape_n does not function
- // and inputs to shape_n should be removed.
- node_map.clear();
- graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
- EXPECT_EQ(0, node_map.count("a"));
- EXPECT_EQ(0, node_map.count("b"));
- EXPECT_EQ(1, node_map.count("c"));
- }
-
void TestMaxConstantSizeInBytes() {
auto root = tensorflow::Scope::NewRootScope();
@@ -431,10 +389,6 @@ TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
-TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
- TestRemoveUnusedNodesMultipleOutputs();
-}
-
TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
TestMaxConstantSizeInBytes();
}
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index e113565f45..9d4148c07f 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -59,6 +59,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/autograph/converters:converters",
"//tensorflow/contrib/autograph/converters:test_lib",
"//tensorflow/contrib/autograph/impl:impl",
+ "//tensorflow/contrib/autograph/operators:operators",
"//tensorflow/contrib/autograph/pyct:pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index d25a9e77b1..97f625e7e9 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -45,7 +45,7 @@ DOCLINES = __doc__.split('\n')
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.8.0'
+_VERSION = '1.9.0-rc0'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
@@ -54,6 +54,7 @@ REQUIRED_PACKAGES = [
'numpy >= 1.13.3',
'six >= 1.10.0',
'protobuf >= 3.4.0',
+ 'setuptools <= 39.1.0',
'tensorboard >= 1.8.0, < 1.9.0',
'termcolor >= 1.1.0',
]
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8483a4d767..3935992a00 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -107,13 +107,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "eigen_archive",
urls = [
- "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/6913f0cf7d06.tar.gz",
- "https://bitbucket.org/eigen/eigen/get/6913f0cf7d06.tar.gz",
+ "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/267806ed9b4f.tar.gz",
+ "https://bitbucket.org/eigen/eigen/get/267806ed9b4f.tar.gz",
],
- sha256 = "791b836cacd03e20bae5bdd25f1c4a5505a0a9975ba94a61eb4e2631fbd1d53a",
- strip_prefix = "eigen-eigen-6913f0cf7d06",
+ sha256 = "ade57357093463cab9e4e51cd5749c81483a75451b1471a3ebc73f9c1d14043b",
+ strip_prefix = "eigen-eigen-267806ed9b4f",
build_file = clean_dep("//third_party:eigen.BUILD"),
- patch_file = clean_dep("//third_party:eigen_fix_cuda_compilation.patch")
)
tf_http_archive(
@@ -452,11 +451,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/40c66c3d40377cf85640b3a35e6ec5c5b1cbc41f.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/40c66c3d40377cf85640b3a35e6ec5c5b1cbc41f.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/582e5dd5553e3089fef97f9ab5a3f063e0160fa9.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/582e5dd5553e3089fef97f9ab5a3f063e0160fa9.tar.gz",
],
- sha256 = "6f782a0d2e9d7946bdf20807e0fcd8f5eaed8afd93bdd610cdefbe9435ca551f",
- strip_prefix = "llvm-40c66c3d40377cf85640b3a35e6ec5c5b1cbc41f",
+ sha256 = "9a0e63469ae5a546e0c84b778955f0febabfc8497d312324546ec7d0db68430e",
+ strip_prefix = "llvm-582e5dd5553e3089fef97f9ab5a3f063e0160fa9",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl
index 02d2b78067..a203245005 100644
--- a/third_party/clang_toolchain/download_clang.bzl
+++ b/third_party/clang_toolchain/download_clang.bzl
@@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder):
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
- CLANG_REVISION = '332335'
+ CLANG_REVISION = '332838'
CLANG_SUB_REVISION = 1
package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION)
checksums = {
'Linux_x64':
- '5c234e0bc43b2386984ac34ac9c200c35686f2f7fa5ded0db031055bbc7f3e52',
+ 'b9ef55de7500778f366039dbe62d1632074a3ef3673022eabf4e59d405730968',
'Mac':
- '69b94f16d261c0922c3853cdad768776f454dece2948363f1c4e20bc2ddbf95d',
+ '30d808512763c98cecf15f7bb654d845de3e8d065a95f5c5b6b3459254cc98d6',
'Win':
- '76c8897abf032f3e23598275517da60090f53cf35b673481f41fa98752d1ad37',
+ '277e799a190b22727c26b09986c0cedbd667a189f425318f421addf6a21ca4bd',
}
platform_folder = _get_platform_folder(repo_ctx.os.name)
diff --git a/third_party/eigen_fix_cuda_compilation.patch b/third_party/eigen_fix_cuda_compilation.patch
deleted file mode 100644
index b921a7c31d..0000000000
--- a/third_party/eigen_fix_cuda_compilation.patch
+++ /dev/null
@@ -1,38 +0,0 @@
-diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
---- a/Eigen/src/Core/ProductEvaluators.h
-+++ b/Eigen/src/Core/ProductEvaluators.h
-@@ -137,7 +137,7 @@ struct Assignment<DstXprType, Product<Lh
- typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
- {
- typedef Product<Lhs,Rhs,Options> SrcXprType;
-- static EIGEN_STRONG_INLINE
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
- {
- Index dstRows = src.rows();
-@@ -390,7 +390,7 @@ struct generic_product_impl<Lhs,Rhs,Dens
- typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
- template<typename Dst>
-- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
- {
- // Same as: dst.noalias() = lhs.lazyProduct(rhs);
- // but easier on the compiler side
-@@ -398,14 +398,14 @@ struct generic_product_impl<Lhs,Rhs,Dens
- }
-
- template<typename Dst>
-- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
- {
- // dst.noalias() += lhs.lazyProduct(rhs);
- call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::add_assign_op<typename Dst::Scalar,Scalar>());
- }
-
- template<typename Dst>
-- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
- {
- // dst.noalias() -= lhs.lazyProduct(rhs);
- call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op<typename Dst::Scalar,Scalar>());