aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-05-08 17:14:55 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-05-08 17:14:55 -0700
commit24c9174f84be94043e58ac4536295a3d44d82678 (patch)
tree92f6cfd82d9ad2c295ec8a45bd7df8d5b5d6ee0f
parentc0fb9413914d983cad2ea6bb4997033a1f0dd722 (diff)
parent14d5f219f33b1ab8e0a67b84d97204d046adb91f (diff)
Merge commit for internal changes
-rw-r--r--configure.py4
-rw-r--r--tensorflow/c/c_api_experimental.cc48
-rw-r--r--tensorflow/c/c_api_experimental.h23
-rw-r--r--tensorflow/c/eager/tape.h36
-rw-r--r--tensorflow/compiler/aot/tests/BUILD14
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py29
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt20
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc26
-rw-r--r--tensorflow/compiler/jit/BUILD24
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc207
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.h35
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op_test.cc145
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc90
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h51
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h15
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h6
-rw-r--r--tensorflow/compiler/tests/BUILD8
-rw-r--r--tensorflow/compiler/tests/eager_test.py112
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc3
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/python/BUILD3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc315
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h206
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i53
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py362
-rw-r--r--tensorflow/compiler/xla/service/BUILD21
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc141
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc203
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc104
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc87
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h17
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc55
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc40
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc30
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc56
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc78
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto3
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc2101
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h48
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h2102
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc45
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc129
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h62
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc71
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.h9
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc23
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.cc22
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc22
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc2
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h34
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc23
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc4
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc54
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc219
-rw-r--r--tensorflow/compiler/xla/service_interface.h1
-rw-r--r--tensorflow/compiler/xla/statusor.h11
-rw-r--r--tensorflow/compiler/xla/statusor_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc29
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h29
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc245
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc40
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc24
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/BUILD14
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg.py431
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py252
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py13
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py2
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py44
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py47
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py4
-rw-r--r--tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py33
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py115
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py33
-rw-r--r--tensorflow/contrib/distribute/python/single_loss_example.py20
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py13
-rw-r--r--tensorflow/contrib/distribute/python/values.py101
-rw-r--r--tensorflow/contrib/distributions/BUILD7
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py31
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py39
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/chain.py44
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py55
-rw-r--r--tensorflow/contrib/eager/python/tfe_test.py6
-rw-r--r--tensorflow/contrib/estimator/BUILD38
-rw-r--r--tensorflow/contrib/estimator/__init__.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export.py216
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export_test.py391
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py26
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py13
-rw-r--r--tensorflow/contrib/lite/RELEASE.md8
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h3
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h5
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md53
-rw-r--r--tensorflow/contrib/lite/interpreter.h37
-rw-r--r--tensorflow/contrib/lite/java/BUILD27
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml92
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml79
-rw-r--r--tensorflow/contrib/lite/java/ovic/README.md93
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml48
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD29
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java (renamed from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java)6
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java247
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/build.gradle58
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.pngbin0 -> 2381 bytes
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.pngbin0 -> 5201 bytes
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml39
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml54
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml20
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml22
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/build.gradle23
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/gradle.properties17
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jarbin0 -> 53636 bytes
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties6
-rwxr-xr-xtensorflow/contrib/lite/java/ovic/demo/gradlew160
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/gradlew.bat90
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/settings.gradle1
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java4
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java12
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/testdata/BUILD19
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD18
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc160
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc207
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h144
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h264
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc110
-rw-r--r--tensorflow/contrib/lite/kernels/pad_test.cc368
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc125
-rw-r--r--tensorflow/contrib/lite/kernels/select_test.cc143
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc82
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h85
-rw-r--r--tensorflow/contrib/lite/model.cc9
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc9
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs25
-rw-r--r--[-rwxr-xr-x]tensorflow/contrib/lite/schema/schema_generated.h572
-rw-r--r--tensorflow/contrib/lite/testing/BUILD5
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py189
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc9
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc71
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc65
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc55
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc30
-rw-r--r--tensorflow/contrib/lite/toco/model.h37
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc8
-rw-r--r--tensorflow/contrib/lite/toco/types.proto3
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py30
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py51
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py62
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/ops/replication_ops.cc4
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD10
-rw-r--r--tensorflow/contrib/tpu/proto/compilation_result.proto13
-rw-r--r--tensorflow/contrib/tpu/python/tpu/session_support.py34
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py70
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt54
-rw-r--r--tensorflow/core/common_runtime/executor.cc26
-rw-r--r--tensorflow/core/common_runtime/function.cc10
-rw-r--r--tensorflow/core/common_runtime/profile_handler.h16
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc126
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.h14
-rw-r--r--tensorflow/core/common_runtime/shape_refiner_test.cc100
-rw-r--r--tensorflow/core/framework/api_def.proto6
-rw-r--r--tensorflow/core/framework/dataset.cc19
-rw-r--r--tensorflow/core/framework/dataset.h6
-rw-r--r--tensorflow/core/framework/function.cc2
-rw-r--r--tensorflow/core/framework/node_def_builder.cc17
-rw-r--r--tensorflow/core/framework/node_def_util.cc6
-rw-r--r--tensorflow/core/framework/op_def_builder.cc4
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc5
-rw-r--r--tensorflow/core/framework/op_gen_lib_test.cc1
-rw-r--r--tensorflow/core/framework/op_kernel.cc2
-rw-r--r--tensorflow/core/framework/resource_mgr.h18
-rw-r--r--tensorflow/core/framework/shape_inference.cc29
-rw-r--r--tensorflow/core/framework/shape_inference.h7
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h2
-rw-r--r--tensorflow/core/graph/graph.cc2
-rw-r--r--tensorflow/core/graph/graph_constructor.cc10
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc2
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc4
-rw-r--r--tensorflow/core/graph/graph_def_builder.h2
-rw-r--r--tensorflow/core/graph/graph_partition.cc2
-rw-r--r--tensorflow/core/graph/node_builder.cc2
-rw-r--r--tensorflow/core/graph/while_context.cc2
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc70
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc26
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc55
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc203
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h5
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc140
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc142
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc140
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer_test.cc107
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc19
-rw-r--r--tensorflow/core/grappler/utils.h4
-rw-r--r--tensorflow/core/grappler/utils/functions.cc57
-rw-r--r--tensorflow/core/grappler/utils/functions.h21
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc38
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc12
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h2
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc773
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc9
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt48
-rw-r--r--tensorflow/core/ops/dataset_ops.cc13
-rw-r--r--tensorflow/core/ops/ops.pbtxt48
-rw-r--r--tensorflow/core/platform/default/build_config.bzl8
-rw-r--r--tensorflow/core/platform/default/mutex.h4
-rw-r--r--tensorflow/docs_src/deploy/index.md4
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md4
-rw-r--r--tensorflow/docs_src/programmers_guide/embedding.md2
-rw-r--r--tensorflow/go/op/wrappers.go262
-rw-r--r--tensorflow/python/BUILD3
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py2
-rw-r--r--tensorflow/python/eager/backprop.py5
-rw-r--r--tensorflow/python/eager/backprop_test.py10
-rw-r--r--tensorflow/python/eager/function.py143
-rw-r--r--tensorflow/python/eager/function_test.py15
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc6
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h1
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc62
-rw-r--r--tensorflow/python/estimator/BUILD2
-rw-r--r--tensorflow/python/estimator/canned/dnn.py68
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py20
-rw-r--r--tensorflow/python/estimator/canned/head.py216
-rw-r--r--tensorflow/python/estimator/canned/head_test.py92
-rw-r--r--tensorflow/python/estimator/estimator.py346
-rw-r--r--tensorflow/python/estimator/estimator_test.py336
-rw-r--r--tensorflow/python/estimator/export/export.py325
-rw-r--r--tensorflow/python/estimator/export/export_output.py223
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py110
-rw-r--r--tensorflow/python/estimator/export/export_test.py253
-rw-r--r--tensorflow/python/estimator/model_fn.py59
-rw-r--r--tensorflow/python/framework/function.py9
-rw-r--r--tensorflow/python/framework/ops.py17
-rwxr-xr-xtensorflow/python/keras/BUILD5
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/sequential_test.py39
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py234
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_arrays.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager.py932
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py96
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_utils.py91
-rw-r--r--tensorflow/python/keras/_impl/keras/model_subclassing_test.py39
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py17
-rw-r--r--tensorflow/python/kernel_tests/conv2d_transpose_test.py5
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py26
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD5
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py1
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py1
-rw-r--r--tensorflow/python/ops/check_ops.py30
-rw-r--r--tensorflow/python/ops/control_flow_ops.py7
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py49
-rw-r--r--tensorflow/python/ops/distributions/util.py24
-rw-r--r--tensorflow/python/ops/gradients_impl.py42
-rw-r--r--tensorflow/python/ops/gradients_test.py15
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py196
-rw-r--r--tensorflow/python/saved_model/builder_impl.py54
-rw-r--r--tensorflow/python/saved_model/constants.py6
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py90
-rw-r--r--tensorflow/python/saved_model/signature_constants.py6
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py2
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py56
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py95
-rw-r--r--tensorflow/python/saved_model/tag_constants.py5
-rw-r--r--tensorflow/python/training/checkpointable.py30
-rw-r--r--tensorflow/python/training/checkpointable_test.py10
-rw-r--r--tensorflow/python/training/checkpointable_utils.py7
-rw-r--r--tensorflow/python/training/distribute.py17
-rw-r--r--tensorflow/stream_executor/cuda/cuda_activation.cc6
-rw-r--r--tensorflow/stream_executor/cuda/cuda_activation.h3
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc1466
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h51
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc4
-rw-r--r--tensorflow/tools/benchmark/benchmark_model_test.cc55
-rwxr-xr-xtensorflow/tools/ci_build/update_version.py2
-rw-r--r--tensorflow/workspace.bzl20
-rw-r--r--third_party/clang_toolchain/download_clang.bzl8
-rw-r--r--third_party/png_fix_rpi.patch16
-rw-r--r--third_party/tflite_ovic_testdata.BUILD12
341 files changed, 17489 insertions, 6704 deletions
diff --git a/configure.py b/configure.py
index fe15bfc1a4..6d9aba61bb 100644
--- a/configure.py
+++ b/configure.py
@@ -845,8 +845,8 @@ def reformat_version_sequence(version_str, sequence_count):
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
- 'Please specify the CUDA SDK version you want to use, '
- 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
+ 'Please specify the CUDA SDK version you want to use. '
+ '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 82dbd3cdbc..95b04f9058 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
}
return ret;
}
+
+void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TF_Tensor* tensor, TF_Status* status) {
+ assert(session);
+ {
+ tensorflow::mutex_lock c(session->graph->mu);
+ if (VLOG_IS_ON(1)) {
+ VLOG(1) << "Enqueuing named tensor with id " << tensor_id
+ << ", with input graph: "
+ << session->graph->graph.ToGraphDefDebug().DebugString();
+ tensorflow::Tensor internal_tensor;
+ if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
+ VLOG(1) << "Enqueu'ing tensor content: "
+ << internal_tensor.DebugString();
+ }
+ }
+ }
+
+ TF_Operation* enqueue_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
+ if (enqueue_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the enqueue node in the TF graph.");
+ return;
+ }
+
+ TF_Operation* placeholder_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
+ if (placeholder_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the placeholder node as input to enqueue in the TF "
+ "graph.");
+ return;
+ }
+
+ VLOG(1) << "Running the enqueue op";
+ TF_Output input{placeholder_op, 0};
+ TF_SessionRun(session, /*run_options*/ nullptr,
+ // input related parameters
+ /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
+ // output related parameters
+ /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
+ /*targets*/ &enqueue_op, /*ntargets*/ 1,
+ /*run_metadata*/ nullptr, status);
+ VLOG(1) << "Enqueuing is done.";
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index e6757c065f..20bdace40f 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -87,8 +87,11 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
unsigned char is_mnist, TF_Status* status);
// On success, dequeues a tensor from a TF-managed FifoQueue given by
-// `tensor_id`, associated with `session`. Caller must call TF_DeleteTensor()
-// over the returned tensor. If the queue is empty, this call is blocked.
+// `tensor_id`, associated with `session`. There must be a graph node named
+// "fifo_queue_dequeue_<tensor_id>", to be executed by this API call.
+
+// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is
+// empty, this call is blocked.
//
// Tensors are enqueued via the corresponding TF enqueue op.
// TODO(hongm): Add support for `timeout_ms`.
@@ -96,6 +99,22 @@ TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session,
int tensor_id,
TF_Status* status);
+// On success, enqueues `tensor` into a TF-managed FifoQueue given by
+// `tensor_id`, associated with `session`. There must be a graph node named
+// "fifo_queue_enqueue_<tensor_id>", to be executed by this API call. It reads
+// from a placeholder node "arg_tensor_enqueue_<tensor_id>".
+//
+// `tensor` is still owned by the caller. This call will be blocked if the queue
+// has reached its capacity, and will be unblocked when the queued tensors again
+// drop below the capacity due to dequeuing.
+//
+// Tensors are dequeued via the corresponding TF dequeue op.
+// TODO(hongm): Add support for `timeout_ms`.
+TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
+ int tensor_id,
+ TF_Tensor* tensor,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 8026076b9e..e9ed3395c4 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -130,13 +130,15 @@ class GradientTape {
}
}
- bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
+ bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
+ gtl::ArraySlice<tensorflow::DataType> dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);
@@ -170,12 +172,30 @@ class GradientTape {
// Template instantiations here
+inline bool IsDtypeTrainable(DataType dtype) {
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_COMPLEX64:
+ case DT_COMPLEX128:
+ case DT_RESOURCE:
+ case DT_VARIANT:
+ return true;
+ default:
+ return false;
+ }
+}
+
template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
- gtl::ArraySlice<int64> tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
+ gtl::ArraySlice<int64> tensor_ids,
+ gtl::ArraySlice<tensorflow::DataType> dtypes) {
+ CHECK_EQ(tensor_ids.size(), dtypes.size());
+ for (int i = 0; i < tensor_ids.size(); ++i) {
+ if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
+ return IsDtypeTrainable(dtypes[i]);
}
}
return false;
@@ -189,9 +209,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
+ gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
+ BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
+ if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 222e26810a..fd2cf2b67d 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -15,6 +15,7 @@ test_suite(
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
+ ":test_graph_tfcond_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
@@ -55,6 +56,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
+ "test_graph_tfcond.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
@@ -119,6 +121,17 @@ tf_library(
)
tf_library(
+ name = "test_graph_tfcond",
+ testonly = 1,
+ config = "test_graph_tfcond.config.pbtxt",
+ cpp_class = "CondComp",
+ graph = "test_graph_tfcond.pb",
+ tags = [
+ "manual",
+ ],
+)
+
+tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
@@ -194,6 +207,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
+ ":test_graph_tfcond",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 67767f55da..9ec7df163b 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
f.write(saver.as_saver_def().SerializeToString())
+def tfassert_eq(_):
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ control_flow_ops.Assert(
+ math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
+ math_ops.add(x, math_ops.negative(y), name='x_y_diff')
+
+
+def tfcond(_):
+ p = array_ops.placeholder(dtypes.bool, name='p_hold')
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ z = control_flow_ops.cond(p, lambda: x, lambda: y)
+ array_ops.identity(z, name='result')
+
+
def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
@@ -126,14 +142,6 @@ def tfsplits(_):
array_ops.identity(y, name='result')
-def tfassert_eq(_):
- x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = array_ops.placeholder(dtypes.int32, name='y_hold')
- control_flow_ops.Assert(
- math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
- math_ops.add(x, math_ops.negative(y), name='x_y_diff')
-
-
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -148,12 +156,13 @@ def main(_):
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
+ write_graph(tfassert_eq, FLAGS.out_dir)
+ write_graph(tfcond, FLAGS.out_dir)
+ write_graph(tffunction, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
- write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
- write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
new file mode 100644
index 0000000000..94a01ad4ab
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
@@ -0,0 +1,20 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "p_hold" }
+ shape {}
+}
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "result" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 27ba42b31f..309a991fc1 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
@@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
+TEST(TFCompileTest, Cond) {
+ CondComp cond;
+ EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
+ EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
+ EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
+ cond.arg1() = 10;
+ cond.arg2() = 20;
+ {
+ cond.arg0() = true;
+ const int32 expected_result = cond.arg1();
+ EXPECT_TRUE(cond.Run());
+ EXPECT_EQ(cond.result0(), expected_result);
+ EXPECT_EQ(cond.result0_data()[0], expected_result);
+ EXPECT_EQ(cond.result0_data(), cond.results()[0]);
+ }
+ {
+ cond.arg0() = false;
+ const int32 expected_result = cond.arg2();
+ EXPECT_TRUE(cond.Run());
+ EXPECT_EQ(cond.result0(), expected_result);
+ EXPECT_EQ(cond.result0_data()[0], expected_result);
+ EXPECT_EQ(cond.result0_data(), cond.results()[0]);
+ }
+}
+
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 07136d6a74..a6b3ce394c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -261,6 +261,7 @@ cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
+ "create_xla_launch_op.h",
],
deps = [
":common",
@@ -270,6 +271,29 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "create_xla_launch_op_test",
+ srcs = [
+ "create_xla_launch_op.h",
+ "create_xla_launch_op_test.cc",
+ ],
+ deps = [
+ ":create_xla_launch_op",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 18d901323f..f35e916eb9 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -25,78 +27,189 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
-// 'ndef' is a call to a compilable function defined in 'flr', returns OK
-// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
-// node. Otherwise, returns a non-OK.
+// Utility which searches for values in a sorted list by scanning over it once.
+// No matter how many times ScanForValue is called, the list is scanned at most
+// once. However, if a call to ScanForValue skips over a value, that value is
+// not revisited in future calls to ScanForValue, so callers must take
+// care to order their calls.
//
-// This routine is here so that FunctionLibraryRuntime can jit a
-// specific function call as requested.
-Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
- std::unique_ptr<OpKernel>* kernel) {
- bool xla_compile = false;
- if (!flr->GetFunctionLibraryDefinition()
- ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
- .ok() ||
- !xla_compile) {
- // Not marked as _XlaCompile=true.
- return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+// Useful for merging multiple sorted lists in O(n) time.
+class SinglePassSearch {
+ public:
+ // Creates a SinglePassSearch object that can be used to search in `values`.
+ // Does not take ownership of `values`. `values` must outlive this.
+ // `values` must be sorted.
+ explicit SinglePassSearch(const std::vector<int>* values)
+ : current_index_(0), values_(values) {}
+
+ // Scans forward in the vector looking for "value", updating the internal
+ // position in to the vector.
+ // Returns true iff the vector contains the given value at or after current
+ // position.
+ // Not thread-safe.
+ bool ScanForValue(int value) {
+ while (current_index_ < values_->size() &&
+ (*values_)[current_index_] <= value) {
+ if ((*values_)[current_index_] == value) {
+ current_index_++;
+ return true;
+ }
+ current_index_++;
+ }
+ return false;
}
- // Make sure that kernels have been registered on the JIT device.
- XlaOpRegistry::RegisterCompilationKernels();
- if (!IsCompilable(flr, ndef)) {
- // ndef is calling a function that XLA can't compile.
- return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
+
+ private:
+ int current_index_;
+ const std::vector<int>* values_;
+};
+
+Status CompilationRequested(const FunctionLibraryRuntime& flr,
+ const NodeDef& node_def) {
+ bool xla_compile = false;
+ // Check if op is marked _XlaCompile=true.
+ Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
+ node_def, kXlaCompileAttr, &xla_compile);
+ if (!status.ok() || !xla_compile) {
+ if (VLOG_IS_ON(3)) {
+ if (!status.ok()) {
+ VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
+ << node_def.op() << ". status=" << status.ToString();
+ } else {
+ VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
+ }
+ }
+ return Status(error::INVALID_ARGUMENT, "");
}
+ return Status::OK();
+}
+
+// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
+// runtime, returns this function's body in `fbody` as well as the indices
+// of its constant and resource arguments.
+// `fbody` is owned by `flr`.
+// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
+// They are sorted in ascending order on this function's return.
+Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
+ const NodeDef& node_def,
+ const FunctionBody** fbody,
+ std::vector<int>* constant_arg_indices,
+ std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
- // If ndef is not instantiable, e.g., the function does not exist,
+ // If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
- flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
- const FunctionBody* fbody = flr->GetFunctionBody(handle);
- CHECK(fbody); // Can't be nullptr since we just instantiated it.
- std::vector<bool> const_args(fbody->arg_types.size());
+ flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
+ *fbody = flr->GetFunctionBody(handle);
+ CHECK(*fbody); // Can't be nullptr since we just instantiated it.
+ const DataTypeVector& arg_types = (*fbody)->arg_types;
+ std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
- // There is a const arg. Bail out.
- return errors::InvalidArgument("Const arg: ", i, " in ",
- DebugString(fbody->fdef));
+ constant_arg_indices->push_back(i);
+ }
+ }
+
+ // There can be hundreds of resource variables. Reserve the space for them.
+ // We don't reserve for constants above as they are usually few.
+ resource_arg_indices->reserve(arg_types.size());
+ for (int i = 0; i < arg_types.size(); ++i) {
+ if (arg_types[i] == DT_RESOURCE) {
+ resource_arg_indices->push_back(i);
}
}
- NodeDef launch_def;
- launch_def.set_name(ndef.name());
- launch_def.set_op("_XlaLaunch");
- launch_def.set_device(flr->device()->name());
- AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
- AddNodeAttr("Nresources", 0, &launch_def);
- AddNodeAttr("Targs", fbody->arg_types, &launch_def);
- AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
- NameAttrList func;
- func.set_name(ndef.op());
- *(func.mutable_attr()) = ndef.attr();
- AddNodeAttr("function", func, &launch_def);
-
- // TODO(b/32387911): Handles the host memory types across function
- // calls properly. For now, we assume all inputs and outputs are on
- // the device memory.
+ return Status::OK();
+}
+
+} // namespace
+
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* kernel) {
+ TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
+
+ VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
+
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterCompilationKernels();
+ if (!IsCompilable(flr, node_def)) {
+ // node_def is calling a function that XLA can't compile.
+ return errors::InvalidArgument("Not compilable: ",
+ node_def.ShortDebugString());
+ }
+
+ // Get function body, constant args, and resource args.
+ const FunctionBody* fbody = nullptr;
+ std::vector<int> constant_arg_indices;
+ std::vector<int> resource_arg_indices;
+ TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
+ flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
+
+ // Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
+ // These indices are used only for optimization purposes. They allow us
+ // to loop over constant_arg_indices and resource_arg_indices only once
+ // while iterating over all the function arguments checking if it is a
+ // resource or a constant.
+ // The reason we optimized this code is because functions can have a lot of
+ // captured arguments. For example, the backward pass of ResNet50 takes in all
+ // 214 variables and a similar number of activations.
+ SinglePassSearch constants_search(&constant_arg_indices);
+ SinglePassSearch resources_search(&resource_arg_indices);
+ for (int i = 0; i < fbody->arg_types.size(); ++i) {
+ if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
+ // Compile-time constants and resource handles are expected to be in
+ // host memory.
+ input_memory_types[i] = HOST_MEMORY;
+ }
+ }
+ // One might wonder, about the case where a compile-time constant argument
+ // (which must be in host memory) is also used as an input into an op,
+ // e.g. Add, that expects its inputs in device memory. Here is how it
+ // works now.
+ // First, what do we mean by "op expects an input in XYZ memory"?
+ // There are two types of "ops" here: the tf2xla kernel and the HLO
+ // computation it builds. The tf2xla kernel needs to retrieve the actual
+ // numeric value of the compile-time constant tensors, so it really expects
+ // them to be on in host memory. However, for other inputs, it refers to them
+ // using xla::ComputationDataHandle, which is just a symbolic handle that
+ // xla::ComputationBuilder assigns. How does this handle gets assigned for
+ // constant arguments? Even constant arguments get an _Arg node in the graph
+ // instatiated for Function compilation. The tf2xla kernel for constant _Arg
+ // nodes takes the constant value, converts it to XlaLiteral, and feeds it
+ // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
+ // constant XlaLiteral is included in the HLO graph, and subsequently, in
+ // the actual executable, which is copied to the device before being
+ // executed. Thus, when this executable runs, the constant is available in
+ // device memory.
+
+ // XlaLaunch kernel keeps all outputs (including constants, which it copies),
+ // in device memory
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ // Create the kernel.
+ NameAttrList function;
+ function.set_name(node_def.op());
+ *(function.mutable_attr()) = node_def.attr();
+
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
- dev->GetAllocator(AllocatorAttributes()), &launch_def,
+ dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
- kernel->reset(new XlaLocalLaunchOp(&construction));
+
+ *kernel = absl::make_unique<XlaLocalLaunchBase>(
+ &construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
+namespace {
+
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h
new file mode 100644
index 0000000000..98a22e3515
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class FunctionLibraryRuntime;
+class OpKernel;
+
+// Given a NodeDef 'node_def' and the function library runtime 'flr', if
+// 'node_def' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* kernel);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
new file mode 100644
index 0000000000..bcd5e75c7e
--- /dev/null
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -0,0 +1,145 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+NodeDef ToNodeDef(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+// Create a FunctionDef that takes one resource and one regular param
+FunctionDef XTimesY() {
+ return FunctionDefHelper::Define(
+ // Name
+ "XTimesY",
+ // Args
+ {"x: float", "y: resource"},
+ // Return values
+ {"z: float"},
+ // Attr def
+ {},
+ // Nodes
+ {
+ {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
+ {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
+ });
+}
+
+class CreateXlaLaunchOpTest : public ::testing::Test {
+ protected:
+ void Init(const std::vector<FunctionDef>& flib) {
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", 1});
+ TF_CHECK_OK(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices_));
+
+ FunctionDefLibrary proto;
+ for (const auto& fdef : flib) {
+ *(proto.add_function()) = fdef;
+ }
+ lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
+ OpRegistry::Global(), proto);
+ OptimizerOptions opts;
+ device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+ pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
+ device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
+ opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
+ flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+ }
+
+ FunctionLibraryRuntime* flr_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+
+ std::unique_ptr<OpKernel> kernel_;
+};
+
+AttrValue BoolAttr(bool b) {
+ AttrValue v;
+ v.set_b(b);
+ return v;
+}
+
+TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(
+ flr_, ToNodeDef(R"pb(
+ name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
+ )pb"), &kernel_);
+ ASSERT_TRUE(status.ok()) << status.ToString();
+
+ EXPECT_EQ("XTimesY", kernel_->name());
+ EXPECT_EQ("XTimesY", kernel_->type_string());
+
+ EXPECT_EQ(2, kernel_->num_inputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
+ EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
+ EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
+
+ EXPECT_EQ(1, kernel_->num_outputs());
+ EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
+ EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
+ FunctionDef fdef = XTimesY();
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
+ FunctionDef fdef = XTimesY();
+ (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
+ Init({fdef});
+
+ Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
+ name: 'XTimesY'
+ op: 'XTimesY'
+ input: 'a'
+ input: 'b'
+ )proto"), &kernel_);
+ EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 049d170fa4..86a9fd3b8e 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -39,15 +39,15 @@ limitations under the License.
namespace tensorflow {
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), device_type_(ctx->device_type()) {
- const NameAttrList* func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
- function_ = *func;
- DataTypeVector constant_types;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
- num_constant_args_ = constant_types.size();
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ device_type_(ctx->device_type()),
+ function_(function) {
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
@@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
}
}
-Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
+Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache) {
const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
@@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
-void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOp::Compute "
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, num_resource_args_);
+ SnapshotResourceVariables(ctx, resources_);
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
@@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
std::map<int, Tensor> constant_args;
- for (int i = 0; i < num_constant_args_; ++i) {
+ for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
@@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(
- num_resource_args_, client, xla_allocator, allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(client, xla_allocator,
+ allocate_xla_tensors);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
@@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Done";
}
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index 8f8e646f0f..8dfc4b382d 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -26,6 +26,41 @@ limitations under the License.
namespace tensorflow {
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Builds a XlaCompilationCache class suitable for the current device.
+ Status BuildCompilationCache(OpKernelContext* ctx,
+ XlaCompilationCache** cache);
+
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ DeviceType device_type_;
+ NameAttrList function_;
+ se::Platform::Id platform_id_;
+};
+
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
@@ -35,26 +70,12 @@ namespace tensorflow {
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
-class XlaLocalLaunchOp : public OpKernel {
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
- void Compute(OpKernelContext* ctx) override;
-
private:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** compiler);
-
- DeviceType device_type_;
- NameAttrList function_;
- int num_constant_args_;
- // Number of resource variable arguments.
- int num_resource_args_;
-
- se::Platform::Id platform_id_;
-
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 60458f6f33..6b83cf67ff 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
- int64 num_resource_args = variables.size();
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- num_resource_args, client, client->backend().memory_allocator(), true);
+ client, client->backend().memory_allocator(), true);
launch_context.PopulateInputs(ctx, result, variables);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 33e53612b9..0223f97a03 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables) {
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables) {
std::map<int, OptionalTensor> snapshot;
- int first_variable = ctx->num_inputs() - num_variables;
- for (int i = 0; i < num_variables; ++i) {
+ for (int i : variables) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
- OptionalTensor& tensor = snapshot[first_variable + i];
+ ResourceHandle handle = HandleFromInput(ctx, i);
+ OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
@@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
- int64 num_resource_args, xla::LocalClient* client,
- xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
- : num_resource_args_(num_resource_args),
- client_(client),
+ xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
+ bool allocate_xla_tensors)
+ : client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 38291b0bd4..a2431253f8 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -31,15 +31,17 @@ limitations under the License.
namespace tensorflow {
class XlaAllocator;
-// Takes a snapshot of the values of resource variable arguments, which are
-// the last `num_variables` arguments. We snapshot tensors that back
+// Takes a snapshot of the values of resource variable arguments, whose
+// indices are specified in `variables` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
-// Returns a map of TensorFlow argument index to resource variable.
-std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables);
+// Returns a map of TensorFlow argument index to resource variable. If a
+// resource variable is not initialized, the corresponding OptionalTensor
+// will have its `present` field set to false.
+std::map<int, OptionalTensor> SnapshotResourceVariables(
+ OpKernelContext* ctx, const std::vector<int>& variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -72,7 +74,7 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
- XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
+ XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
@@ -92,7 +94,6 @@ class XlaComputationLaunchContext {
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private:
- int64 num_resource_args_;
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 922a918973..6b29c82ec1 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -54,7 +54,7 @@ class XlaTensor {
// Some Tensors can have complex on-device shapes, including tuple shapes. To
// manage the memory for these tensors a ShapedBuffer may be required.
- // Return true if this TensorInfo contains a ShapedBuffer.
+ // Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
// Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer()
@@ -62,7 +62,7 @@ class XlaTensor {
CHECK(has_shaped_buffer());
return *shaped_buffer_;
}
- // Mutates the TensorInfo to set the ShapedBuffer.
+ // Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
xla::MakeUnique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
@@ -72,7 +72,7 @@ class XlaTensor {
// in on-demand mode to avoid re-copying values from the device if we know the
// host value already.
- // Return true if this TensorInfo contains a host tensor.
+ // Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; }
// Return the contained host tensor.
// REQUIRES: has_host_tensor()
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index a94b298f87..9791792f29 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -300,6 +300,10 @@ tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
+ tags = [
+ "manual",
+ "notap",
+ ],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -323,7 +327,11 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:platform_test",
+ "//tensorflow/python/eager:function",
],
)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index bdd0185dfe..5ab1585f8c 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.layers import convolutional
+from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
@@ -43,7 +49,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen0(self):
with self.test_scope():
- empty = constant_op.constant([], dtype=dtypes.int32)
+ empty = constant_op.constant([], dtype=dtypes.float32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
@@ -51,7 +57,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
@@ -60,7 +66,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
@@ -131,7 +137,105 @@ class EagerTest(XLATestCase):
self.assertEqual(2., grads[0][0].numpy())
-if __name__ == "__main__":
+class EagerFunctionTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_scope():
+ matmul = function.defun(math_ops.matmul, compiled=True)
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t, transpose_a=True)
+ self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
+
+ def testConv(self):
+ if 'GPU' in self.device:
+ # TODO(b/32333178)
+ self.skipTest('Current implementation of RandomStandardNormal kernel '
+ 'is very slow on GPU, and has been blacklisted.')
+ with self.test_scope():
+ data_format = 'channels_last'
+ conv = convolutional.Conv2D(
+ filters=1, kernel_size=2, padding='VALID',
+ data_format=data_format, activation=nn_ops.relu,
+ kernel_initializer=init_ops.ones_initializer(),
+ bias_initializer=init_ops.zeros_initializer())
+ pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
+
+ def model(x):
+ x = conv(x)
+ return pool(x)
+ model = function.defun(model, compiled=True)
+
+ x = array_ops.ones([1, 4, 4, 1])
+ y = model(x)
+ self.assertAllEqual(y.numpy(), [[[[4.]]]])
+
+ def testReadVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun(compiled=True)
+ def f():
+ return v.read_value()
+
+ var = f()
+ self.assertEqual(1.0, var.numpy())
+
+ def testUpdateVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def f(v):
+ v.assign_add(1.0)
+ return v
+
+ f = function.defun(f, compiled=True)
+
+ var = f(v)
+ self.assertEqual(2.0, var.numpy())
+
+ def testAllArgumentKinds(self):
+ """Test a complex function that takes different argument kinds.
+
+ tf2xla machinery that translates, compiles, and runs defuns
+ classifies arguments into: compile-time constants, regular tensors,
+ and resources. This test creates a function with a mix of all these
+ kinds. Moreover, the order of function arguments is intentionally mixed up.
+
+ This also tests the case when the same argument is a compile-time constant
+ as well as used in an operation that normally expects its inputs to be
+ in device memory - addition in this case.
+ """
+ with self.test_scope():
+ def foo(c1, r1, v1, c2, v2, r2):
+ # c1 and c2 are compile-time constants
+ # r1 and r2 are regular tensors
+ # v1 and v2 are resource variables
+ a = c1 + r1
+ b = math_ops.cast(c2, dtypes.float32) + v2
+ c = array_ops.slice(v1, c1, c2)
+ d = r2 * v2
+ return a, b, c, d
+
+ foo = function.defun(foo, compiled=True)
+
+ c1 = [0, 0]
+ c2 = array_ops.ones([2], dtype=dtypes.int32)
+
+ r1 = array_ops.ones([2])
+ r2 = [[2., 2.], [3., 3.]]
+
+ v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
+ v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
+
+ a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
+
+ self.assertAllEqual([1, 1], a.numpy())
+ self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
+ self.assertAllEqual([[1.]], c.numpy())
+ self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
+
+
+if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 4336ebdbd1..b6f8390a45 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase):
# seed were not fixed.
self.assertTrue(self._chi_squared(y, 10) < 16.92)
+ def testRandomNormalIsFinite(self):
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self._random_types():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless.stateless_random_uniform(
+ shape=[10000], seed=seed_t, dtype=dtype)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertTrue(np.all(np.isfinite(y)))
+
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 6340c22518..a99d4ddc7c 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -255,7 +255,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(1);
xla::XlaBuilder* builder = ctx->builder();
- auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0);
+ auto uniform =
+ RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)),
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 1af9cb6d2a..dbf14f32bc 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -99,6 +99,7 @@ cc_library(
hdrs = ["service_interface.h"],
visibility = [":friends"],
deps = [
+ ":xla_data_proto",
":xla_proto",
"//tensorflow/core:lib",
],
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index ecb87bd889..932cce943f 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -49,9 +49,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 044458164f..df262c97bf 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/default/thread_annotations.h"
@@ -248,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
return new LocalShapedBuffer(std::move(result_buffer));
}
-LocalComputation::LocalComputation(Computation computation)
+LocalComputation::LocalComputation(XlaComputation computation)
: computation_(std::move(computation)) {}
StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
@@ -271,7 +272,7 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
return new CompiledLocalComputation(std::move(local_executable));
}
-const Computation& LocalComputation::computation() const {
+const XlaComputation& LocalComputation::computation() const {
return computation_;
}
@@ -281,8 +282,12 @@ StatusOr<Shape> LocalComputation::GetReturnValueShape() const {
return std::move(*program_shape.mutable_result());
}
+LocalOp::LocalOp(const XlaOp& op) : op_(op) {}
+
+const XlaOp& LocalOp::op() const { return op_; }
+
LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
- : builder_(GetOrCreateLocalClient(), computation_name) {}
+ : builder_(computation_name) {}
void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
builder_.SetOpMetadata(metadata);
@@ -291,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }
StatusOr<LocalComputation*> LocalComputationBuilder::Build() {
- TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build());
+ TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
return new LocalComputation(std::move(computation));
}
-ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number,
- const Shape& shape,
- const string& name) {
+LocalOp LocalComputationBuilder::Parameter(int64 parameter_number,
+ const Shape& shape,
+ const string& name) {
return builder_.Parameter(parameter_number, shape, name);
}
std::unique_ptr<Shape> LocalComputationBuilder::GetShape(
- const ComputationDataHandle& operand) {
- return builder_.GetShape(operand).ConsumeValueOrDie();
+ const LocalOp& operand) {
+ auto result = MakeUnique<Shape>();
+ *result = builder_.GetShape(operand.op()).ValueOrDie();
+ return result;
}
StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
@@ -311,222 +318,236 @@ StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
return program_shape.result();
}
-ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) {
+LocalOp LocalComputationBuilder::Infeed(const Shape& shape) {
return builder_.Infeed(shape);
}
-void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand,
+void LocalComputationBuilder::Outfeed(const LocalOp& operand,
const Shape& shape,
const string& outfeed_config) {
- builder_.Outfeed(operand, shape, outfeed_config);
+ builder_.Outfeed(operand.op(), shape, outfeed_config);
}
-ComputationDataHandle LocalComputationBuilder::ConstantLiteral(
- const Literal& literal) {
+LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
return builder_.ConstantLiteral(literal);
}
-ComputationDataHandle LocalComputationBuilder::Broadcast(
- const ComputationDataHandle& operand,
+LocalOp LocalComputationBuilder::Broadcast(
+ const LocalOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- return builder_.Broadcast(operand, broadcast_sizes);
+ return builder_.Broadcast(operand.op(), broadcast_sizes);
}
-ComputationDataHandle LocalComputationBuilder::Pad(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& padding_value,
- const PaddingConfig& padding_config) {
- return builder_.Pad(operand, padding_value, padding_config);
+LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
+ const LocalOp& padding_value,
+ const PaddingConfig& padding_config) {
+ return builder_.Pad(operand.op(), padding_value.op(), padding_config);
}
-ComputationDataHandle LocalComputationBuilder::Reshape(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
+LocalOp LocalComputationBuilder::Reshape(
+ const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- return builder_.Reshape(operand, dimensions, new_sizes);
+ return builder_.Reshape(operand.op(), dimensions, new_sizes);
}
-ComputationDataHandle LocalComputationBuilder::Collapse(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- return builder_.Collapse(operand, dimensions);
+LocalOp LocalComputationBuilder::Collapse(
+ const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return builder_.Collapse(operand.op(), dimensions);
}
-ComputationDataHandle LocalComputationBuilder::CrossReplicaSum(
- const ComputationDataHandle& operand) {
- return builder_.CrossReplicaSum(operand);
+LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) {
+ return builder_.CrossReplicaSum(operand.op());
}
-ComputationDataHandle LocalComputationBuilder::Slice(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
+LocalOp LocalComputationBuilder::Slice(
+ const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return builder_.Slice(operand, start_indices, limit_indices, strides);
+ return builder_.Slice(operand.op(), start_indices, limit_indices, strides);
}
-ComputationDataHandle LocalComputationBuilder::SliceInDim(
- const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
- int64 stride, int64 dimno) {
- return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno);
+LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
+ int64 start_index,
+ int64 limit_index, int64 stride,
+ int64 dimno) {
+ return builder_.SliceInDim(operand.op(), start_index, limit_index, stride,
+ dimno);
}
-ComputationDataHandle LocalComputationBuilder::DynamicSlice(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& start_indices,
+LocalOp LocalComputationBuilder::DynamicSlice(
+ const LocalOp& operand, const LocalOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return builder_.DynamicSlice(operand, start_indices, slice_sizes);
+ return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
}
-ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice(
- const ComputationDataHandle& operand, const ComputationDataHandle& update,
- const ComputationDataHandle& start_indices) {
- return builder_.DynamicUpdateSlice(operand, update, start_indices);
+LocalOp LocalComputationBuilder::DynamicUpdateSlice(
+ const LocalOp& operand, const LocalOp& update,
+ const LocalOp& start_indices) {
+ return builder_.DynamicUpdateSlice(operand.op(), update.op(),
+ start_indices.op());
}
-ComputationDataHandle LocalComputationBuilder::ConcatInDim(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- int64 dimension) {
- return builder_.ConcatInDim(operands, dimension);
+LocalOp LocalComputationBuilder::ConcatInDim(
+ tensorflow::gtl::ArraySlice<LocalOp> operands, int64 dimension) {
+ std::vector<XlaOp> xla_ops;
+ xla_ops.reserve(operands.size());
+ for (const auto& op : operands) {
+ xla_ops.push_back(op.op());
+ }
+ return builder_.ConcatInDim(xla_ops, dimension);
}
-ComputationDataHandle
-LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
- const ComputationDataHandle& operand, const LocalComputation& select,
+LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
+ const LocalOp& operand, const LocalComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const LocalComputation& scatter) {
+ const LocalOp& source, const LocalOp& init_value,
+ const LocalComputation& scatter) {
return builder_.SelectAndScatterWithGeneralPadding(
- operand, select.computation(), window_dimensions, window_strides, padding,
- source, init_value, scatter.computation());
+ operand.op(), select.computation(), window_dimensions, window_strides,
+ padding, source.op(), init_value.op(), scatter.computation());
}
-ComputationDataHandle LocalComputationBuilder::Tuple(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
- return builder_.Tuple(elements);
+LocalOp LocalComputationBuilder::Tuple(
+ tensorflow::gtl::ArraySlice<LocalOp> elements) {
+ std::vector<XlaOp> xla_ops;
+ xla_ops.reserve(elements.size());
+ for (const auto& op : elements) {
+ xla_ops.push_back(op.op());
+ }
+
+ return builder_.Tuple(xla_ops);
}
-ComputationDataHandle LocalComputationBuilder::GetTupleElement(
- const ComputationDataHandle& tuple_data, int64 index) {
- return builder_.GetTupleElement(tuple_data, index);
+LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
+ int64 index) {
+ return builder_.GetTupleElement(tuple_data.op(), index);
}
-ComputationDataHandle LocalComputationBuilder::Dot(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
- return builder_.Dot(lhs, rhs);
+LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
+ return builder_.Dot(lhs.op(), rhs.op());
}
-ComputationDataHandle LocalComputationBuilder::DotGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+LocalOp LocalComputationBuilder::DotGeneral(
+ const LocalOp& lhs, const LocalOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
- return builder_.DotGeneral(lhs, rhs, dimension_numbers);
+ return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
}
-ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+LocalOp LocalComputationBuilder::ConvGeneralDilated(
+ const LocalOp& lhs, const LocalOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation,
+ return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides,
+ padding, lhs_dilation, rhs_dilation,
dimension_numbers);
}
-ComputationDataHandle LocalComputationBuilder::ConvertElementType(
- const ComputationDataHandle& operand, PrimitiveType new_element_type) {
- return builder_.ConvertElementType(operand, new_element_type);
+LocalOp LocalComputationBuilder::ConvertElementType(
+ const LocalOp& operand, PrimitiveType new_element_type) {
+ return builder_.ConvertElementType(operand.op(), new_element_type);
}
-ComputationDataHandle LocalComputationBuilder::Call(
+LocalOp LocalComputationBuilder::Call(
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
- return builder_.Call(local_computation.computation(), operands);
+ tensorflow::gtl::ArraySlice<LocalOp> operands) {
+ std::vector<XlaOp> xla_ops;
+ xla_ops.reserve(operands.size());
+ for (const auto& op : operands) {
+ xla_ops.push_back(op.op());
+ }
+ return builder_.Call(local_computation.computation(), xla_ops);
}
-ComputationDataHandle LocalComputationBuilder::Transpose(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
- return builder_.Transpose(operand, permutation);
+LocalOp LocalComputationBuilder::Transpose(
+ const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) {
+ return builder_.Transpose(operand.op(), permutation);
}
-ComputationDataHandle LocalComputationBuilder::Rev(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- return builder_.Rev(operand, dimensions);
+LocalOp LocalComputationBuilder::Rev(
+ const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return builder_.Rev(operand.op(), dimensions);
}
-ComputationDataHandle LocalComputationBuilder::Map(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+LocalOp LocalComputationBuilder::Map(
+ tensorflow::gtl::ArraySlice<LocalOp> operands,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
- return builder_.Map(operands, local_computation.computation(), dimensions,
- static_operands);
+ tensorflow::gtl::ArraySlice<LocalOp> static_operands) {
+ std::vector<XlaOp> xla_ops;
+ xla_ops.reserve(operands.size());
+ for (const auto& op : operands) {
+ xla_ops.push_back(op.op());
+ }
+
+ std::vector<XlaOp> static_xla_ops;
+ static_xla_ops.reserve(static_operands.size());
+ for (const auto& op : static_operands) {
+ static_xla_ops.push_back(op.op());
+ }
+
+ return builder_.Map(xla_ops, local_computation.computation(), dimensions,
+ static_xla_ops);
}
-ComputationDataHandle LocalComputationBuilder::Reduce(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value,
+LocalOp LocalComputationBuilder::Reduce(
+ const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- return builder_.Reduce(operand, init_value, local_computation.computation(),
- dimensions_to_reduce);
+ return builder_.Reduce(operand.op(), init_value.op(),
+ local_computation.computation(), dimensions_to_reduce);
}
-ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value,
+LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
+ const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
return builder_.ReduceWindowWithGeneralPadding(
- operand, init_value, local_computation.computation(), window_dimensions,
- window_strides, padding);
+ operand.op(), init_value.op(), local_computation.computation(),
+ window_dimensions, window_strides, padding);
}
-ComputationDataHandle LocalComputationBuilder::RngNormal(
- const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
- const Shape& shape) {
- return builder_.RngNormal(mu, sigma, shape);
+LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
+ const LocalOp& sigma,
+ const Shape& shape) {
+ return builder_.RngNormal(mu.op(), sigma.op(), shape);
}
-ComputationDataHandle LocalComputationBuilder::RngUniform(
- const ComputationDataHandle& a, const ComputationDataHandle& b,
- const Shape& shape) {
- return builder_.RngUniform(a, b, shape);
+LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
+ const Shape& shape) {
+ return builder_.RngUniform(a.op(), b.op(), shape);
}
-ComputationDataHandle LocalComputationBuilder::While(
- const LocalComputation& condition, const LocalComputation& body,
- const ComputationDataHandle& init) {
- return builder_.While(condition.computation(), body.computation(), init);
+LocalOp LocalComputationBuilder::While(const LocalComputation& condition,
+ const LocalComputation& body,
+ const LocalOp& init) {
+ return builder_.While(condition.computation(), body.computation(), init.op());
}
-ComputationDataHandle LocalComputationBuilder::Conditional(
- const ComputationDataHandle& predicate,
- const ComputationDataHandle& true_operand,
- const LocalComputation& true_computation,
- const ComputationDataHandle& false_operand,
+LocalOp LocalComputationBuilder::Conditional(
+ const LocalOp& predicate, const LocalOp& true_operand,
+ const LocalComputation& true_computation, const LocalOp& false_operand,
const LocalComputation& false_computation) {
- return builder_.Conditional(predicate, true_operand,
- true_computation.computation(), false_operand,
- false_computation.computation());
+ return builder_.Conditional(
+ predicate.op(), true_operand.op(), true_computation.computation(),
+ false_operand.op(), false_computation.computation());
}
-StatusOr<bool> LocalComputationBuilder::IsConstant(
- const ComputationDataHandle& operand, int64 num_parameters) {
- return builder_.IsConstant(operand, num_parameters);
+StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
+ return builder_.IsConstant(operand.op());
}
-StatusOr<std::unique_ptr<Literal>> LocalComputationBuilder::ComputeConstant(
- const ComputationDataHandle& operand, const Layout* output_layout,
- tensorflow::gtl::ArraySlice<Literal> parameters) {
- return builder_.ComputeConstant(operand, output_layout, parameters);
+StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
+ const LocalOp& operand) {
+ TF_ASSIGN_OR_RETURN(XlaComputation computation,
+ builder_.BuildConstantSubGraph(operand.op()));
+ return new LocalComputation(std::move(computation));
}
#define _FORWARD(method_name, return_sig, args_sig, args) \
@@ -534,23 +555,19 @@ StatusOr<std::unique_ptr<Literal>> LocalComputationBuilder::ComputeConstant(
return builder_.method_name args; \
}
-#define _FORWARD_UNOP(method_name) \
- _FORWARD(method_name, ComputationDataHandle, \
- (const ComputationDataHandle& operand), (operand))
-
-#define _FORWARD_BINOP(method_name) \
- _FORWARD( \
- method_name, ComputationDataHandle, \
- (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
- (lhs, rhs, broadcast_dimensions))
-
-#define _FORWARD_TRIOP(method_name) \
- _FORWARD( \
- method_name, ComputationDataHandle, \
- (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
- const ComputationDataHandle& ehs), \
- (lhs, rhs, ehs))
+#define _FORWARD_UNOP(method_name) \
+ _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op()))
+
+#define _FORWARD_BINOP(method_name) \
+ _FORWARD(method_name, LocalOp, \
+ (const LocalOp& lhs, const LocalOp& rhs, \
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
+ (lhs.op(), rhs.op(), broadcast_dimensions))
+
+#define _FORWARD_TRIOP(method_name) \
+ _FORWARD(method_name, LocalOp, \
+ (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \
+ (lhs.op(), rhs.op(), ehs.op()))
_FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 5ec097846a..a06b85b4ea 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -17,9 +17,10 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -97,25 +98,37 @@ class CompiledLocalComputation {
std::unique_ptr<LocalExecutable> executable_;
};
-// Wraps a Computation produced by a LocalComputationBuilder. The
+// Wraps a XlaComputation produced by a LocalComputationBuilder. The
// Compile method compiles the computation to a (local) executable via
// the client library's local client. This class is intended to be
// made available to Python via SWIG.
class LocalComputation {
public:
- LocalComputation(Computation computation);
+ LocalComputation(XlaComputation computation);
StatusOr<CompiledLocalComputation*> Compile(
const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options);
- const Computation& computation() const;
+ const XlaComputation& computation() const;
// Returns the return-value shape for this computation.
StatusOr<Shape> GetReturnValueShape() const;
private:
- Computation computation_;
+ XlaComputation computation_;
+};
+
+// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended
+// to be made available to Python via SWIG.
+class LocalOp {
+ public:
+ LocalOp(const XlaOp& op);
+
+ const XlaOp& op() const;
+
+ private:
+ XlaOp op_;
};
// Wraps the ComputationBuilder API in order to:
@@ -135,166 +148,137 @@ class LocalComputationBuilder {
// Returns an owned LocalComputation to the caller on success.
StatusOr<LocalComputation*> Build();
- ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
- const string& name);
+ LocalOp Parameter(int64 parameter_number, const Shape& shape,
+ const string& name);
- std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);
+ std::unique_ptr<Shape> GetShape(const LocalOp& operand);
// Returns the shape of the current return value for the computation.
StatusOr<Shape> GetReturnValueShape();
- ComputationDataHandle Infeed(const Shape& shape);
+ LocalOp Infeed(const Shape& shape);
- void Outfeed(const ComputationDataHandle& operand, const Shape& shape,
+ void Outfeed(const LocalOp& operand, const Shape& shape,
const string& outfeed_config);
- ComputationDataHandle ConstantLiteral(const Literal& literal);
+ LocalOp ConstantLiteral(const Literal& literal);
- ComputationDataHandle Broadcast(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ LocalOp Broadcast(const LocalOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
- ComputationDataHandle Pad(const ComputationDataHandle& operand,
- const ComputationDataHandle& padding_value,
- const PaddingConfig& padding_config);
+ LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value,
+ const PaddingConfig& padding_config);
- ComputationDataHandle Reshape(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ LocalOp Reshape(const LocalOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
- ComputationDataHandle Collapse(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ LocalOp Collapse(const LocalOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
- ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);
+ LocalOp CrossReplicaSum(const LocalOp& operand);
- ComputationDataHandle Slice(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ LocalOp Slice(const LocalOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
- ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
- int64 start_index, int64 limit_index,
- int64 stride, int64 dimno);
+ LocalOp SliceInDim(const LocalOp& operand, int64 start_index,
+ int64 limit_index, int64 stride, int64 dimno);
- ComputationDataHandle DynamicSlice(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
- ComputationDataHandle DynamicUpdateSlice(
- const ComputationDataHandle& operand, const ComputationDataHandle& update,
- const ComputationDataHandle& start_indices);
+ LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update,
+ const LocalOp& start_indices);
- ComputationDataHandle ConcatInDim(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- int64 dimension);
+ LocalOp ConcatInDim(tensorflow::gtl::ArraySlice<LocalOp> operands,
+ int64 dimension);
- ComputationDataHandle SelectAndScatterWithGeneralPadding(
- const ComputationDataHandle& operand, const LocalComputation& select,
+ LocalOp SelectAndScatterWithGeneralPadding(
+ const LocalOp& operand, const LocalComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const LocalComputation& scatter);
+ const LocalOp& source, const LocalOp& init_value,
+ const LocalComputation& scatter);
- ComputationDataHandle Tuple(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
+ LocalOp Tuple(tensorflow::gtl::ArraySlice<LocalOp> elements);
- ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
- int64 index);
+ LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index);
- ComputationDataHandle Dot(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs);
+ LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs);
- ComputationDataHandle DotGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
- ComputationDataHandle ConvGeneralDilated(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ LocalOp ConvGeneralDilated(
+ const LocalOp& lhs, const LocalOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
- ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
- PrimitiveType new_element_type);
+ LocalOp ConvertElementType(const LocalOp& operand,
+ PrimitiveType new_element_type);
- ComputationDataHandle Call(
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
+ LocalOp Call(const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<LocalOp> operands);
- ComputationDataHandle Transpose(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ LocalOp Transpose(const LocalOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
- ComputationDataHandle Rev(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ LocalOp Rev(const LocalOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
- ComputationDataHandle Map(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands);
+ LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands,
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<LocalOp> static_operands);
- ComputationDataHandle Reduce(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value,
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
- ComputationDataHandle ReduceWindowWithGeneralPadding(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value,
+ LocalOp ReduceWindowWithGeneralPadding(
+ const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding);
- ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
- const ComputationDataHandle& sigma,
- const Shape& shape);
+ LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
+ const Shape& shape);
- ComputationDataHandle RngUniform(const ComputationDataHandle& a,
- const ComputationDataHandle& b,
- const Shape& shape);
+ LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape);
- ComputationDataHandle While(const LocalComputation& condition,
- const LocalComputation& body,
- const ComputationDataHandle& init);
+ LocalOp While(const LocalComputation& condition, const LocalComputation& body,
+ const LocalOp& init);
- ComputationDataHandle Conditional(const ComputationDataHandle& predicate,
- const ComputationDataHandle& true_operand,
- const LocalComputation& true_computation,
- const ComputationDataHandle& false_operand,
- const LocalComputation& false_computation);
+ LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand,
+ const LocalComputation& true_computation,
+ const LocalOp& false_operand,
+ const LocalComputation& false_computation);
- StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
- int64 num_parameters);
+ StatusOr<bool> IsConstant(const LocalOp& operand);
- StatusOr<std::unique_ptr<Literal> > ComputeConstant(
- const ComputationDataHandle& operand, const Layout* output_layout,
- tensorflow::gtl::ArraySlice<Literal> parameters);
+ StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \
return_sig method_name args_sig;
-#define _FORWARD_UNOP(method_name) \
- _FORWARD(method_name, ComputationDataHandle, \
- (const ComputationDataHandle& operand))
+#define _FORWARD_UNOP(method_name) \
+ _FORWARD(method_name, LocalOp, (const LocalOp& operand))
-#define _FORWARD_BINOP(method_name) \
- _FORWARD( \
- method_name, ComputationDataHandle, \
- (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
+#define _FORWARD_BINOP(method_name) \
+ _FORWARD(method_name, LocalOp, \
+ (const LocalOp& lhs, const LocalOp& rhs, \
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
-#define _FORWARD_TRIOP(method_name) \
- _FORWARD( \
- method_name, ComputationDataHandle, \
- (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
- const ComputationDataHandle& ehs))
+#define _FORWARD_TRIOP(method_name) \
+ _FORWARD(method_name, LocalOp, \
+ (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs))
_FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)
@@ -338,7 +322,7 @@ class LocalComputationBuilder {
#undef _FORWARD_TRIOP
private:
- ComputationBuilder builder_;
+ XlaBuilder builder_;
};
// Functions for freeing resources from the Python side.
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index b8cce5a5f7..04c56bbba9 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -22,9 +22,8 @@ limitations under the License.
//
// C++ Python
// -------------------------------------+---------------------------------------
-// ComputationDataHandle <-> int
// ArraySlice<int64> <- sequence of int
-// ArraySlice<ComputationDataHandle> <- sequence of int
+// ArraySlice<LocalOp> <- sequence of LocalOp
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape -> pair holding (dtype, dimensions)
@@ -91,12 +90,9 @@ limitations under the License.
// One central reason for the Python-side indirection is that the
// Python-side objects produced by the typemaps in this file are
// further packaged up by xla_client before being passed on. For
-// instance, xla_client wraps the long produced for a C++
-// ComputationDataHandle in a Python ComputationDataHandle proto,
-// rather than exposing a raw long outside of the client. Similarly,
-// the Python pair produced for a C++ Shape is further wrapped in a
-// Python class (xla_client.Shape) so as not to expose the raw pair
-// externally.
+// instance, the Python pair produced for a C++ Shape is further
+// wrapped in a Python class (xla_client.Shape) so as not to expose
+// the raw pair externally.
//
// Other SWIG object wrappers (e.g. of LocalComputation) are further
// wrapped by xla_client in order to set up a custom destructor that
@@ -124,6 +120,7 @@ using namespace xla;
using namespace xla::swig;
namespace xla {
+
namespace swig {
bool GetIntAttr(PyObject* o, const char* field, int64* result) {
@@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o,
tensorflow::ImportNumpy();
%}
-// ComputationDataHandle
-
-%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
- const int64 handle = numpy::PyIntOrPyLongToLong($input);
- if (handle == -1 && PyErr_Occurred()) {
- SWIG_fail;
- }
- temp.set_handle(handle);
- $1 = &temp;
-}
-
-%typemap(out) ComputationDataHandle {
- $result = numpy::LongToPyIntOrPyLong($1.handle());
-}
-
%typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> {
if ($1.ok()) {
auto* value = $1.ValueOrDie();
@@ -301,33 +283,23 @@ tensorflow::ImportNumpy();
$1 = temps;
}
-// ComputationDataHandle
+// ArraySlice<LocalOp>
-%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle>
- (std::vector<ComputationDataHandle> temps) {
+%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalOp>(
+ std::vector<LocalOp> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
SWIG_fail;
}
const int size = PySequence_Size($input);
- temps.resize(size);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- PyObject* py_int = numpy::PyNumberToPyInt(o);
- if (!py_int) {
- PyErr_SetString(
- PyExc_TypeError,
- "Argument sequence element cannot be converted to int");
- SWIG_fail;
- }
- const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
- if (handle == -1 && PyErr_Occurred()) {
- Py_DECREF(py_int);
- Py_DECREF(o);
+ LocalOp* op;
+ if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*),
+ SWIG_POINTER_EXCEPTION)) == -1) {
SWIG_fail;
}
- temps[i].set_handle(handle);
- Py_DECREF(py_int);
+ temps.push_back(*op);
Py_DECREF(o);
}
$1 = temps;
@@ -934,6 +906,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputation;
%unignore xla::swig::LocalComputation::Compile;
%unignore xla::swig::LocalComputation::GetReturnValueShape;
+%unignore xla::swig::LocalOp;
%unignore xla::swig::LocalComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::Build;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index f6809b6b87..1d5b75d1be 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -335,20 +335,6 @@ def _wrap_shape(shape_info):
return Shape.array_shape(dtype, dims)
-def _wrap_data_handle(handle):
- cdh = xla_data_pb2.ComputationDataHandle()
- cdh.handle = handle
- return cdh
-
-
-def _unwrap_data_handle(handle_proto):
- return handle_proto.handle
-
-
-def _unwrap_data_handles(handle_protos):
- return [_unwrap_data_handle(cdh) for cdh in handle_protos]
-
-
def require_numpy_array_layout(value):
if isinstance(value, tuple):
return tuple(require_numpy_array_layout(x) for x in value)
@@ -535,9 +521,9 @@ class ComputationBuilder(object):
queue for subsequent use in the computation.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
- return _wrap_data_handle(self._client.Infeed(shape))
+ return self._client.Infeed(shape)
def Outfeed(self, operand):
"""Enqueues an outfeed op onto the computation.
@@ -545,9 +531,7 @@ class ComputationBuilder(object):
Outfeed operations enqueue data, using the given operand, onto the XLA
outfeed queue for subsequent dequeue via the client API.
"""
- self._client.Outfeed(
- _unwrap_data_handle(operand), self.GetShape(operand),
- ''.encode('utf-8'))
+ self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8'))
def Constant(self, value):
"""Enqueues a constant op onto the computation.
@@ -557,10 +541,10 @@ class ComputationBuilder(object):
to one of the supported types.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
value = require_numpy_array_layout(value)
- return _wrap_data_handle(self._client.ConstantLiteral(value))
+ return self._client.ConstantLiteral(value)
def ConstantF32Scalar(self, value):
"""Convenience method to enqueue a scalar F32 constant op.
@@ -569,7 +553,7 @@ class ComputationBuilder(object):
value: a floating-point number.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
return self.Constant(np.array(value, dtype=np.float32))
@@ -580,7 +564,7 @@ class ComputationBuilder(object):
value: a floating-point number.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
return self.Constant(np.array(value, dtype=np.float64))
@@ -591,7 +575,7 @@ class ComputationBuilder(object):
value: a floating-point number.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
return self.Constant(np.array(value, dtype=np.int32))
@@ -602,7 +586,7 @@ class ComputationBuilder(object):
value: a floating-point number.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
return self.Constant(np.array(value, dtype=np.int64))
@@ -613,7 +597,7 @@ class ComputationBuilder(object):
value: a boolean value.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
return self.Constant(np.array(value, dtype=np.bool))
@@ -629,15 +613,14 @@ class ComputationBuilder(object):
parameters, use it for *all* parameters to avoid clashes.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
if name is None:
name = ''
if parameter_num is None:
parameter_num = next(self._parameter_numbering)
- return _wrap_data_handle(
- self._client.Parameter(parameter_num, shape, name.encode('utf8')))
+ return self._client.Parameter(parameter_num, shape, name.encode('utf8'))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
@@ -649,7 +632,7 @@ class ComputationBuilder(object):
parameter_num: as in ParameterWithShape.
Returns:
- A ComputationDataHandle message.
+ A LocalOp.
"""
return self.ParameterWithShape(
Shape.from_pyval(value), name=name, parameter_num=parameter_num)
@@ -658,14 +641,13 @@ class ComputationBuilder(object):
"""Enqueues a broadcast operation onto the computation.
Args:
- operand: the operand ComputationDataHandle to broadcast.
+ operand: the operand LocalOp to broadcast.
sizes: an iterable of broadcast sizes.
Returns:
- A ComputationDataHandle representing the added broadcast op.
+ A LocalOp representing the added broadcast op.
"""
- return _wrap_data_handle(
- self._client.Broadcast(_unwrap_data_handle(operand), sizes))
+ return self._client.Broadcast(operand, sizes)
def Concatenate(self, operands, dimension):
"""Enqueues a concatenate operation onto the computation.
@@ -675,10 +657,9 @@ class ComputationBuilder(object):
dimension: the dimension in which to perform the concatenation.
Returns:
- A ComputationDataHandle representing the added concatenate op.
+ A LocalOp representing the added concatenate op.
"""
- return _wrap_data_handle(
- self._client.ConcatInDim(_unwrap_data_handles(operands), dimension))
+ return self._client.ConcatInDim(operands, dimension)
def ConvertElementType(self, operand, new_element_type):
"""Enqueues an element type conversion operation onto the computation.
@@ -688,14 +669,12 @@ class ComputationBuilder(object):
new_element_type: the target primitive type.
Returns:
- A ComputationDataHandle representing the added conversion op.
+ A LocalOp representing the added conversion op.
"""
- return _wrap_data_handle(
- self._client.ConvertElementType(
- _unwrap_data_handle(operand), new_element_type))
+ return self._client.ConvertElementType(operand, new_element_type)
def GetShape(self, operand):
- return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
+ return _wrap_shape(self._client.GetShape(operand))
def GetReturnValueShape(self):
return _wrap_shape(self._client.GetReturnValueShape())
@@ -707,40 +686,35 @@ class ComputationBuilder(object):
"""Enqueues a Pad operation onto the computation.
Args:
- operand: ComputationDataHandle representing the array to pad.
- padding_value: ComputationDataHandle representing the scalar pad value.
+ operand: LocalOp representing the array to pad.
+ padding_value: LocalOp representing the scalar pad value.
padding_config: either an xla_data_pb2.PaddingConfig or a list of integer
triples (edge_padding_low, edge_padding_high, interior_padding)
representing the configuration of the padding operation.
Returns:
- A ComputationDataHandle representing the added Pad op.
+ A LocalOp representing the added Pad op.
"""
if not isinstance(padding_config, xla_data_pb2.PaddingConfig):
padding_config = GetPaddingConfigFromTriples(padding_config)
- return _wrap_data_handle(
- self._client.Pad(_unwrap_data_handle(operand),
- _unwrap_data_handle(padding_value),
- padding_config))
+ return self._client.Pad(operand, padding_value, padding_config)
def Reshape(self, operand, dimensions, new_sizes):
"""Enqueues a reshape op onto the computation.
Args:
- operand: ComputationDataHandle representing the array to be reshaped.
+ operand: LocalOp representing the array to be reshaped.
dimensions: sequence of integers encoding the order in which dimensions
are collapsed or None, in which case dimensions are flattened in order.
new_sizes: sequence of integers encoding the new dimension sizes (shape).
Returns:
- A ComputationDataHandle representing the added Reshape op.
+ A LocalOp representing the added Reshape op.
"""
if dimensions is None:
ndim = len(self.GetShape(operand).dimensions())
dimensions = tuple(range(ndim))
- return _wrap_data_handle(
- self._client.Reshape(
- _unwrap_data_handle(operand), dimensions, new_sizes))
+ return self._client.Reshape(operand, dimensions, new_sizes)
def CrossReplicaSum(self, operand):
"""CrossReplicaSum op.
@@ -749,67 +723,56 @@ class ComputationBuilder(object):
operand: the operand to sum across replica instances.
Returns:
- A ComputationDataHandle that has the sum of the value among all replicas.
+ A LocalOp that has the sum of the value among all replicas.
"""
- return _wrap_data_handle(
- self._client.CrossReplicaSum(_unwrap_data_handle(operand)))
+ return self._client.CrossReplicaSum(operand)
def Collapse(self, operand, dimensions):
"""Collapse op."""
- return _wrap_data_handle(
- self._client.Collapse(_unwrap_data_handle(operand), dimensions))
+ return self._client.Collapse(operand, dimensions)
def Trans(self, operand):
"""Specialized matrix transpose op."""
- return _wrap_data_handle(
- self._client.Transpose(_unwrap_data_handle(operand), [1, 0]))
+ return self._client.Transpose(operand, [1, 0])
def Transpose(self, operand, permutation):
"""Transpose op."""
- return _wrap_data_handle(
- self._client.Transpose(_unwrap_data_handle(operand), permutation))
+ return self._client.Transpose(operand, permutation)
def Rev(self, operand, dimensions):
"""Rev op."""
- return _wrap_data_handle(
- self._client.Rev(_unwrap_data_handle(operand), dimensions))
+ return self._client.Rev(operand, dimensions)
def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin
"""Clamp op."""
- return _wrap_data_handle(
- self._client.Clamp(_unwrap_data_handle(min),
- _unwrap_data_handle(operand),
- _unwrap_data_handle(max)))
+ return self._client.Clamp(min, operand, max)
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter):
"""Select and scatter op, used by the gradient of ReduceWindow.
Args:
- operand: ComputationDataHandle for array of dimension N and type T over
+ operand: LocalOp for array of dimension N and type T over
which the windows slide.
select: Computation of type (T, T) -> Pred to apply to the elements of
each window to indicate which element is selected.
window_dimensions: sequence of N integers for dimensions of the window.
window_strides: sequence of N integers for the strides of the window.
padding: PaddingType representing either 'SAME' or 'VALID ' padding.
- source: ComputationDataHandle for array of type T with values to scatter.
- init_value: ComputationDataHandle of scalar type T for initial out value.
+ source: LocalOp for array of type T with values to scatter.
+ init_value: LocalOp of scalar type T for initial out value.
scatter: Computation of type (T, T) -> T to apply to each scatter source
element with its destination element.
Returns:
- A ComputationDataHandle representing the added SelectAndScatter op.
+ A LocalOp representing the added SelectAndScatter op.
"""
pads = _convert_padding_type_to_pad_values(
padding, self.GetShape(operand).dimensions(),
window_dimensions, window_strides)
- return _wrap_data_handle(
- self._client.SelectAndScatterWithGeneralPadding(
- _unwrap_data_handle(operand), select.c_local_computation,
- window_dimensions, window_strides, pads,
- _unwrap_data_handle(source), _unwrap_data_handle(init_value),
- scatter.c_local_computation))
+ return self._client.SelectAndScatterWithGeneralPadding(
+ operand, select.c_local_computation, window_dimensions, window_strides,
+ pads, source, init_value, scatter.c_local_computation)
def Select(self, pred, on_true, on_false):
"""Element-wise selection op.
@@ -817,17 +780,13 @@ class ComputationBuilder(object):
Constructs an output array from elements of two input arrays, based on the
values of a predicate array.
"""
- return _wrap_data_handle(
- self._client.Select(
- _unwrap_data_handle(pred),
- _unwrap_data_handle(on_true),
- _unwrap_data_handle(on_false)))
+ return self._client.Select(pred, on_true, on_false)
def Slice(self, operand, start_indices, limit_indices, strides=None):
"""Enqueues a slice operation onto the computation.
Args:
- operand: ComputationDataHandle for the N dimensional array to be sliced.
+ operand: LocalOp for the N dimensional array to be sliced.
start_indices: iterable of N integers containing the starting indices of
the slice for each dimension.
limit_indices: iterable of N integers containing the ending indices
@@ -836,207 +795,177 @@ class ComputationBuilder(object):
each dimension.
Returns:
- A ComputationDataHandle representing the added Slice op.
+ A LocalOp representing the added Slice op.
"""
if strides is None:
start_indices = list(start_indices)
strides = [1] * len(start_indices)
- return _wrap_data_handle(
- self._client.Slice(
- _unwrap_data_handle(operand), start_indices, limit_indices,
- strides))
+ return self._client.Slice(operand, start_indices, limit_indices, strides)
def SliceInDim(self, operand, start_index, limit_index, stride, dimno):
"""Enqueues a slice-in-dimension operation onto the computation.
Args:
- operand: ComputationDataHandle for the N dimensional array to be sliced.
+ operand: LocalOp for the N dimensional array to be sliced.
start_index: an integer containing the start index of the slice.
limit_index: an integer containing the end index of the slice.
stride: an integer containing the stride size for the slice.
dimno: an integer indicating the dimension along which to slice.
Returns:
- A ComputationDataHandle representing the added Slice op.
+ A LocalOp representing the added Slice op.
"""
- return _wrap_data_handle(
- self._client.SliceInDim(
- _unwrap_data_handle(operand), start_index, limit_index, stride,
- dimno))
+ return self._client.SliceInDim(operand, start_index, limit_index, stride,
+ dimno)
def DynamicSlice(self, operand, start_indices, slice_sizes):
"""Enqueues a slice op with dynamic start indices onto the computation.
Args:
- operand: ComputationDataHandle for the N dimensional array to be sliced.
- start_indices: ComputationDataHandle for the 1D array of N integers
+ operand: LocalOp for the N dimensional array to be sliced.
+ start_indices: LocalOp for the 1D array of N integers
containing the starting indices of the slice.
slice_sizes: iterable of N integers containing the slice sizes in each
dimension.
Returns:
- A ComputationDataHandle representing the added DynamicSlice op.
+ A LocalOp representing the added DynamicSlice op.
"""
- return _wrap_data_handle(
- self._client.DynamicSlice(
- _unwrap_data_handle(operand),
- _unwrap_data_handle(start_indices),
- slice_sizes))
+ return self._client.DynamicSlice(operand, start_indices, slice_sizes)
def DynamicUpdateSlice(self, operand, update, start_indices):
"""Enqueues a dynamic update slice operation onto the computation.
Args:
- operand: ComputationDataHandle for the N dimensional array to be updated.
+ operand: LocalOp for the N dimensional array to be updated.
update: N dimensional array comprising the slice update.
start_indices: Rank-1 array of N integers comprising the starting indices
of the slice along each dimension.
Returns:
- A ComputationDataHandle representing the added DynamicUpdateSlice op.
+ A LocalOp representing the added DynamicUpdateSlice op.
"""
- return _wrap_data_handle(
- self._client.DynamicUpdateSlice(
- _unwrap_data_handle(operand),
- _unwrap_data_handle(update),
- _unwrap_data_handle(start_indices)))
+ return self._client.DynamicUpdateSlice(operand, update, start_indices)
def Tuple(self, *ops):
"""Enqueues a tuple operation onto the computation.
Args:
- ops: a sequence of tuple operands (each a ComputationDataHandle).
+ ops: a sequence of tuple operands (each a LocalOp).
Returns:
- A ComputationDataHandle representing the added Tuple op.
+ A LocalOp representing the added Tuple op.
"""
- return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops)))
+ return self._client.Tuple(ops)
def GetTupleElement(self, tup, index):
"""Enqueues a 'get tuple element' operation onto the computation.
Args:
- tup: the tuple operand (a ComputationDataHandle).
+ tup: the tuple operand (a LocalOp).
index: numeric index to select from the tuple.
Returns:
- A ComputationDataHandle representing the added GetTupleElement op.
+ A LocalOp representing the added GetTupleElement op.
"""
- return _wrap_data_handle(
- self._client.GetTupleElement(_unwrap_data_handle(tup), index))
+ return self._client.GetTupleElement(tup, index)
def Call(self, computation_to_apply, operands):
"""Enqueues a call operation onto the computation.
Args:
computation_to_apply: a Computation object.
- operands: an iterable of ComputationDataHandle. The number and types of
+ operands: an iterable of LocalOp. The number and types of
operands must match the arity of computation_to_apply.
Returns:
- A ComputationDataHandle representing the added call op.
+ A LocalOp representing the added call op.
"""
- return _wrap_data_handle(
- self._client.Call(computation_to_apply.c_local_computation,
- _unwrap_data_handles(operands)))
+ return self._client.Call(computation_to_apply.c_local_computation, operands)
def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
"""Enqueues a map operation onto the computation.
Args:
- operands: an iterable of ComputationDataHandle.
+ operands: an iterable of LocalOp.
computation_to_apply: a Computation object.
dimensions: dimensions over which to apply map the function.
static_operands: auxiliary arguments passed to the applied computation.
Returns:
- A ComputationDataHandle representing the added Map op.
+ A LocalOp representing the added Map op.
"""
- return _wrap_data_handle(
- self._client.Map(
- _unwrap_data_handles(operands),
- computation_to_apply.c_local_computation,
- dimensions,
- _unwrap_data_handles(static_operands)))
+ return self._client.Map(operands, computation_to_apply.c_local_computation,
+ dimensions, static_operands)
def Reduce(self, operand, init_value, computation_to_apply, dimensions):
"""Enqueues a reduction operation onto the computation.
Args:
- operand: reduction operand (ComputationDataHandle).
- init_value: reduction initial value (ComputationDataHandle).
+ operand: reduction operand (LocalOp).
+ init_value: reduction initial value (LocalOp).
computation_to_apply: a Computation object - binary reduction function.
dimensions: sequence of dimensions (integers) to reduce on.
Returns:
- A ComputationDataHandle representing the added Reduce op.
+ A LocalOp representing the added Reduce op.
"""
- return _wrap_data_handle(
- self._client.Reduce(
- _unwrap_data_handle(operand),
- _unwrap_data_handle(init_value),
- computation_to_apply.c_local_computation,
- dimensions))
+ return self._client.Reduce(operand, init_value,
+ computation_to_apply.c_local_computation,
+ dimensions)
def ReduceWindow(self, operand, init_value, computation_to_apply,
window_dimensions, window_strides, padding):
"""Enqueues a windowed reduction operation onto the computation.
Args:
- operand: reduction operand (ComputationDataHandle).
- init_value: reduction initial value (ComputationDataHandle).
+ operand: reduction operand (LocalOp).
+ init_value: reduction initial value (LocalOp).
computation_to_apply: a binary reduction function (Computation).
window_dimensions: dimensions of window (sequence of integers).
window_strides: strides for window (sequence of integers).
padding: PaddingType representing either 'SAME' or 'VALID' padding.
Returns:
- A ComputationDataHandle representing the added ReduceWindow op.
+ A LocalOp representing the added ReduceWindow op.
"""
pads = _convert_padding_type_to_pad_values(
padding, self.GetShape(operand).dimensions(), window_dimensions,
window_strides)
- return _wrap_data_handle(
- self._client.ReduceWindowWithGeneralPadding(
- _unwrap_data_handle(operand),
- _unwrap_data_handle(init_value),
- computation_to_apply.c_local_computation,
- window_dimensions, window_strides, pads))
+ return self._client.ReduceWindowWithGeneralPadding(
+ operand, init_value, computation_to_apply.c_local_computation,
+ window_dimensions, window_strides, pads)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
Args:
- mu: A ComputationDataHandle to an F32 scalar specifying the mean.
- sigma: A ComputationDataHandle to an F32 scalar specifying the standard
+ mu: A LocalOp to an F32 scalar specifying the mean.
+ sigma: A LocalOp to an F32 scalar specifying the standard
deviation.
dims: A 1D array-like of nonnegative integers specifying the dimensions.
- Returns: a ComputationDataHandle to the generated array of F32 values.
+ Returns: a LocalOp to the generated array of F32 values.
"""
shape = Shape.array_shape(self.GetShape(mu).element_type(), dims)
- return _wrap_data_handle(
- self._client.RngNormal(
- _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
+ return self._client.RngNormal(mu, sigma, shape)
def RngUniform(self, a, b, dims):
"""Enqueues an RngUniform operation onto the computation.
Args:
- a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with
+ a: a LocalOp to an F32, S32, or U32 scalar (consistent with
the type of b) specifying the low end of the interval [a, b) over which
values are generated.
- b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with
+ b: a LocalOp to an F32, S32, or U32 scalar (consistent with
the type of a) specifying the high end of the interval [a, b) over which
values are generated.
dims: A 1D array-like of nonnegative integers specifying the dimensions.
- Returns: a ComputationDataHandle to the generated array of values with the
+ Returns: a LocalOp to the generated array of values with the
same numeric type (F32, S32, or U32) as the arguments a and b.
"""
shape = Shape.array_shape(self.GetShape(a).element_type(), dims)
- return _wrap_data_handle(
- self._client.RngUniform(
- _unwrap_data_handle(a), _unwrap_data_handle(b), shape))
+ return self._client.RngUniform(a, b, shape)
def While(self, cond, body, init):
"""Enqueues a While operation onto the computation.
@@ -1044,112 +973,105 @@ class ComputationBuilder(object):
Args:
cond: a Computation for the loop condition, which has type T -> PRED
body: a Computation for the loop body, which has type T -> T
- init: a ComputationDataHandle for the initial parameter, which has type T
+ init: a LocalOp for the initial parameter, which has type T
- Returns: a ComputationDataHandle representing the While operation.
+ Returns: a LocalOp representing the While operation.
"""
- return _wrap_data_handle(
- self._client.While(cond.c_local_computation,
- body.c_local_computation,
- _unwrap_data_handle(init)))
+ return self._client.While(cond.c_local_computation,
+ body.c_local_computation, init)
def Conditional(self, pred, true_operand, true_computation, false_operand,
false_computation):
"""Enqueues a Conditional operation onto the computation.
Args:
- predicate: a ComputationDataHandle to test, which has scalar type PRED
- true_operand: a ComputationDataHandle of type T_0
+ predicate: a LocalOp to test, which has scalar type PRED
+ true_operand: a LocalOp of type T_0
true_computation: a Computation to apply to true_operand, type T_0 -> S
false_operand: a ComputationDatahandle of type T_1
false_computation: a Computation to apply to false_operand, type T_1 -> S
- Returns: a ComputationDataHandle representing the Conditional operation.
+ Returns: a LocalOp representing the Conditional operation.
"""
- return _wrap_data_handle(
- self._client.Conditional(
- _unwrap_data_handle(pred), _unwrap_data_handle(true_operand),
- true_computation.c_local_computation,
- _unwrap_data_handle(false_operand),
- false_computation.c_local_computation))
+ return self._client.Conditional(
+ pred, true_operand, true_computation.c_local_computation, false_operand,
+ false_computation.c_local_computation)
- def IsConstant(self, operand, num_parameters=0):
- """Enqueues an IsConstant operation onto the computation.
+ def IsConstant(self, operand):
+ """Checks whether the given operand is a compile-time constant.
Args:
operand: a ComputationDataHandle to test.
- num_parameters: optional int, number of computation parameters to treat as
- constant (default 0).
Returns: bool indicating whether `operand` is a compile-time constant,
- meaning its value does not depend on parameters with index greater than or
- equal to `num_parameters`.
+ meaning its value does not depend on any parametersor, or on stateful
+ operators such as `RngNormal` or `Infeed`.
+ """
+ return self._client.IsConstant(operand)
+
+ def BuildConstantSubGraph(self, operand):
+ """Builds a constant sub graph.
+
+ Args:
+ operand: a LocalOp to test.
+ Returns: a LocalComputation that is rooted on the given `operand` which is a
+ compile-time constant.
"""
- return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters)
+ return self._client.BuildConstantSubGraph(operand)
def Dot(self, lhs, rhs):
"""Enqueues a dot operation onto the computation.
Args:
- lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array.
- rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array.
+ lhs: LocalOp for the rank 1 or rank 2 left-hand-side array.
+ rhs: LocalOp for the rank 1 or rank 2 right-hand-side array.
- Returns: a ComputationDataHandle representing the Dot operation.
+ Returns: a LocalOp representing the Dot operation.
"""
- return _wrap_data_handle(
- self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
+ return self._client.Dot(lhs, rhs)
def DotGeneral(self, lhs, rhs, dimension_numbers):
"""Enqueues a general dot operation onto the computation.
Args:
- lhs: ComputationDataHandle for the left-hand-side array.
- rhs: ComputationDataHandle for the right-hand-side array.
+ lhs: LocalOp for the left-hand-side array.
+ rhs: LocalOp for the right-hand-side array.
dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested
tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
integers representing the dimensions to treat as contracting dimensions
and batch dimensions on each input operand.
- Returns: a ComputationDataHandle representing the DotGeneral operation.
+ Returns: a LocalOp representing the DotGeneral operation.
"""
if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
- return _wrap_data_handle(
- self._client.DotGeneral(
- _unwrap_data_handle(lhs), _unwrap_data_handle(rhs),
- dimension_numbers))
+ return self._client.DotGeneral(lhs, rhs, dimension_numbers)
def Conv(self, lhs, rhs, window_strides, padding):
"""Enqueues a Conv operation onto the computation.
Args:
- lhs: ComputationDataHandle for the rank N+2 array of inputs.
- rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
+ lhs: LocalOp for the rank N+2 array of inputs.
+ rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
- Returns: a ComputationDataHandle representing the Conv operation.
+ Returns: a LocalOp representing the Conv operation.
"""
pads = _convert_padding_type_to_pad_values(
padding, self.GetShape(lhs).dimensions()[2:],
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
- return _wrap_data_handle(
- self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
- _unwrap_data_handle(rhs),
- window_strides,
- pads,
- (),
- (),
- dimension_numbers))
+ return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
+ (), dimension_numbers)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
- lhs: ComputationDataHandle for the rank N+2 array of inputs.
- rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
+ lhs: LocalOp for the rank N+2 array of inputs.
+ rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of kernel strides.
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
@@ -1159,14 +1081,9 @@ class ComputationBuilder(object):
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
"""
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
- return _wrap_data_handle(
- self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
- _unwrap_data_handle(rhs),
- window_strides,
- padding,
- lhs_dilation,
- rhs_dilation,
- dimension_numbers))
+ return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
+ lhs_dilation, rhs_dilation,
+ dimension_numbers)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1196,15 +1113,14 @@ def _forward_methods_to_local_builder():
"""Generate a forwarding method that wraps/unwraps data handles."""
def forward(self, *args, **kwargs):
- unwrapped_args = [_unwrap_data_handle(arg) for arg in args]
+ arg_list = list(args)
- if is_binop and len(unwrapped_args) < 3:
- unwrapped_args.append(kwargs.get('broadcast_dimensions', ()))
+ if is_binop and len(arg_list) < 3:
+ arg_list.append(kwargs.get('broadcast_dimensions', ()))
- return _wrap_data_handle(
- target_method(
- self._client, # pylint: disable=protected-access
- *unwrapped_args))
+ return target_method(
+ self._client, # pylint: disable=protected-access
+ *arg_list)
return forward
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9c362d8cad..aa3a6261e0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -26,6 +26,7 @@ xla_proto_library(
xla_proto_library(
name = "hlo_proto",
srcs = ["hlo.proto"],
+ visibility = ["//visibility:public"],
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
)
@@ -200,7 +201,22 @@ tf_cc_test(
cc_library(
name = "hlo_evaluator",
- srcs = ["hlo_evaluator.cc"],
+ srcs = [
+ "hlo_evaluator.cc",
+ "hlo_evaluator_typed_visitor.h",
+ "hlo_evaluator_typed_visitor_bfloat16.cc",
+ "hlo_evaluator_typed_visitor_bool.cc",
+ "hlo_evaluator_typed_visitor_complex64.cc",
+ "hlo_evaluator_typed_visitor_double.cc",
+ "hlo_evaluator_typed_visitor_float.cc",
+ "hlo_evaluator_typed_visitor_half.cc",
+ "hlo_evaluator_typed_visitor_int32.cc",
+ "hlo_evaluator_typed_visitor_int64.cc",
+ "hlo_evaluator_typed_visitor_int8.cc",
+ "hlo_evaluator_typed_visitor_uint32.cc",
+ "hlo_evaluator_typed_visitor_uint64.cc",
+ "hlo_evaluator_typed_visitor_uint8.cc",
+ ],
hdrs = ["hlo_evaluator.h"],
deps = [
":hlo",
@@ -370,6 +386,7 @@ tf_cc_test(
":hlo_matchers",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2467,6 +2484,7 @@ tf_cc_test(
srcs = ["transpose_folding_test.cc"],
deps = [
":hlo",
+ ":hlo_matchers",
":shape_inference",
":transpose_folding",
"//tensorflow/compiler/xla:literal_util",
@@ -2478,6 +2496,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 8e785de68c..4ec79a0244 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -291,6 +291,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
+ StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -912,6 +914,134 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
return add_result;
}
+StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
+ HloInstruction* dot) {
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ if (dnums.lhs_contracting_dimensions_size() != 1 ||
+ dnums.rhs_contracting_dimensions_size() != 1 ||
+ dnums.lhs_batch_dimensions_size() != 0 ||
+ dnums.rhs_batch_dimensions_size() != 0 ||
+ dot->shape().dimensions_size() != 2) { // dot output 2D
+ VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
+ return nullptr;
+ }
+
+ // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
+ // Currently a Gather is a DynamicSlice.
+ auto is_dynamic_slice_constant_combination =
+ [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
+ // First operand is a DynamicSlice(Constant).
+ if (a->opcode() != HloOpcode::kDynamicSlice) {
+ return false;
+ }
+ auto* dynamic_slice_op = a->operand(0);
+ if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
+ return false;
+ }
+ // Second operand is a Constant.
+ if (b->opcode() != HloOpcode::kConstant) {
+ return false;
+ }
+ // The DynamicSlice output is a vector.
+ const Shape& dynamic_slice_shape = a->shape();
+ if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
+ return false;
+ }
+ // Constant size is the same before and after slice in the contracting
+ // dimension, otherwise we either must precompute for all possible slice
+ // indices or dot is invalid.
+ const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
+ if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
+ dynamic_slice_shape.dimensions(a_contracting_dimension)) {
+ return false;
+ }
+ return true;
+ };
+
+ HloInstruction* lhs = dot->mutable_operand(0);
+ HloInstruction* rhs = dot->mutable_operand(1);
+ int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
+ int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
+
+ if (!is_dynamic_slice_constant_combination(
+ lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
+ !is_dynamic_slice_constant_combination(
+ rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
+ VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
+ "dot(ctB, DS(ctA)), where the two constants have equal "
+ "contracting dimensions.";
+ return nullptr;
+ }
+
+ // LHS is DynamicSlice:
+ // input: dot(DS(ctA), ctB))
+ // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
+ // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
+ // output: DS(dot(ctA, ctB))
+ // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
+
+ // RHS is DynamicSlice:
+ // input: dot(ctA, DS(ctB))
+ // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
+ // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
+ // output: DS(dot(ctA, ctB))
+ // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
+
+ bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
+
+ // ctA:
+ HloInstruction* left_operand =
+ lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
+ // ctB:
+ HloInstruction* right_operand =
+ lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
+ // Build ctA x ctB.
+ const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
+ const int n =
+ right_operand->shape().dimensions(1 - rhs_contracting_dimension);
+ auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
+ auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
+ memoized_shape, left_operand, right_operand, dnums));
+ // Get pair {start, 0} or {0, start}.
+ HloInstruction* original_start_indices =
+ lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
+ // Position of start:
+ int index_of_non_zero_start = lhs_is_dynamic_slice
+ ? 1 - lhs_contracting_dimension
+ : 1 - rhs_contracting_dimension;
+ // Position of zero:
+ int index_of_zero_start = 1 - index_of_non_zero_start;
+
+ // Slice out start and 0 components and reorder if necessary.
+ auto indices_type = original_start_indices->shape().element_type();
+ Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
+ Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
+ HloInstruction* non_zero_start =
+ computation_->AddInstruction(HloInstruction::CreateSlice(
+ s_shape, original_start_indices, {index_of_non_zero_start},
+ {index_of_non_zero_start + 1}, {1}));
+ HloInstruction* zero_start =
+ computation_->AddInstruction(HloInstruction::CreateSlice(
+ s_shape, original_start_indices, {index_of_zero_start},
+ {index_of_zero_start + 1}, {1}));
+ HloInstruction* new_start_indices =
+ lhs_is_dynamic_slice
+ ? computation_->AddInstruction(HloInstruction::CreateConcatenate(
+ d_shape, {non_zero_start, zero_start}, 0))
+ : computation_->AddInstruction(HloInstruction::CreateConcatenate(
+ d_shape, {zero_start, non_zero_start}, 0));
+
+ // Build DynamicSlice(ctA x ctB).
+ const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
+ const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
+ auto* memoized_lookup =
+ computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
+ dot->shape(), memoized_inst, new_start_indices,
+ {new_slice_m, new_slice_n}));
+
+ return memoized_lookup;
+}
+
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
@@ -941,6 +1071,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
return ReplaceInstruction(dot, dot_of_concat_optimized);
}
+ // Simplify dot(ConstA, Gather(Index, ConstB)) to:
+ // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
+ // batched version of dot.
+ TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
+ OptimizeDotOfGather(dot));
+ if (dot_of_gather_optimized) {
+ VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
+ "gather(i, dot*(constA, constB))";
+ return ReplaceInstruction(dot, dot_of_gather_optimized);
+ }
+
if (enable_dot_strength_reduction_ && !is_layout_sensitive_) {
TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
HandleDotStrengthReduction(dot));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index d0c99bf818..4e082877c7 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2963,5 +2963,208 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation,
DotOfConcatSimplificationTest,
::testing::ValuesIn(kDotOfConcatTestSpecs));
+
+struct DotOfGatherTestSpec {
+ int64 m;
+ int64 k;
+ int64 n;
+ int s; // start index for dynamic slice on the non-contracting dimension
+ int64 lcd; // left contracting dimension
+ int64 rcd; // right contracting dimension
+ bool neg; // is negative testcase
+};
+
+class DotOfGatherSimplificationTest
+ : public HloVerifiedTestBase,
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+
+// input: dot(DS(ctA), ctB))
+// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
+// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
+// output: DS(dot(ctA, ctB))
+// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
+TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
+ HloComputation::Builder builder(TestName());
+
+ DotOfGatherTestSpec spec = GetParam();
+
+ ASSERT_LE(spec.s, spec.m);
+
+ // For negative tests, increase k of the dynamic slice argument to prevent the
+ // optimization (constants ctA, ctB must have equal contracting dimensions).
+ int64 k_increase = spec.neg ? 5 : 0;
+ int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
+ int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
+ auto* lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
+ /*cols=*/lhs_cols)));
+
+ int32 start_row = (spec.lcd == 0) ? 0 : spec.s;
+ int32 start_col = (spec.lcd == 0) ? spec.s : 0;
+ const auto start_indices =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<int32>({start_row, start_col})));
+ int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
+ int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
+ Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
+ auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+ ds_shape, lhs, start_indices, {slice_row_size, slice_col_size}));
+
+ int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
+ int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
+ auto* rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
+ /*cols=*/rhs_cols)));
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
+ dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
+
+ int64 dot_row_size = 1;
+ int64 dot_col_size = spec.n;
+ Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
+ builder.AddInstruction(
+ HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
+ ASSERT_TRUE(run_successful);
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
+
+ if (spec.neg) {
+ EXPECT_NE(computation->root_instruction()->opcode(),
+ HloOpcode::kDynamicSlice);
+ } else {
+ EXPECT_THAT(computation->root_instruction(),
+ op::DynamicSlice(op::Dot(op::Constant(), op::Constant()),
+ op::Concatenate()));
+ }
+}
+
+// input: dot(ctA, DS(ctB))
+// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
+// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
+// output: DS(dot(ctA, ctB))
+// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
+TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
+ HloComputation::Builder builder(TestName());
+
+ DotOfGatherTestSpec spec = GetParam();
+
+ ASSERT_LE(spec.s, spec.n);
+
+ int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
+ int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
+ auto* lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
+ /*cols=*/lhs_cols)));
+
+ // For negative tests increase k of the dynamic slice argument to prevent the
+ // optimization
+ int64 k_increase = spec.neg ? 5 : 0;
+ int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
+ int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
+ auto* rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
+ /*cols=*/rhs_cols)));
+
+ int32 start_row = (spec.rcd == 0) ? 0 : spec.s;
+ int32 start_col = (spec.rcd == 0) ? spec.s : 0;
+ const auto start_indices =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<int32>({start_row, start_col})));
+ int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
+ int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
+ Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
+ auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
+ ds_shape, rhs, start_indices, {slice_row_size, slice_col_size}));
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
+ dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
+
+ int64 dot_row_size = spec.m;
+ int64 dot_col_size = 1;
+ Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
+ builder.AddInstruction(
+ HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
+ ASSERT_TRUE(run_successful);
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
+
+ if (spec.neg) {
+ EXPECT_NE(computation->root_instruction()->opcode(),
+ HloOpcode::kDynamicSlice);
+ } else {
+ EXPECT_THAT(computation->root_instruction(),
+ op::DynamicSlice(op::Dot(op::Constant(), op::Constant()),
+ op::Concatenate()));
+ }
+}
+
+std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
+ std::vector<DotOfGatherTestSpec> positives = {
+ // "Classical dot", i.e. matrix multiply:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
+ /*neg=*/false},
+ // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
+ // dot(ct, ct) before DotOfGather optimization kicks in.
+ // Contract on rows:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
+ /*neg=*/false},
+ // Reverse matrix multiply:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
+ /*neg=*/false},
+ // Contract on columns:
+ {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
+ /*neg=*/false},
+ {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
+ /*neg=*/false},
+ };
+ std::vector<DotOfGatherTestSpec> all;
+ for (int i = 0; i < positives.size(); i++) {
+ DotOfGatherTestSpec positive_test = positives[i];
+ all.push_back(positive_test);
+ DotOfGatherTestSpec negative_test = positive_test;
+ negative_test.neg = true;
+ all.push_back(negative_test);
+ }
+ return all;
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
+ ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 91ed6e427a..3d2e24ca14 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -535,7 +535,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
SequentialHloOrdering::HloModuleSequence module_sequence,
- CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction()));
+ CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(),
+ DFSMemoryScheduler));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index a98e85a151..46fe060817 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -158,37 +158,95 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) {
EXPECT_EQ(dot, computation->root_instruction());
}
-TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) {
- HloComputation::Builder builder(TestName());
- HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
- HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1"));
+TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) {
+ string hlo_string = R"(
+HloModule DotOperationFusion_TransposeFusion
- HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
- ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1));
- HloInstruction* transpose1 =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0}));
- builder.AddInstruction(
- MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1));
+ENTRY DotOperationFusion_TransposeFusion {
+ arg0 = f32[1,256] parameter(0)
+ arg1 = f32[1024,256] parameter(1)
+ exponential = s32[1024,256] exponential(arg1)
+ transpose = s32[256,1024] transpose(exponential), dimensions={1,0}
+ ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+ HloComputation* computation = module->entry_computation();
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
return candidate_operands;
},
TransposeFolding::NeverFoldTranspose);
- EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie());
- EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion);
- EXPECT_EQ(computation->root_instruction()->fusion_kind(),
- HloInstruction::FusionKind::kTransposeDot);
- EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
- EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion);
- EXPECT_EQ(computation->root_instruction()->fusion_kind(),
- HloInstruction::FusionKind::kTransposeDot);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ ASSERT_TRUE(changed);
+ ASSERT_THAT(computation->root_instruction(),
+ op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
+}
+
+TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) {
+ string hlo_string = R"(
+HloModule DotOperationFusion_TransposeFusion
+
+ENTRY DotOperationFusion_TransposeFusion {
+ arg0 = f32[256,1] parameter(0)
+ arg1 = f32[256,1024] parameter(1)
+ transpose = s32[1,256] transpose(arg0), dimensions={1,0}
+ exponential = s32[256,1024] exponential(arg1)
+ ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+ HloComputation* computation = module->entry_computation();
+
+ TransposeFolding transpose_folding(
+ [](const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ },
+ TransposeFolding::NeverFoldTranspose);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ ASSERT_TRUE(changed);
+ ASSERT_THAT(computation->root_instruction(),
+ op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
+ /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0));
+}
+
+TEST_F(InstructionFusionTest,
+ DotOperationFusion_TransposeFusion_LHS_NonDefault) {
+ string hlo_string = R"(
+HloModule DotOperationFusion_TransposeFusion
+
+ENTRY DotOperationFusion_TransposeFusion {
+ arg0 = f32[1,256] parameter(0)
+ arg1 = f32[256,1024] parameter(1)
+ transpose = s32[256,1] transpose(arg0), dimensions={1,0}
+ exponential = s32[256,1024] exponential(arg1)
+ ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+ HloComputation* computation = module->entry_computation();
+
+ TransposeFolding transpose_folding(
+ [](const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ },
+ TransposeFolding::NeverFoldTranspose);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ ASSERT_TRUE(changed);
+ ASSERT_THAT(computation->root_instruction(),
+ op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0));
}
class OpcodeFusionTest : public InstructionFusionTest {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index e8117377e6..6c642080c3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -139,13 +139,9 @@ Status CpuLayoutAssignment::AddBackendConstraints(
Shape lhs_shape(RowMajorShape(lhs_instruction->shape()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0));
- // dot is a kDot or a kTransposeDot fusion node. In the latter case, if
- // it represents X @ X, it may have just one operand.
- if (dot->operand_count() > 1) {
- const HloInstruction* rhs_instruction = dot->operand(1);
- Shape rhs_shape(RowMajorShape(rhs_instruction->shape()));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1));
- }
+ const HloInstruction* rhs_instruction = dot->operand(1);
+ Shape rhs_shape(RowMajorShape(rhs_instruction->shape()));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1));
// Set layouts of the instructions' shapes.
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot));
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 801c523908..8db4a0650d 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -522,16 +522,16 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
} // namespace
-DotOpEmitter::DotOpEmitter(
- const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
- const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
- const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
- llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
- const HloModuleConfig& hlo_module_config,
- const TargetMachineFeatures& target_machine_features)
+DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
+ const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array,
+ const llvm_ir::IrArray& rhs_array,
+ const llvm_ir::IrArray* addend_array,
+ llvm::Value* executable_run_options_value,
+ llvm::IRBuilder<>* ir_builder,
+ const HloModuleConfig& hlo_module_config,
+ const TargetMachineFeatures& target_machine_features)
: dot_(dot),
- transpose_lhs_(transpose_lhs),
- transpose_rhs_(transpose_rhs),
target_array_(target_array),
lhs_array_(lhs_array),
rhs_array_(rhs_array),
@@ -542,18 +542,18 @@ DotOpEmitter::DotOpEmitter(
target_machine_features_(target_machine_features) {}
/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
- const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
- const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
- const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
+ const HloInstruction& dot, const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
+ const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
PrimitiveType type = target_array.GetShape().element_type();
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
- DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
- lhs_array, rhs_array, addend_array,
- executable_run_options_value, ir_builder,
- hlo_module_config, target_machine_features);
+ DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array,
+ addend_array, executable_run_options_value,
+ ir_builder, hlo_module_config,
+ target_machine_features);
return dot_emitter.Emit();
}
@@ -578,7 +578,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
if (mat_mult_dims.m == 1) {
bool rhs_effectively_row_major =
- transpose_rhs_ ^ !mat_mult_dims.rhs_column_major;
+ mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major;
if (rhs_effectively_row_major) {
k = mat_mult_dims.k;
m = mat_mult_dims.n;
@@ -594,7 +594,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
if (mat_mult_dims.n == 1) {
bool lhs_effectively_column_major =
- transpose_lhs_ ^ mat_mult_dims.lhs_column_major;
+ mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major;
if (lhs_effectively_column_major) {
m = mat_mult_dims.m;
k = mat_mult_dims.k;
@@ -741,16 +741,10 @@ tensorflow::Status DotOpEmitter::Emit() {
// Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
// case where the reduction dimension is 0 for both LHS and RHS. This results
// in a vector dot product producing a scalar.
- int64 lhs_reduction_dimension = 0;
- if (ShapeUtil::Rank(lhs_shape) >= 2) {
- lhs_reduction_dimension =
- ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1);
- }
- int64 rhs_reduction_dimension = 0;
- if (ShapeUtil::Rank(rhs_shape) >= 2) {
- rhs_reduction_dimension =
- ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2);
- }
+ int64 lhs_reduction_dimension =
+ dot_.dot_dimension_numbers().lhs_contracting_dimensions(0);
+ int64 rhs_reduction_dimension =
+ dot_.dot_dimension_numbers().rhs_contracting_dimensions(0);
// Verify the reduction dimension in the two operands are the same size.
TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
@@ -986,8 +980,8 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
const llvm_ir::IrArray* lhs = &lhs_array_;
const llvm_ir::IrArray* rhs = &rhs_array_;
- bool transpose_lhs = transpose_lhs_;
- bool transpose_rhs = transpose_rhs_;
+ bool transpose_lhs = mat_mult_dims.lhs_non_canonical;
+ bool transpose_rhs = mat_mult_dims.rhs_non_canonical;
if (!mat_mult_dims.lhs_column_major) {
std::swap(mat_mult_dims.m, mat_mult_dims.n);
@@ -1015,12 +1009,16 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
-
- return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0),
- lhs_shape.dimensions(transpose_lhs_ ? 0 : 1),
- rhs_shape.dimensions(transpose_rhs_ ? 0 : 1),
- LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
- LayoutUtil::Minor(rhs_shape.layout(), 0) == 0};
+ const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers();
+
+ return {
+ /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
+ /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
+ /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
+ /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
+ /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0,
+ /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
+ /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1};
}
llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
@@ -1090,27 +1088,16 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
+ const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
- CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
+ CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
+ rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
return true;
}
}
- if (hlo.opcode() == HloOpcode::kFusion &&
- hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
- hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
- auto* dot = hlo.fused_expression_root();
- const Shape& lhs_shape = dot->operand(0)->shape();
- const Shape& rhs_shape = dot->operand(1)->shape();
- if (ShapeUtil::HasZeroElements(lhs_shape) ||
- ShapeUtil::HasZeroElements(rhs_shape)) {
- return false;
- }
- return true;
- }
-
return false;
}
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 47e0924334..a20bf2f9db 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -56,16 +56,15 @@ class DotOpEmitter {
// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported
// for Matrix-vector products.
static tensorflow::Status EmitDotOperation(
- const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
- const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
- const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
+ const HloInstruction& dot, const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
+ const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features);
private:
- DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
- bool transpose_rhs, const llvm_ir::IrArray& target_array,
+ DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array,
const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
@@ -114,8 +113,14 @@ class DotOpEmitter {
// True if the LHS matrix column major.
bool lhs_column_major;
+ // True if the LHS contraction dimension is not 1.
+ bool lhs_non_canonical;
+
// True if the RHS matrix column major.
bool rhs_column_major;
+
+ // True if the RHS contraction dimension is not 0.
+ bool rhs_non_canonical;
};
// Get the MatMultDims instance for the dot product this DotOpEmitter
@@ -132,8 +137,6 @@ class DotOpEmitter {
}
const HloInstruction& dot_;
- const bool transpose_lhs_;
- const bool transpose_rhs_;
const llvm_ir::IrArray& target_array_;
const llvm_ir::IrArray& lhs_array_;
const llvm_ir::IrArray& rhs_array_;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6347ee2a2a..55e5aa5063 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -827,13 +827,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
"Dot with multiple contracting dimensions not implemented.");
}
- if (dnums.lhs_contracting_dimensions(0) !=
- std::min(lhs->shape().dimensions_size() - 1, 1) ||
- dnums.rhs_contracting_dimensions(0) != 0) {
- return Unimplemented(
- "Dot with non-standard contracting dimensions not implemented.");
- }
-
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
@@ -850,8 +843,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// Dot operation is complicated so we delegate to a helper class.
return DotOpEmitter::EmitDotOperation(
- *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
- lhs_array, rhs_array, /*addend_array=*/nullptr,
+ *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr,
GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
target_machine_features_);
}
@@ -2086,44 +2078,7 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) {
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
auto* root = fusion->fused_expression_root();
- if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) {
- DCHECK(root->opcode() == HloOpcode::kDot);
- const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0));
- const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1));
- DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
- rhs_parameter->opcode() == HloOpcode::kParameter);
- const HloInstruction* lhs =
- fusion->operand(lhs_parameter->parameter_number());
- const HloInstruction* rhs =
- fusion->operand(rhs_parameter->parameter_number());
-
- TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
- /*instruction=*/*root, /*operands=*/{lhs, rhs},
- /*supported_types=*/{F16, F32, F64}));
-
- llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
- llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
-
- Shape target_shape = fusion->shape();
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
- llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
- VLOG(2) << "HandleFusion kTransposeDot: ";
- VLOG(2) << " lhs operand: "
- << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
- VLOG(2) << " rhs operand: "
- << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
- VLOG(2) << " target: "
- << llvm_ir::DumpToString(*target_array.GetBasePointer());
-
- // Dot operation is complicated so we delegate to a helper class.
- TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
- *root, root->operand(0)->IsRank2Transpose(),
- root->operand(1)->IsRank2Transpose(), target_array, lhs_array,
- rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(),
- &ir_builder_, hlo_module_config_, target_machine_features_));
- return Status::OK();
- } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
- assignment_)) {
+ if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
@@ -2166,9 +2121,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
GetIrArrayFor(fusion->operand(addend_param_number)));
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
- *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
- lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(),
- &ir_builder_, hlo_module_config_, target_machine_features_));
+ *dot, target_array, lhs_array, rhs_array, &addend_array,
+ GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
+ target_machine_features_));
return Status::OK();
} else {
return Unimplemented("Fusion kind not implemented on CPU");
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index fb28280fad..47e8405ff2 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -127,7 +127,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
// Currently, we do not assign parallel tasks to instructions with at least
// one of the following properties:
// *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
- // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
+ // *) Emit custom loops (kSelectAndScatter).
// *) Operations that are not thread safe (like infeed and rng).
// *) Tuple-shaped.
// TODO(b/27458679) Parallelize instructions which are skipped here.
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index c4c56c5692..41ee45f55f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -197,22 +197,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// We don't put any data in these buffers, because (in theory, anyway) the
// speed of a conv isn't affected by the data being convolved.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- se::port::StatusOr<DeviceMemoryBase> input_buf =
+ StatusOr<DeviceMemoryBase> maybe_input_buf =
input_output_allocator.AllocateBytes(&stream,
ShapeUtil::ByteSizeOf(input_shape));
- se::port::StatusOr<DeviceMemoryBase> filter_buf =
+ StatusOr<DeviceMemoryBase> maybe_filter_buf =
input_output_allocator.AllocateBytes(&stream,
ShapeUtil::ByteSizeOf(filter_shape));
- se::port::StatusOr<DeviceMemoryBase> output_buf =
+ StatusOr<DeviceMemoryBase> maybe_output_buf =
input_output_allocator.AllocateBytes(&stream,
ShapeUtil::ByteSizeOf(output_shape));
- if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) {
+ if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() ||
+ !maybe_output_buf.ok()) {
LOG(WARNING)
<< "Couldn't allocate space for input/filter/output of convolution "
<< instr->ToString() << ". Falling back to default algorithm.";
return nullopt;
}
+ DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie();
+ DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie();
+ DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie();
+
+ // Although we don't have evidence this matters, zero out the buffers before
+ // autotuning. It's conceivable that using uninitialized memory as the inputs
+ // might affect performance if e.g. the inputs contain denormals, and this is
+ // easy enough.
+ if (!stream.ThenMemZero(&input_buf, input_buf.size())
+ .ThenMemZero(&filter_buf, filter_buf.size())
+ .ThenMemZero(&output_buf, output_buf.size())
+ .BlockHostUntilDone()
+ .ok()) {
+ LOG(WARNING)
+ << "Couldn't zero out input/filter/output buffer for convolution "
+ << instr->ToString() << ". Falling back to default algorithm.";
+ return nullopt;
+ }
+
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
input_shape, output_shape, dnums, stream_exec_);
se::dnn::ProfileResult best_result;
@@ -225,12 +245,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- bool launch_ok = RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- input_buf.ValueOrDie(), filter_buf.ValueOrDie(),
- output_buf.ValueOrDie(), &scratch_allocator, window,
- dnums, AlgorithmConfig(alg), &stream, &profile_result)
- .ok();
+ bool launch_ok =
+ RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ input_buf, filter_buf, output_buf,
+ &scratch_allocator, window, dnums,
+ AlgorithmConfig(alg), &stream, &profile_result)
+ .ok();
if (launch_ok && profile_result.is_valid()) {
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 0ec12f52d8..f996fe486d 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -221,8 +221,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& lhs_shape, const Shape& rhs_shape,
- const Shape& output_shape, bool transpose_lhs,
- bool transpose_rhs, double alpha,
+ const Shape& output_shape, double alpha,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kGemm, hlo_instruction),
lhs_buffer_(lhs_buffer),
@@ -231,8 +230,6 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
lhs_shape_(lhs_shape),
rhs_shape_(rhs_shape),
output_shape_(output_shape),
- transpose_lhs_(transpose_lhs),
- transpose_rhs_(transpose_rhs),
alpha_(alpha) {}
tensorflow::Status GemmThunk::ExecuteOnStream(
@@ -284,10 +281,13 @@ tensorflow::Status GemmThunk::ExecuteOnStream(
shape.dimensions(!is_row_major));
};
- const MatrixDescriptor lhs_descriptor =
- make_descriptor(lhs_data, lhs_shape_, transpose_lhs_);
- const MatrixDescriptor rhs_descriptor =
- make_descriptor(rhs_data, rhs_shape_, transpose_rhs_);
+ const DotDimensionNumbers& dim_nums =
+ hlo_instruction()->dot_dimension_numbers();
+
+ const MatrixDescriptor lhs_descriptor = make_descriptor(
+ lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
+ const MatrixDescriptor rhs_descriptor = make_descriptor(
+ rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1);
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
// autotune this gemm to figure out the best algorithm.
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index a18f425bc3..f42cbf9e94 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -35,15 +35,13 @@ namespace gpu {
class GemmThunk : public Thunk {
public:
// Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
- // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should
- // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is
- // a constant.
+ // BLAS gemm. hlo_instruction is as in Thunk. alpha is a constant.
GemmThunk(const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& lhs_shape, const Shape& rhs_shape,
- const Shape& output_shape, bool transpose_lhs, bool transpose_rhs,
- double alpha, const HloInstruction* hlo_instruction);
+ const Shape& output_shape, double alpha,
+ const HloInstruction* hlo_instruction);
GemmThunk(const GemmThunk&) = delete;
GemmThunk& operator=(const GemmThunk&) = delete;
@@ -69,8 +67,6 @@ class GemmThunk : public Thunk {
const Shape rhs_shape_;
const Shape output_shape_;
- const bool transpose_lhs_;
- const bool transpose_rhs_;
const double alpha_;
// Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index ece9fa04dc..6436abc06c 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -65,9 +65,9 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -193,11 +193,11 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) {
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
HloInstruction* add = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(add));
@@ -259,24 +259,24 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
}
- HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
- f32_2x2_, HloOpcode::kDot, params[2], params[3]));
+ HloInstruction* d00 = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 85ecbe8fdb..c5eb721185 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -48,6 +48,19 @@ bool IsFusile(const HloInstruction& hlo) {
} // namespace
+/*static*/ bool GpuInstructionFusion::IsExpensive(
+ const HloInstruction& instruction) {
+ switch (instruction.opcode()) {
+ // We say that floating-point division is cheap on the GPU.
+ case HloOpcode::kDivide:
+ return !ShapeUtil::ElementIsFloating(instruction.shape()) &&
+ InstructionFusion::IsExpensive(instruction);
+
+ default:
+ return InstructionFusion::IsExpensive(instruction);
+ }
+}
+
bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
index bb2990e6df..9fb06b0a24 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
@@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion {
explicit GpuInstructionFusion(bool may_duplicate)
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {}
+ static bool IsExpensive(const HloInstruction& instruction);
+
bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;
HloInstruction::FusionKind ChooseKind(
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 4b231c449f..6c9a805ad6 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -253,5 +253,61 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
op::Dot(op::Parameter(), op::Transpose(op::Parameter()))));
}
+// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
+// duplicated and fused into both reduces.
+TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ Add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+ ENTRY TestComputation {
+ zero = f32[] constant(0)
+ one = f32[] constant(1)
+ p0 = f32[100] parameter(0)
+ recip = f32[100] divide(one, p0)
+ sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ ROOT root = (f32[], f32[]) tuple(sum1, sum2)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()));
+}
+
+// Compute sum(100/p0), where p0 has type s32, twice. Check that the division
+// is *not* duplicated and fused into both reduces, because we say that integer
+// division is not cheap.
+TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ Add {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(lhs, rhs)
+ }
+ ENTRY TestComputation {
+ zero = s32[] constant(0)
+ one_hundred = s32[] constant(100)
+ p0 = s32[100] parameter(0)
+ recip = s32[100] divide(one_hundred, p0)
+ sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add
+ ROOT mul = (s32[], s32[]) tuple(sum1, sum2)
+ })")
+ .ValueOrDie();
+
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 532d436ee8..96199035b9 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -78,18 +78,14 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
- CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
+ const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
+ CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
+ rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
return true;
}
}
if (hlo.opcode() == HloOpcode::kFusion &&
- hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
- hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
- return true;
- }
-
- if (hlo.opcode() == HloOpcode::kFusion &&
hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) {
// Try to find the dot inside the output fusion node.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 9f37235d32..83d90296df 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2206,65 +2206,37 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
lhs->shape(), // The shape of LHS.
rhs->shape(), // The shape of RHS.
inst->shape(), // The shape of the output.
- false, // Do not transpose LHS.
- false, // Do not transpose RHS.
1.0, // alpha.
inst);
}
if (inst->opcode() == HloOpcode::kFusion) {
- if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) {
- const HloInstruction* mul = inst->fused_expression_root();
- const HloInstruction* dot = mul->operand(0);
- const HloInstruction* alpha = mul->operand(1);
- if (dot->opcode() != HloOpcode::kDot) {
- std::swap(dot, alpha);
- }
- DCHECK(dot->opcode() == HloOpcode::kDot);
- const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
- const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
- DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
- rhs_parameter->opcode() == HloOpcode::kParameter);
- const HloInstruction* lhs =
- inst->operand(lhs_parameter->parameter_number());
- const HloInstruction* rhs =
- inst->operand(rhs_parameter->parameter_number());
-
- return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*mul), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- dot->operand(0)->IsRank2Transpose(), // Transpose LHS.
- dot->operand(1)->IsRank2Transpose(), // Transpose RHS.
- alpha->literal().Get<double>({0}), // alpha.
- inst);
- } else {
- const HloInstruction* dot = inst->fused_expression_root();
- DCHECK(dot->opcode() == HloOpcode::kDot);
- const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
- const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
- DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
- rhs_parameter->opcode() == HloOpcode::kParameter);
- const HloInstruction* lhs =
- inst->operand(lhs_parameter->parameter_number());
- const HloInstruction* rhs =
- inst->operand(rhs_parameter->parameter_number());
-
- return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*inst), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- dot->operand(0)->IsRank2Transpose(), // Transpose LHS.
- dot->operand(1)->IsRank2Transpose(), // Transpose RHS.
- 1.0, // Alpha.
- inst);
+ CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput);
+ const HloInstruction* mul = inst->fused_expression_root();
+ const HloInstruction* dot = mul->operand(0);
+ const HloInstruction* alpha = mul->operand(1);
+ if (dot->opcode() != HloOpcode::kDot) {
+ std::swap(dot, alpha);
}
+ DCHECK(dot->opcode() == HloOpcode::kDot);
+ const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
+ const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
+ DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
+ rhs_parameter->opcode() == HloOpcode::kParameter);
+ const HloInstruction* lhs =
+ inst->operand(lhs_parameter->parameter_number());
+ const HloInstruction* rhs =
+ inst->operand(rhs_parameter->parameter_number());
+
+ return MakeUnique<GemmThunk>(
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*mul), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ alpha->literal().Get<double>({0}), // alpha.
+ inst);
}
LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 8c98956f1a..b42767dfd5 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -41,9 +41,9 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) {
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -60,9 +60,9 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
@@ -91,24 +91,24 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
}
- HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
- f32_2x2_, HloOpcode::kDot, params[2], params[3]));
+ HloInstruction* d00 = builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
+ HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index aa6860880b..1f7c1cffd3 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -147,6 +147,9 @@ message HloInstructionProto {
repeated int64 called_computation_ids = 38;
xla.OpSharding sharding = 40;
+
+ // Backend configuration for the instruction. Has backend-specific meaning.
+ string backend_config = 43;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 594413e88f..17e43c3cb8 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -347,6 +347,11 @@ std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
// construction.
+ //
+ // TODO(b/78350259): This violates const-correctness, since while the original
+ // computation is not returned, we still retrieve non-const computations from
+ // a const one. Consider also avoiding const for HloComputation, or review XLA
+ // for const-correctness of non-HloInstruction* types like this.
ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
&post_order);
@@ -723,18 +728,25 @@ Status HloComputation::Accept(
return this->Accept(&visitor);
}
-std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix,
- HloModule* module) {
+std::unique_ptr<HloComputation> HloComputation::Clone(
+ const string& suffix, HloModule* module,
+ HloInstruction::CloneMap* clone_map) {
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- module, suffix);
+ module, clone_map, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloModule* module, const string& suffix) {
+ HloModule* module, HloInstruction::CloneMap* clone_map,
+ const string& suffix) {
+ HloInstruction::CloneMap local_clone_map;
+ if (clone_map == nullptr) {
+ clone_map = &local_clone_map;
+ }
+
// Look up instr in the replacements map, and return either the replacement,
// or instr, if the replacement isn't present.
//
@@ -756,24 +768,19 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
}
}
- std::unordered_map<HloInstruction*, HloInstruction*> clone_map;
std::vector<std::unique_ptr<HloInstruction>> instructions;
std::unique_ptr<HloInstruction> new_instr = nullptr;
for (auto instr : postorder) {
std::vector<HloInstruction*> new_operands;
for (auto operand : instr->operands()) {
auto replaced_operand = replace(operand);
- // If replaced_operand is null, that means 'replacements' asked us not to
- // include operand in the new computation. But we can't do that, because
- // operand is used by instr.
CHECK_NE(replaced_operand, nullptr)
- << "replacements map tried to eliminate a used instruction "
- << operand->ToString() << ", used by " << instr->ToString();
- new_operands.push_back(FindOrDie(clone_map, replaced_operand));
+ << "Replacements map specifies to leave out " << operand->ToString()
+ << ", but it is used by " << instr->ToString() << ".";
+ new_operands.push_back(FindOrDie(*clone_map, replaced_operand));
}
- new_instr =
- instr->CloneWithNewOperands(instr->shape(), new_operands, module);
- InsertOrDie(&clone_map, instr, new_instr.get());
+ new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands,
+ module, clone_map);
instructions.push_back(std::move(new_instr));
}
Builder builder(name() + "." + suffix);
@@ -781,27 +788,24 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
builder.AddInstruction(std::move(instr));
}
auto result = builder.Build(
- /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction())));
+ /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction())));
// Clone control dependencies.
for (auto instr : postorder) {
- HloInstruction* new_instr = FindOrDie(clone_map, instr);
+ HloInstruction* new_instr = FindOrDie(*clone_map, instr);
for (auto successor : instr->control_successors()) {
auto replaced_successor = replace(successor);
-
- // successor may not be in clone_map, because it might have been
- // removed by the replacements map.
- if (replaced_successor == nullptr) {
- continue;
- }
+ CHECK_NE(replaced_successor, nullptr)
+ << "Replacements map specifies to leave out " << successor->ToString()
+ << ", but it is control-depended-on by " << instr->ToString() << ".";
TF_CHECK_OK(new_instr->AddControlDependencyTo(
- FindOrDie(clone_map, replaced_successor)));
+ FindOrDie(*clone_map, replaced_successor)));
}
}
// We cloned the elements of 'replacements', so they're all going to be
- // destroyed. HloInstructions need to be detached from their operands before
+ // destroyed. HloInstructions need to be detached from their operands before
// they're destroyed, otherwise they stick around in the operands' users lists
// and cause use-after-frees.
for (auto& kv : replacements) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 9d3f6e9a2c..9898355625 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -291,11 +291,17 @@ class HloComputation {
const std::function<Status(const HloInstruction*)>& visitor_func) const;
// Returns a deep copy of this computation including all instructions.
- // If the module pointer is not nullptr, it will be the module where
- // the cloned computations will be added to (in order to support deep
- // cloning).
- std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
- HloModule* module = nullptr);
+ //
+ // If the module pointer is not nullptr, then the cloned computations will be
+ // added to this module in order to support deep cloning. Otherwise the module
+ // of the computation is used.
+ //
+ // If clone_map is not nullptr, then each original instruction that is cloned
+ // will be inserted and map to its clone. clone_map should not already contain
+ // any of the instructions to clone.
+ std::unique_ptr<HloComputation> Clone(
+ const string& suffix = "clone", HloModule* module = nullptr,
+ HloInstruction::CloneMap* clone_map = nullptr);
// Like Clone(), but if an instruction is present in replacement_map, we use
// the map's value to replace that instruction in the cloned computation.
@@ -305,7 +311,9 @@ class HloComputation {
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloModule* module = nullptr, const string& suffix = "clone");
+ HloModule* module = nullptr,
+ HloInstruction::CloneMap* clone_map = nullptr,
+ const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 9a89888480..ed3b654851 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -269,7 +269,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
ArraySlice<const Shape*> domain, const Shape& range,
tensorflow::StringPiece name) {
- HloComputation::Builder b(name.ToString());
+ HloComputation::Builder b{std::string(name)};
int64 param_idx = 0;
for (const Shape* param_shape : domain) {
b.AddInstruction(HloInstruction::CreateParameter(
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 1071f5b184..e7425c8ba7 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -53,19 +53,6 @@ namespace {
using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::FlatSet;
-using tensorflow::gtl::optional;
-
-template <typename T>
-struct is_complex_t : public std::false_type {};
-
-template <>
-struct is_complex_t<complex64> : public std::true_type {};
-
-template <typename T>
-struct is_complex64_t : public std::false_type {};
-
-template <>
-struct is_complex64_t<complex64> : public std::true_type {};
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
@@ -147,2092 +134,48 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
return std::move(result);
}
-template <typename ReturnT, typename NativeT>
-StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
- HloInstruction* instruction,
- const std::function<ReturnT(NativeT)>& unary_op,
- const Literal& operand_literal) {
- const auto shape = instruction->shape();
- const auto* operand = instruction->operand(0);
-
- // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
- // removed.
- if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
- return Unimplemented(
- "Implicit broadcasting is currently unsupported in HLO evaluator "
- "Shape Mismatch: %s vs %s",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(operand->shape()).c_str());
- }
-
- auto result = Literal::CreateFromShape(shape);
-
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- return unary_op(operand_literal.Get<NativeT>(multi_index));
- }));
- return std::move(result);
-}
-
-// For one particular placement of a window in a base shape (the placement is
-// represented as `window_count_index`), iterates inside the window. Translates
-// the window index into base index. If the base index is within bound, call `f`
-// with the base index.
-void IterateThroughWindow(
- const Shape& window_shape, const Window& window, const Shape& base_shape,
- const ArraySlice<int64>& window_count_index,
- const std::function<void(const std::vector<int64>&)>& f) {
- const int64 rank = ShapeUtil::Rank(base_shape);
- DimensionVector window_index(rank);
- std::fill(window_index.begin(), window_index.end(), 0);
- do {
- std::vector<int64> base_index(rank);
- bool out_of_bound = false;
- for (int64 i = 0; i < rank; ++i) {
- base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
- if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
- out_of_bound = true;
- break;
- }
- }
- if (!out_of_bound) {
- f(base_index);
- }
- } while (IndexUtil::BumpIndices(window_shape, &window_index));
-}
-
-// Creates a vector of multipliers which can be used to create a linear index
-// into shape.
-//
-// Given the multidimensional index {i1, ..., iN} and
-// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply
-//
-// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N].
-//
-// This lets you calculate LI given the multidimensional indices in any order.
-DimensionVector MakeDimMultipliers(const Shape& shape) {
- DimensionVector v(ShapeUtil::Rank(shape));
- int64 scale = 1;
- for (auto dim : LayoutUtil::MinorToMajor(shape)) {
- v[dim] = scale;
- scale *= shape.dimensions(dim);
- }
- return v;
-}
-
} // namespace
-template <typename ReturnT, typename ElementwiseT>
-class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
- public:
- explicit TypedVisitor(HloEvaluator* p) : parent_(p) {}
-
- // The following higher-order functions convert a function with ElementwiseT
- // to a function with ReturnT.
- std::function<ReturnT(ReturnT)> ConvertUnaryFunction(
- const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
- return [&unary_op](ReturnT arg) {
- return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg)));
- };
- }
- std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction(
- const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
- binary_op) {
- return [&binary_op](ReturnT arg1, ReturnT arg2) {
- return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1),
- static_cast<ElementwiseT>(arg2)));
- };
- }
- std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction(
- const std::function<ElementwiseT(ElementwiseT, ElementwiseT,
- ElementwiseT)>& ternary_op) {
- return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) {
- return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1),
- static_cast<ElementwiseT>(arg2),
- static_cast<ElementwiseT>(arg3)));
- };
- }
-
- Status DefaultAction(HloInstruction* hlo_instruction) override {
- return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
- HloOpcodeString(hlo_instruction->opcode()).c_str());
- }
-
- // TODO(b/35950897): many of the stl functions used in the handlers are not
- // overloaded for every XLA primitive types.
-
- template <typename NativeT,
- typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
- nullptr>
- Status HandleAbs(HloInstruction* abs) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
- ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
- return elem_operand;
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
- Status HandleAbs(HloInstruction* abs) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
- ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
- return std::abs(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr>
- Status HandleAbs(HloInstruction* abs) {
- const Literal& operand_literal =
- parent_->GetEvaluatedLiteralFor(abs->operand(0));
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[abs],
- (ElementWiseUnaryOpImpl<float, NativeT>(
- abs, [](NativeT elem_operand) { return std::abs(elem_operand); },
- operand_literal)));
-
- return Status::OK();
- }
-
- Status HandleAbs(HloInstruction* abs) override {
- // If the operand is of C64 type, the return type of abs will be F32.
- // However, ElementwiseT would still be the return type, F32, and thus
- // specifying the ElementwiseT explicitly as C64 is needed below.
- if (abs->operand(0)->shape().element_type() == C64) {
- return HandleAbs<complex64>(abs);
- }
- return HandleAbs<ElementwiseT>(abs);
- }
-
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleRound(HloInstruction* round) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[round],
- ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) {
- return std::round(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleRound(HloInstruction* round) {
- return InvalidArgument("Unsupported type for Round");
- }
-
- Status HandleRound(HloInstruction* round) override {
- return HandleRound<ReturnT>(round);
- }
-
- Status HandleBroadcast(HloInstruction* broadcast) override {
- parent_->evaluated_[broadcast] =
- Literal::CreateFromShape(broadcast->shape());
- auto output = parent_->evaluated_[broadcast].get();
- const Literal& operand_to_broadcast =
- parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
- std::vector<int64> broadcast_indices(
- ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
-
- TF_RET_CHECK(broadcast->dimensions().size() ==
- ShapeUtil::Rank(operand_to_broadcast.shape()))
- << "broadcast dimensions is of size: " << broadcast->dimensions().size()
- << " and rank of operand_to_broadcast is: "
- << ShapeUtil::Rank(operand_to_broadcast.shape());
- // Checks that operand's dimensions are the same as the broadcast's
- // dimensions along the dimensions to be broadcasted.
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand_to_broadcast.shape().dimensions(i));
- }
-
- return output->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
- }
- return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
- });
- }
-
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleCeil(HloInstruction* ceil) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
- ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) {
- return std::ceil(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleCeil(HloInstruction* ceil) {
- return InvalidArgument("Unsupported type for Ceil");
- }
-
- Status HandleCeil(HloInstruction* ceil) override {
- return HandleCeil<ReturnT>(ceil);
- }
-
- Status HandleConvert(HloInstruction* convert) override {
- const HloInstruction* operand = convert->operand(0);
- TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
- parent_->GetEvaluatedLiteralFor(operand).Convert(
- convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
- return Status::OK();
- }
-
- Status HandleBitcastConvert(HloInstruction* convert) override {
- const HloInstruction* operand = convert->operand(0);
- TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
- parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
- convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
- }
- return Status::OK();
- }
-
- Status HandleExp(HloInstruction* exp) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
- ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) {
- return std::exp(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleFloor(HloInstruction* floor) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[floor],
- ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) {
- return std::floor(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleFloor(HloInstruction* floor) {
- return InvalidArgument("Unsupported type for Floor");
- }
-
- Status HandleFloor(HloInstruction* floor) override {
- return HandleFloor<ReturnT>(floor);
- }
-
- Status HandleLog(HloInstruction* log) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
- ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
- return std::log(elem_operand);
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_integral<NativeT>::value &&
- !std::is_same<NativeT, bool>::value>::type* = nullptr>
- Status HandleNot(HloInstruction* not_) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
- ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
- return ~elem_operand;
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleNot(HloInstruction* not_) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
- ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
- return !elem_operand;
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
- nullptr>
- Status HandleNot(HloInstruction* not_) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
- ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
- return !elem_operand;
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleNot(HloInstruction* not_) {
- return InvalidArgument("Unsupported type for Not");
- }
-
- Status HandleNot(HloInstruction* not_) override {
- return HandleNot<ElementwiseT>(not_);
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_signed<NativeT>::value &&
- !std::is_floating_point<NativeT>::value>::type* = nullptr>
- Status HandleNegate(HloInstruction* negate) {
- using type = typename std::make_unsigned<NativeT>::type;
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[negate],
- ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) {
- return NativeT(-type(elem_operand));
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<
- !std::is_signed<NativeT>::value ||
- std::is_floating_point<NativeT>::value>::type* = nullptr>
- Status HandleNegate(HloInstruction* negate) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[negate],
- ElementWiseUnaryOp(
- negate, [](ElementwiseT elem_operand) { return -elem_operand; }));
- return Status::OK();
- }
-
- Status HandleNegate(HloInstruction* negate) override {
- return HandleNegate<ReturnT>(negate);
- }
-
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleSign(HloInstruction* sign) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
- ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
- return (ElementwiseT(0) < elem_operand) -
- (elem_operand < ElementwiseT(0));
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleSign(HloInstruction* sign) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
- ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
- auto abs_val = std::abs(elem_operand);
- return 0 == abs_val ? ElementwiseT(0)
- : elem_operand / abs_val;
- }));
- return Status::OK();
- }
-
- Status HandleSign(HloInstruction* sign) override {
- return HandleSign<ReturnT>(sign);
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleAtan2(HloInstruction* atan2) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
- ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
- ElementwiseT rhs_elem) {
- return std::atan2(lhs_elem, rhs_elem);
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<!std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleAtan2(HloInstruction* atan2) {
- return InvalidArgument("Unsupported type for Atan2");
- }
-
- Status HandleAtan2(HloInstruction* atan2) override {
- return HandleAtan2<ElementwiseT>(atan2);
- }
-
- Status HandleTanh(HloInstruction* tanh) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
- ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
- return std::tanh(elem_operand);
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_signed<NativeT>::value &&
- !std::is_floating_point<NativeT>::value>::type* = nullptr>
- Status HandleMultiply(HloInstruction* multiply) {
- using type = typename std::make_unsigned<NativeT>::type;
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[multiply],
- ElementWiseBinaryOp(multiply,
- [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
- return NativeT(type(lhs_elem) * type(rhs_elem));
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<std::is_unsigned<NativeT>::value ||
- std::is_floating_point<NativeT>::value ||
- is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleMultiply(HloInstruction* multiply) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[multiply],
- ElementWiseBinaryOp(multiply,
- [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
- return lhs_elem * rhs_elem;
- }));
- return Status::OK();
- }
-
- Status HandleMultiply(HloInstruction* multiply) override {
- return HandleMultiply<ElementwiseT>(multiply);
- }
-
- Status HandleSubtract(HloInstruction* subtract) override {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[subtract],
- ElementWiseBinaryOp(subtract,
- [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
- return lhs_elem - rhs_elem;
- }));
- return Status::OK();
- }
-
- Status HandleAdd(HloInstruction* add) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[add],
- ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem,
- ElementwiseT rhs_elem) {
- return lhs_elem + rhs_elem;
- }));
- return Status::OK();
- }
-
- Status HandleDivide(HloInstruction* divide) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
- ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
- ElementwiseT rhs_elem) {
- return lhs_elem / rhs_elem;
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value>::type* =
- nullptr>
- Status HandleMaximum(HloInstruction* maximum) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[maximum],
- ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
- return std::max(lhs, rhs);
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleMaximum(HloInstruction* maximum) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[maximum],
- ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
- return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs;
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleMaximum(HloInstruction* maximum) {
- return InvalidArgument("Unsupported type for Maximum");
- }
-
- Status HandleMaximum(HloInstruction* maximum) override {
- return HandleMaximum<ElementwiseT>(maximum);
- }
-
- template <typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value>::type* =
- nullptr>
- Status HandleMinimum(HloInstruction* minimum) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
- ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
- ElementwiseT rhs_el) {
- return std::min(lhs_el, rhs_el);
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleMinimum(HloInstruction* minimum) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[minimum],
- ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
- ElementwiseT rhs_el) {
- return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el;
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleMinimum(HloInstruction* minimum) {
- return InvalidArgument("Unsupported type for Minimum");
- }
-
- Status HandleMinimum(HloInstruction* minimum) override {
- return HandleMinimum<ElementwiseT>(minimum);
- }
-
- Status HandlePower(HloInstruction* power) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[power],
- ElementWiseBinaryOp(power, [](ElementwiseT lhs_el,
- ElementwiseT rhs_el) {
- return std::pow(lhs_el, rhs_el);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleRemainder(HloInstruction* remainder) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
- ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
- ElementwiseT rhs_el) {
- return std::fmod(lhs_el, rhs_el);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleRemainder(HloInstruction* remainder) {
- return InvalidArgument("Unsupported type for Remainder");
- }
-
- Status HandleRemainder(HloInstruction* remainder) override {
- return HandleRemainder<ElementwiseT>(remainder);
- }
-
- template <typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value>::type* =
- nullptr>
- Status HandleAnd(HloInstruction* and_) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[and_],
- ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
- return lhs_el & rhs_el;
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleAnd(HloInstruction* and_) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[and_],
- ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
- return lhs_el && rhs_el;
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleAnd(HloInstruction* and_) {
- return InvalidArgument("Unsupported type for And");
- }
-
- Status HandleAnd(HloInstruction* and_) override {
- return HandleAnd<ElementwiseT>(and_);
- }
-
- template <typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value>::type* =
- nullptr>
- Status HandleOr(HloInstruction* or_) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[or_],
- ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
- return lhs_el | rhs_el;
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleOr(HloInstruction* or_) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[or_],
- ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
- return lhs_el || rhs_el;
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleOr(HloInstruction* or_) {
- return InvalidArgument("Unsupported type for Or");
- }
-
- Status HandleOr(HloInstruction* or_) override {
- return HandleOr<ElementwiseT>(or_);
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_integral<NativeT>::value &&
- !std::is_same<NativeT, bool>::value>::type* = nullptr>
- Status HandleShiftLeft(HloInstruction* shl) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[shl],
- ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
- return IsShiftOutOfBounds<NativeT>(rhs_elem) ? 0
- : (lhs_elem << rhs_elem);
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<!std::is_integral<NativeT>::value ||
- std::is_same<NativeT, bool>::value>::type* =
- nullptr>
- Status HandleShiftLeft(HloInstruction*) {
- return InvalidArgument("Unsupported type for ShiftLeft");
- }
-
- Status HandleShiftLeft(HloInstruction* shl) override {
- return HandleShiftLeft<ElementwiseT>(shl);
- }
- template <typename NativeT,
- typename std::enable_if<
- std::is_integral<NativeT>::value &&
- !std::is_same<NativeT, bool>::value>::type* = nullptr>
- Status HandleShiftRightArithmetic(HloInstruction* shr) {
- typedef typename std::make_signed<NativeT>::type SignedT;
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[shr],
- ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
- SignedT lhs_signed = static_cast<SignedT>(lhs_elem);
- if (IsShiftOutOfBounds<NativeT>(rhs_elem)) {
- return lhs_signed < 0 ? static_cast<SignedT>(-1) : 0;
- } else {
- return lhs_signed >> rhs_elem;
- }
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<!std::is_integral<NativeT>::value ||
- std::is_same<NativeT, bool>::value>::type* =
- nullptr>
- Status HandleShiftRightArithmetic(HloInstruction*) {
- return InvalidArgument("Unsupported type for ShiftRightArithmetic");
- }
-
- Status HandleShiftRightArithmetic(HloInstruction* shra) override {
- return HandleShiftRightArithmetic<ElementwiseT>(shra);
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_integral<NativeT>::value &&
- !std::is_same<NativeT, bool>::value>::type* = nullptr>
- Status HandleShiftRightLogical(HloInstruction* shr) {
- typedef typename std::make_unsigned<NativeT>::type UnsignedT;
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[shr],
- ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
- // If shift amount is greater than the number of bits, then return 0.
- if (IsShiftOutOfBounds<NativeT>(rhs_elem)) {
- return static_cast<NativeT>(0);
- }
- return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
- rhs_elem);
- }));
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<!std::is_integral<NativeT>::value ||
- std::is_same<NativeT, bool>::value>::type* =
- nullptr>
- Status HandleShiftRightLogical(HloInstruction*) {
- return InvalidArgument("Unsupported type for ShiftRightLogical");
- }
-
- Status HandleShiftRightLogical(HloInstruction* shrl) override {
- return HandleShiftRightLogical<ElementwiseT>(shrl);
- }
-
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleClamp(HloInstruction* clamp) {
- std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
- clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
- return std::fmin(high, std::fmax(value, low));
- };
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[clamp],
- ElementwiseTernaryOp(clamp,
- std::move(ConvertTernaryFunction(clamp_op))));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleClamp(HloInstruction*) {
- return InvalidArgument("Unsupported type for Clamp");
- }
-
- Status HandleClamp(HloInstruction* clamp) override {
- return HandleClamp<ElementwiseT>(clamp);
- }
-
- Status HandleSelect(HloInstruction* select) override {
- CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
- CHECK(!ShapeUtil::IsTuple(select->shape()));
- std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
- [](bool pred, ReturnT on_true, ReturnT on_false) {
- if (pred) {
- return on_true;
- }
- return on_false;
- };
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
- ElementwiseTernaryOp(select, std::move(select_op)));
- return Status::OK();
- }
-
- Status HandleReverse(HloInstruction* reverse) override {
- const auto result_shape = reverse->shape();
- const auto reverse_dimensions = reverse->dimensions();
-
- auto operand = reverse->operand(0);
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferReverseShape(operand->shape(),
- reverse_dimensions));
-
- TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
- << "return shape set to: " << ShapeUtil::HumanString(result_shape)
- << " but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
-
- const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = Literal::CreateFromShape(result_shape);
-
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> out_index) {
- std::vector<int64> from_index(out_index.begin(), out_index.end());
- for (const int64 dim : reverse_dimensions) {
- from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
- }
- return operand_literal.Get<ReturnT>(from_index);
- }));
-
- parent_->evaluated_[reverse] = std::move(result);
- return Status::OK();
- }
-
- Status HandleConvolution(HloInstruction* conv) override {
- auto lhs = conv->operand(0);
- auto rhs = conv->operand(1);
- const auto& window = conv->window();
- const Shape& result_shape = conv->shape();
- const Shape& lhs_shape = lhs->shape();
- const Shape& rhs_shape = rhs->shape();
-
- TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
- TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
- CHECK(ShapeUtil::IsArray(lhs_shape));
- CHECK(ShapeUtil::IsArray(rhs_shape));
- CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
- CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
-
- const auto& dnums = conv->convolution_dimension_numbers();
- const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
- CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size());
- CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
- CHECK_GE(num_spatial_dims, 0);
- CHECK_EQ(window.dimensions_size(), num_spatial_dims);
-
- const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
- const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
-
- CHECK_EQ(num_spatial_dims + 2, lhs_rank);
- CHECK_EQ(num_spatial_dims + 2, rhs_rank);
-
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
- window, dnums));
- CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
- << "return shape set to: " << ShapeUtil::HumanString(result_shape)
- << " but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
-
- const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
- const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
-
- std::vector<int64> window_dimension_sizes;
- for (auto i : dnums.kernel_spatial_dimensions()) {
- window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
- }
-
- const Shape& window_shape =
- ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
-
- DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
- DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
-
- auto lhs_literal_data = lhs_literal.data<ReturnT>();
- auto rhs_literal_data = rhs_literal.data<ReturnT>();
-
- auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
- &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](ArraySlice<int64> out_index) {
- // Dimension number applicable for input (lhs).
- const int64 input_batch_dim = dnums.input_batch_dimension();
- const int64 input_z_dim = dnums.input_feature_dimension();
- // Dimension number applicable for kernel (rhs).
- const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
- const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
- // Dimension number applicable for output.
- const int64 output_batch_dim = dnums.output_batch_dimension();
- const int64 output_z_dim = dnums.output_feature_dimension();
-
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
-
- ElementwiseT result_val = static_cast<ElementwiseT>(0);
- DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
- 0);
-
- // Convolve input feature with kernel.
- do {
- for (int64 iz = 0; iz < z_size; ++iz) {
- int64 lhs_linear_index = 0;
- lhs_linear_index += out_index[output_batch_dim] *
- lhs_dim_multipliers[input_batch_dim];
- lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
-
- int64 rhs_linear_index = 0;
- rhs_linear_index += out_index[output_z_dim] *
- rhs_dim_multipliers[kernel_output_z_dim];
- rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
-
- // Find corresponding spatial dimension index for input (lhs).
- for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
- // Spatial dimension number for input (lhs) and output.
- const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
- const int64 output_spatial_dim =
- dnums.output_spatial_dimensions(ki);
-
- // Calculate lhs (input) index without taking base dilation into
- // account.
- const auto& window_dim = window.dimensions(ki);
- const int64 undilated_index =
- out_index[output_spatial_dim] * window_dim.stride() -
- window_dim.padding_low() +
- rhs_spatial_index[ki] * window_dim.window_dilation();
- // Skip if the lhs (input) index is to be dilated. As an
- // optimization, skip this mod if there's no dilation.
- if (window_dim.base_dilation() > 1 &&
- undilated_index % window_dim.base_dilation() != 0) {
- goto cnt;
- }
-
- // Calculate the actual lhs (input) index after dilation. As an
- // optimization, skip this integer divide if there's no dilation.
- int64 lhs_spatial_index;
- if (window_dim.base_dilation() > 1) {
- lhs_spatial_index = undilated_index / window_dim.base_dilation();
- } else {
- lhs_spatial_index = undilated_index;
- }
- lhs_linear_index +=
- lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
-
- // Skip if input index is not in bounds.
- if (!(lhs_spatial_index >= 0 &&
- lhs_spatial_index <
- lhs_shape.dimensions(input_spatial_dim))) {
- goto cnt;
- }
-
- rhs_linear_index +=
- (window_dim.window_reversal()
- ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
- : rhs_spatial_index[ki]) *
- rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
- }
-
- result_val +=
- static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
- static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
- }
- cnt : {}
- } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
-
- return static_cast<ReturnT>(result_val);
- };
-
- auto result = Literal::CreateFromShape(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
-
- parent_->evaluated_[conv] = std::move(result);
- return Status::OK();
- }
-
- Status HandleDot(HloInstruction* dot) override {
- auto lhs = dot->operand(0);
- auto rhs = dot->operand(1);
- CHECK(ShapeUtil::IsArray(dot->shape()));
- CHECK(ShapeUtil::IsArray(lhs->shape()));
- CHECK(ShapeUtil::IsArray(rhs->shape()));
-
- const auto& dnums = dot->dot_dimension_numbers();
-
- const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
- const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
-
- CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
- CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
-
- // There must be 1 and only 1 Contracting dimension for lhs and rhs.
- CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1);
- CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1);
- const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
- const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
- // Contracted dimension sizes must be the same.
- CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
- rhs->shape().dimensions(rhs_contracting_dimension))
- << "lhs contracted dimension: "
- << lhs->shape().dimensions(lhs_contracting_dimension)
- << " rhs contracted dimension: "
- << rhs->shape().dimensions(rhs_contracting_dimension);
- const int64 contracted_dimension_size =
- lhs->shape().dimensions(lhs_contracting_dimension);
-
- const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
- const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
-
- auto result = Literal::CreateFromShape(dot->shape());
-
- CHECK_EQ(dnums.lhs_batch_dimensions_size(),
- dnums.rhs_batch_dimensions_size());
-
- std::vector<int64> lhs_non_contracting_dims;
- for (int64 i = 0; i < lhs_rank; i++) {
- if (i != lhs_contracting_dimension) {
- lhs_non_contracting_dims.push_back(i);
- }
- }
-
- std::vector<int64> rhs_non_batch_non_contracting_dims;
- FlatSet<int64> batch_dims_set(dnums.rhs_batch_dimensions().begin(),
- dnums.rhs_batch_dimensions().end());
- for (int64 i = 0; i < rhs_rank; i++) {
- if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
- rhs_non_batch_non_contracting_dims.push_back(i);
- }
- }
-
- const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
- const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
-
- DimensionVector lhs_index(lhs_rank);
- DimensionVector rhs_index(rhs_rank);
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> result_index) {
- ElementwiseT result_val = static_cast<ElementwiseT>(0);
-
- // Find the corresponding non-contracting indices for lhs and rhs.
- //
- // For `result_index`, its batch dimension, if exists, will be at the
- // same dimension as the batch dimension of lhs and rhs. More
- // specifically:
- // - For lhs, the non-contracting dimensions, including the batch
- // dimension have the same index as the `result_index`.
- // - For rhs, the batch dimension is set separately from other
- // non-contracting dimensions, since these other non-contracting
- // dimensions in rhs follow the non-contracting dimensions of lhs in
- // the resulting index.
- //
- // As an example, for a resulting index:
- // result_index [result_batch, result_x, result_y]
- // the effecting lhs and rhs indices are:
- // lhs [result_batch, lhs_non_contracting_dim, contracting_dim
- // rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
- // `result_x` is only affected by the lhs_non_contracting_dim and
- // likewise `result_y` only depends on rhs_non_contracting_dim.
- //
- // so we can look up the lhs and rhs indices by:
- //
- // lhs:
- // batch index is the same as `result_batch`.
- // non-contracting dimension is the same as
- // result_index[lhs_non_contracting_dim]
- // rhs:
- // batch index: the same as `result_batch`.
- // non-contracting dimension index: *not* the same as
- // result_index[rhs_non_contractng_dim], since the
- // non-contracting dimensions of lhs are included in the
- // result_index first. Instead, the non_contracting_dim of rhs must
- // be calculated as following:
- // lhs_non_contracting_dimensions_size +
- // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
- //
- // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
- // the index offset to the result_index that only depends on
- // the non_batch and non-contracting dimensions of rhs. -1 at the
- // end translates size to index.
- for (auto i : lhs_non_contracting_dims) {
- lhs_index[i] = result_index[i];
- }
- for (auto i : dnums.rhs_batch_dimensions()) {
- rhs_index[i] = result_index[i];
- }
- for (auto i : rhs_non_batch_non_contracting_dims) {
- const int64 rhs_non_batch_non_contracting_dim =
- lhs_non_contracting_size + (i - batch_dim_size) - 1;
- rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
- }
-
- // Accumulates resulting product along the contracted dimension.
- for (int64 i = 0; i < contracted_dimension_size; ++i) {
- lhs_index[lhs_contracting_dimension] = i;
- rhs_index[rhs_contracting_dimension] = i;
-
- result_val +=
- static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
- static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
- }
-
- return static_cast<ReturnT>(result_val);
- }));
-
- parent_->evaluated_[dot] = std::move(result);
- return Status::OK();
- }
-
- Status HandlePad(HloInstruction* pad) override {
- CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
- // Padding value must be scalar.
- CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
- CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
- pad->padding_config().dimensions_size());
-
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferPadShape(
- /*operand_shape=*/pad->operand(0)->shape(),
- /*padding_value_shape=*/pad->operand(1)->shape(),
- /*padding_config=*/pad->padding_config()));
- CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
- << "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
- << "but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
-
- // Create new HLO of padded shape with padding value.
- ReturnT scalar =
- parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = Literal::CreateFromShape(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&scalar](ArraySlice<int64> multi_index) { return scalar; }));
-
- const Literal& evaluated_operand =
- parent_->GetEvaluatedLiteralFor(pad->operand(0));
-
- std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
- 0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
-
- // Loop through each element of the operand, assign them to the
- // corresponding index of the resulting padded literal.
- const PaddingConfig& pad_config = pad->padding_config();
-
- auto func = [&](ArraySlice<int64> input_index) {
- for (auto i = 0; i < input_index.size(); ++i) {
- // Interior padding occurs logically before edge padding, so in the case
- // of negative edge padding elements are removed from the
- // interior-padded operand.
- target_index[i] =
- pad_config.dimensions(i).edge_padding_low() +
- input_index[i] * (pad_config.dimensions(i).interior_padding() + 1);
-
- // Account for negative low and high padding: skip assignment if the
- // any target index is out of range.
- if (!(target_index[i] >= 0 &&
- target_index[i] < pad->shape().dimensions(i))) {
- return true;
- }
- }
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
- return true;
- };
-
- std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(),
- 0);
- std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1);
-
- ShapeUtil::ForEachIndex(
- evaluated_operand.shape(), zero_base,
- AsInt64Slice(evaluated_operand.shape().dimensions()), step, func);
-
- parent_->evaluated_[pad] = std::move(result);
- return Status::OK();
- }
-
- Status HandleDynamicSlice(HloInstruction* dynamic_slice) override {
- auto operand = dynamic_slice->operand(0);
- auto start_indices = dynamic_slice->operand(1);
- auto result_shape = dynamic_slice->shape();
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferDynamicSliceShape(
- operand->shape(), start_indices->shape(),
- dynamic_slice->dynamic_slice_sizes()));
- TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
- << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
- << "but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
- TF_RET_CHECK(
- primitive_util::IsIntegralType(start_indices->shape().element_type()));
-
- const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- const Literal& start_indices_literal =
- parent_->GetEvaluatedLiteralFor(start_indices);
-
- switch (start_indices->shape().element_type()) {
- case S32: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_slice],
- DynamicSlice<int32>(operand_literal, start_indices_literal,
- result_shape));
- } break;
- case S64: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_slice],
- DynamicSlice<int64>(operand_literal, start_indices_literal,
- result_shape));
- } break;
- case U32: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_slice],
- DynamicSlice<uint32>(operand_literal, start_indices_literal,
- result_shape));
- } break;
- case U64: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_slice],
- DynamicSlice<uint64>(operand_literal, start_indices_literal,
- result_shape));
- } break;
- default:
- LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
- "start_indices: "
- << PrimitiveType_Name(start_indices->shape().element_type());
- }
-
- return Status::OK();
- }
-
- Status HandleDynamicUpdateSlice(
- HloInstruction* dynamic_update_slice) override {
- auto operand = dynamic_update_slice->operand(0);
- auto update = dynamic_update_slice->operand(1);
- auto start_indices = dynamic_update_slice->operand(2);
- auto result_shape = dynamic_update_slice->shape();
- TF_ASSIGN_OR_RETURN(
- auto inferred_return_shape,
- ShapeInference::InferDynamicUpdateSliceShape(
- operand->shape(), update->shape(), start_indices->shape()));
- TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
- << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
- << "but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
- TF_RET_CHECK(
- primitive_util::IsIntegralType(start_indices->shape().element_type()));
- TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
-
- const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
- const Literal& start_indices_literal =
- parent_->GetEvaluatedLiteralFor(start_indices);
-
- switch (start_indices->shape().element_type()) {
- case S32: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_update_slice],
- DynamicUpdateSlice<int32>(operand_literal, update_literal,
- start_indices_literal));
- } break;
- case S64: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_update_slice],
- DynamicUpdateSlice<int64>(operand_literal, update_literal,
- start_indices_literal));
- } break;
- case U32: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_update_slice],
- DynamicUpdateSlice<uint32>(operand_literal, update_literal,
- start_indices_literal));
- } break;
- case U64: {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[dynamic_update_slice],
- DynamicUpdateSlice<uint64>(operand_literal, update_literal,
- start_indices_literal));
- } break;
- default:
- LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
- "start_indices: "
- << PrimitiveType_Name(start_indices->shape().element_type());
- }
-
- return Status::OK();
- }
-
- template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
- auto operands = map->operands();
- HloComputation* computation = map->to_apply();
-
- auto result = Literal::CreateFromShape(map->shape());
-
- HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
- arg_literals.reserve(operands.size());
-
- // Construct scalar literal parameters to be passed to the map
- // computation.
- for (auto operand : operands) {
- const Literal& arg_literal =
- parent_->GetEvaluatedLiteralFor(operand);
-
- auto curr_val = arg_literal.Get<NativeT>(multi_index);
- auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
-
- arg_literals.push_back(std::move(curr_val_literal));
- }
-
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
- .ConsumeValueOrDie();
- // Clear visit states so that the we can use the evaluate again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
-
- return computed_result->Get<ReturnT>({});
- }));
- return std::move(result);
- }
-
- Status HandleMap(HloInstruction* map) override {
- switch (map->operand(0)->shape().element_type()) {
- case PRED: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map));
- break;
- }
- case U8: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
- break;
- }
- case U32: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
- break;
- }
- case U64: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map));
- break;
- }
- case S8: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
- break;
- }
- case S32: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
- break;
- }
- case S64: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
- break;
- }
- case F16: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
- MapImpl<Eigen::half>(map));
- break;
- }
- case F32: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
- break;
- }
- case F64: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map));
- break;
- }
- case C64: {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
- break;
- }
- default:
- LOG(FATAL) << "HandleMap: unhandled primitive type for "
- "input operand: "
- << PrimitiveType_Name(
- map->operand(0)->shape().element_type());
- }
-
- return Status::OK();
- }
-
- Status HandleReduce(HloInstruction* reduce) override {
- auto arg = reduce->operand(0);
- auto init_value = reduce->operand(1);
- ArraySlice<int64> dimensions(reduce->dimensions());
- HloComputation* function = reduce->to_apply();
- TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
- ShapeUtil::Rank(arg->shape()) - dimensions.size());
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferReduceShape(
- /*arg=*/arg->shape(),
- /*init_value=*/init_value->shape(),
- /*dimensions_to_reduce=*/dimensions,
- /*to_apply=*/function->ComputeProgramShape()));
- TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
- << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
- << "but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
-
- const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
- VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
- const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
- VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
- TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
- auto init_scalar = init_literal.Get<ReturnT>({});
-
- auto result = Literal::CreateFromShape(reduce->shape());
-
- const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
- std::vector<int64> arg_dim_steps(arg_dimensions.size());
- std::vector<int64> arg_dim_counts(arg_dimensions.size());
- for (const int64 dim : dimensions) {
- arg_dim_steps[dim] = 1;
- arg_dim_counts[dim] = arg_dimensions[dim];
- }
-
- // Map each dimension in the result to a dimension in arg that isn't
- // being reduced.
- std::vector<int64> result_to_arg_index;
- for (int64 i = 0; i < arg_dimensions.size(); ++i) {
- if (arg_dim_steps[i] == 0) {
- result_to_arg_index.push_back(i);
- }
- }
-
- HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- // For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- ReturnT result_val = init_scalar;
-
- std::vector<int64> base(arg_dimensions.size());
- for (int64 i = 0; i < multi_index.size(); ++i) {
- base[result_to_arg_index[i]] = multi_index[i];
- }
-
- // When the reduction is addition of floats, accumulate in a double
- // for better precision. Also, avoid creating Literals for the
- // intermediate results; it's much faster.
- if (ShapeUtil::ElementIsFloating(init_literal.shape()) &&
- IsScalarAdd(function)) {
- double computed_result = 0;
- auto func = [&](ArraySlice<int64> input_index) {
- computed_result += arg_literal.Get<float>(input_index);
- return true;
- };
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
- return static_cast<ReturnT>(computed_result);
- }
- auto func = [&](ArraySlice<int64> input_index) {
- auto curr_val = arg_literal.Get<ReturnT>(input_index);
-
- // Evaluate computation with specified literal operands.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
- auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
- std::vector<const Literal*> args = {result_val_literal.get(),
- curr_val_literal.get()};
-
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*function, args)
- .ConsumeValueOrDie();
- // Clear visit states so that we can use the evaluator again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
- // Assign computed result to result_val.
- result_val = computed_result->Get<ReturnT>({});
- return true;
- };
- // Computes one element of the result, reducing all dimensions that
- // contribute to that element.
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
- return result_val;
- }));
-
- parent_->evaluated_[reduce] = std::move(result);
- return Status::OK();
- }
-
- bool IsScalarAdd(HloComputation* computation) {
- HloInstruction* instruction = computation->root_instruction();
- if (instruction->opcode() == HloOpcode::kAdd &&
- computation->num_parameters() == 2) {
- const HloInstruction* lhs = instruction->operand(0);
- const HloInstruction* rhs = instruction->operand(1);
- return lhs->opcode() == HloOpcode::kParameter &&
- ShapeUtil::IsScalar(lhs->shape()) &&
- rhs->opcode() == HloOpcode::kParameter &&
- ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
- }
- return false;
- }
-
- Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
- auto operand = select_and_scatter->operand(0);
- auto source = select_and_scatter->operand(1);
- const Window& window = select_and_scatter->window();
-
- const Literal& init_literal =
- parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
- TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
- auto init_scalar = init_literal.Get<ReturnT>({});
-
- auto result = Literal::CreateFromShape(select_and_scatter->shape());
-
- // Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](ArraySlice<int64> output_index) { return init_scalar; }));
-
- std::vector<int64> window_dimension_sizes;
- for (const auto& window_dimension : window.dimensions()) {
- window_dimension_sizes.push_back(window_dimension.size());
- }
- const Shape window_shape = ShapeUtil::MakeShape(
- operand->shape().element_type(), window_dimension_sizes);
-
- HloComputation* select = select_and_scatter->select();
- HloComputation* scatter = select_and_scatter->scatter();
-
- const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
-
- int64 rank = ShapeUtil::Rank(operand_literal.shape());
-
- HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- DimensionVector source_index(rank);
-
- std::fill(source_index.begin(), source_index.end(), 0);
- do {
- // For each element in `source`, we place a window in `operand`. For each
- // window placement, we iterate inside the window twice:
- //
- // 1. Find the selected index by applying `select` function to all
- // elements. E.g., If the `select` function is GreaterEqual, the first
- // iteration through the window finds the biggest value and returns its
- // index.
- //
- // 2. Using the selected index, scatter value from `source` to result. We
- // do this by iterating through the window, and compare each index with
- // the selected index.
- optional<ReturnT> selected_val;
- optional<std::vector<int64>> selected_index;
-
- IterateThroughWindow(
- window_shape, window, operand_literal.shape(), source_index,
- [&](const std::vector<int64>& operand_index) {
- auto curr_val = operand_literal.Get<ReturnT>(operand_index);
- if (!selected_val) {
- selected_val = curr_val;
- selected_index = operand_index;
- }
- const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
- const auto selected_val_literal =
- Literal::CreateR0<ReturnT>(*selected_val);
-
- const std::vector<const Literal*> args = {
- selected_val_literal.get(), curr_val_literal.get()};
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*select, args)
- .ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
- if (selected) {
- selected_val = curr_val;
- selected_index = operand_index;
- }
- embedded_evaluator.ResetVisitStates();
- });
-
- IterateThroughWindow(
- window_shape, window, operand_literal.shape(), source_index,
- [&](const std::vector<int64>& operand_index) {
- if (std::equal(operand_index.begin(), operand_index.end(),
- selected_index->begin())) {
- auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- const auto source_literal = Literal::CreateR0<ReturnT>(source);
- const auto scattered_literal =
- Literal::CreateR0<ReturnT>(scattered);
-
- const std::vector<const Literal*> args = {
- source_literal.get(), scattered_literal.get()};
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*scatter, args)
- .ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
- // Clear visit states so that the we can use the evaluator again
- // on the same computation.
- embedded_evaluator.ResetVisitStates();
- }
- });
- } while (IndexUtil::BumpIndices(source->shape(), &source_index));
-
- parent_->evaluated_[select_and_scatter] = std::move(result);
- return Status::OK();
- }
-
- Status HandleReduceWindow(HloInstruction* reduce_window) override {
- auto operand = reduce_window->operand(0);
- const Window& window = reduce_window->window();
- HloComputation* function = reduce_window->to_apply();
- TF_ASSIGN_OR_RETURN(
- auto inferred_return_shape,
- ShapeInference::InferReduceWindowShape(
- /*operand_shape=*/reduce_window->operand(0)->shape(),
- /*init_value=*/reduce_window->operand(1)->shape(), window,
- /*to_apply_shape=*/function->ComputeProgramShape()));
- TF_RET_CHECK(
- ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
- << "return shape is set to: "
- << ShapeUtil::HumanStringWithLayout(reduce_window->shape())
- << "but is inferred to be: "
- << ShapeUtil::HumanStringWithLayout(inferred_return_shape);
-
- const Literal& operand_literal =
- parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
- VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
- const Literal& init_literal =
- parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
- VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
- TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
- auto init_scalar = init_literal.Get<ReturnT>({});
-
- auto result = Literal::CreateFromShape(reduce_window->shape());
-
- // Creates a Shape object from window, for iteration below.
- std::vector<int64> window_dimension_sizes;
- for (const auto& window_dimension : window.dimensions()) {
- window_dimension_sizes.push_back(window_dimension.size());
- }
- const Shape window_shape = ShapeUtil::MakeShape(
- operand->shape().element_type(), window_dimension_sizes);
-
- DimensionVector window_index(window.dimensions_size());
- DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
-
- HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- // For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> output_index) {
- ReturnT result_val = init_scalar;
-
- std::fill(window_index.begin(), window_index.end(), 0);
- std::fill(operand_index.begin(), operand_index.end(), 0);
-
- IterateThroughWindow(
- window_shape, window, operand_literal.shape(), output_index,
- [&](const std::vector<int64>& operand_index) {
- auto curr_val = operand_literal.Get<ReturnT>(operand_index);
-
- // Evaluate computation with specified literal operands.
- const auto curr_val_literal =
- Literal::CreateR0<ReturnT>(curr_val);
- const auto result_val_literal =
- Literal::CreateR0<ReturnT>(result_val);
- const std::vector<const Literal*> args = {
- result_val_literal.get(), curr_val_literal.get()};
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*function, args)
- .ConsumeValueOrDie();
-
- // Clear visit states so that the we can use the evaluate again
- // on the same computation.
- embedded_evaluator.ResetVisitStates();
-
- result_val = computed_result->Get<ReturnT>({});
- });
-
- return result_val;
- }));
-
- parent_->evaluated_[reduce_window] = std::move(result);
- return Status::OK();
- }
-
- Status HandleSlice(HloInstruction* slice) override {
- auto operand = slice->operand(0);
- const Shape& shape = slice->shape();
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferSliceShape(
- operand->shape(), slice->slice_starts(),
- slice->slice_limits(), slice->slice_strides()));
- TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape))
- << "return shape set to: " << ShapeUtil::HumanString(shape)
- << " but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
-
- const int64 rank = ShapeUtil::Rank(operand->shape());
- const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto func = [&](ArraySlice<int64> out_index) {
- DimensionVector operand_index(rank);
- for (int64 i = 0; i < rank; ++i) {
- operand_index[i] =
- slice->slice_starts(i) + out_index[i] * slice->slice_strides(i);
- }
- return operand_literal.Get<ReturnT>(operand_index);
- };
-
- auto result = Literal::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
- parent_->evaluated_[slice] = std::move(result);
- return Status::OK();
- }
-
- // Enable CLZ only for int32 and uint32.
- template <
- typename NativeT,
- typename std::enable_if<
- (std::is_floating_point<NativeT>::value ||
- std::is_integral<NativeT>::value || is_complex_t<NativeT>::value) &&
- !(std::is_same<NativeT, uint32>::value ||
- std::is_same<NativeT, int32>::value)>::type* = nullptr>
- Status HandleClz(HloInstruction* clz) {
- return InvalidArgument("Unsupported type for Clz");
- }
-
- template <typename NativeT,
- typename std::enable_if<
- std::is_same<NativeT, uint32>::value ||
- std::is_same<NativeT, int32>::value>::type* = nullptr>
- Status HandleClz(HloInstruction* clz) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz],
- ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) {
- return 31 - tensorflow::Log2Floor(elem_operand);
- }));
- return Status::OK();
- }
-
- Status HandleClz(HloInstruction* clz) override {
- return HandleClz<ElementwiseT>(clz);
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleSin(HloInstruction* sin) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
- ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) {
- return std::sin(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value ||
- is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleSin(HloInstruction* sin) {
- return InvalidArgument("Unsupported type for Sin");
- }
-
- Status HandleSin(HloInstruction* sin) override {
- return HandleSin<ElementwiseT>(sin);
- }
-
- template <typename NativeT, typename std::enable_if<std::is_floating_point<
- NativeT>::value>::type* = nullptr>
- Status HandleCos(HloInstruction* cos) {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
- ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) {
- return std::cos(elem_operand);
- }));
- return Status::OK();
- }
-
- template <
- typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value ||
- is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleCos(HloInstruction* cos) {
- return InvalidArgument("Unsupported type for Cos");
- }
-
- Status HandleCos(HloInstruction* cos) override {
- return HandleCos<ElementwiseT>(cos);
- }
-
- template <typename NativeT, typename std::enable_if<std::is_same<
- float, NativeT>::value>::type* = nullptr>
- Status HandleReducePrecision(HloInstruction* reduce_precision) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[reduce_precision],
- ElementWiseUnaryOp(reduce_precision, [reduce_precision](
- ElementwiseT elem) {
- uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem);
- const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
- const uint32_t exponent_bits = reduce_precision->exponent_bits();
-
- // Code is based on the CPU/GPU implementation in LLVM-emitting code.
- //
- // Bits in float type:
- // mantissa : bits [0:22]
- // exponent : bits [23:30]
- // sign : bits [31]
- if (mantissa_bits < 23) {
- const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
-
- // Compute rounding bias for round-to-nearest with ties to even.
- // This is equal to a base value of 0111... plus one bit if the last
- // remaining mantissa bit is 1.
- const uint32_t base_rounding_bias =
- (last_mantissa_bit_mask >> 1) - 1;
- const uint32_t x_last_mantissa_bit =
- (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
- const uint32_t x_rounding_bias =
- x_last_mantissa_bit + base_rounding_bias;
-
- // Add rounding bias, and mask out truncated bits. Note that the
- // case where adding the rounding bias overflows into the exponent
- // bits is correct; the non-masked mantissa bits will all be zero,
- // and the exponent will be incremented by one.
- const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
- value_as_int = value_as_int + x_rounding_bias;
- value_as_int = value_as_int & truncation_mask;
- }
- if (exponent_bits < 8) {
- // Masks for f32 values.
- const uint32_t f32_sign_bit_mask = 1u << 31;
- const uint32_t f32_exp_bits_mask = 0xffu << 23;
-
- // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
- // most- significant bit -- is equal to 1.0f for all exponent sizes.
- // Adding 2^(n-1)-1 to this gives us the highest non-infinite
- // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from
- // this gives us the lowest' exponent (corresponding to 0.0f).
- //
- // Thus, the f32 exponent corresponding to the highest non-infinite
- // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
- // exponent corresponding to the lowest exponent for a bit size of n
- // is (2^7-1) - 2^(n-1)-1.
- //
- // Note that we have already checked that exponents_bits >= 1.
- const uint32_t f32_exponent_bias = (1 << 7) - 1;
- const uint32_t reduced_exponent_bias =
- (1 << (exponent_bits - 1)) - 1;
- const uint32_t reduced_max_exponent =
- f32_exponent_bias + reduced_exponent_bias;
- const uint32_t reduced_min_exponent =
- f32_exponent_bias - reduced_exponent_bias;
-
- // Do we overflow or underflow?
- const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
- const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
- const bool x_underflows =
- x_exponent <= (reduced_min_exponent << 23);
-
- // Compute appropriately-signed values of zero and infinity.
- const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
- const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
-
- // Force to zero or infinity if overflow or underflow. (Note that
- // this truncates all denormal values to zero, rather than rounding
- // them.)
- value_as_int = x_overflows ? x_signed_inf : value_as_int;
- value_as_int = x_underflows ? x_signed_zero : value_as_int;
- }
-
- float reduced_result = tensorflow::bit_cast<float>(value_as_int);
- if (std::isnan(elem)) {
- reduced_result = mantissa_bits > 0
- ? elem
- : std::numeric_limits<float>::infinity();
- }
- return reduced_result;
- }));
- return Status::OK();
- }
-
- template <typename NativeT, typename std::enable_if<std::is_same<
- double, NativeT>::value>::type* = nullptr>
- Status HandleReducePrecision(HloInstruction* reduce_precision) {
- return InvalidArgument("Double not supported for reduce precision");
- }
-
- template <
- typename NativeT,
- typename std::enable_if<std::is_integral<NativeT>::value ||
- is_complex_t<NativeT>::value>::type* = nullptr>
- Status HandleReducePrecision(HloInstruction* reduce_precision) {
- return InvalidArgument("Unsupported type for reduce precision");
- }
-
- Status HandleReducePrecision(HloInstruction* reduce_precision) override {
- return HandleReducePrecision<ElementwiseT>(reduce_precision);
- }
-
- private:
- template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
- auto start_indices_typed = start_indices_literal.data<IndexT>();
- std::vector<int64> start(start_indices_typed.begin(),
- start_indices_typed.end());
-
- std::vector<int64> operand_indices(start.size());
-
- auto result = Literal::CreateFromShape(result_shape);
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- for (int64 i = 0; i < operand_indices.size(); ++i) {
- CHECK_GE(multi_index[i] + start[i], 0);
- // Mod is only used here to be consistent with the existing
- // backends' behavior.
- operand_indices[i] = (multi_index[i] + start[i]) %
- operand_literal.shape().dimensions(i);
- }
-
- auto result = operand_literal.Get<ReturnT>(operand_indices);
- return result;
- }));
-
- return std::move(result);
- }
-
- template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
- auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
- std::vector<int64> start(rank, 0);
- for (int64 i = 0; i < rank; ++i) {
- // All other implementations currently wrap-around the index, so this
- // should do so as well.
- start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
- start[i] += (start[i] < 0) * result->shape().dimensions(i);
- }
- std::vector<int64> result_index(rank, 0);
-
- auto func = [&](ArraySlice<int64> update_index) {
- std::transform(update_index.begin(), update_index.end(), start.begin(),
- result_index.begin(), std::plus<int64>());
- // Same as above, wrap-around only to match other implementations'
- // semantics.
- std::transform(result_index.begin(), result_index.end(),
- result->shape().dimensions().begin(), result_index.begin(),
- std::modulus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
- return true;
- };
-
- std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
- std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
- ShapeUtil::ForEachIndex(update_literal.shape(), base,
- AsInt64Slice(update_literal.shape().dimensions()),
- step, func);
-
- return std::move(result);
- }
-
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
- HloInstruction* instruction,
- const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
- const Literal& operand_literal =
- parent_->GetEvaluatedLiteralFor(instruction->operand(0));
- TF_ASSIGN_OR_RETURN(
- auto result_literal,
- (ElementWiseUnaryOpImpl<ReturnT, ReturnT>(
- instruction, ConvertUnaryFunction(unary_op), operand_literal)));
-
- return std::move(result_literal);
- }
-
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
- HloInstruction* instruction,
- const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
- binary_op) {
- const auto shape = instruction->shape();
- const auto* lhs = instruction->operand(0);
- const auto* rhs = instruction->operand(1);
-
- // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast
- // is removed.
- if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) &&
- ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
- return Unimplemented(
- "Implicit broadcasting is currently unsupported in HLO evaluator "
- "Shape Mismatch: %s vs %s vs %s: ",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str());
- }
-
- const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
- const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
-
- auto result = Literal::CreateFromShape(shape);
-
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- return ConvertBinaryFunction(binary_op)(
- lhs_literal.Get<ReturnT>(multi_index),
- rhs_literal.Get<ReturnT>(multi_index));
- }));
- return std::move(result);
- }
-
- template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
- HloInstruction* instruction,
- const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
- const auto shape = instruction->shape();
- const auto* lhs = instruction->operand(0);
- const auto* rhs = instruction->operand(1);
- const auto* ehs = instruction->operand(2);
-
- // TODO(b/35950897, b/27796129): add DCHECK back once implicit
- // broadcast is removed.
- if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) &&
- ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) &&
- ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) {
- return Unimplemented(
- "Implicit broadcasting is currently unsupported in HLO evaluator "
- "Shape Mismatch: %s vs %s vs %s vs %s: ",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str(),
- ShapeUtil::HumanString(ehs->shape()).c_str());
- }
-
- const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
- const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
-
- auto result = Literal::CreateFromShape(shape);
-
- TF_RETURN_IF_ERROR(
- result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) {
- return ternary_op(lhs_literal.Get<LhsType>(multi_index),
- rhs_literal.Get<RhsType>(multi_index),
- ehs_literal.Get<EhsType>(multi_index));
- }));
-
- return std::move(result);
- }
-
- template <typename NativeT>
- static bool IsShiftOutOfBounds(NativeT rhs) {
- typedef typename std::make_unsigned<NativeT>::type UnsignedT;
- UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT;
- UnsignedT rhs_unsigned = static_cast<UnsignedT>(rhs);
- return rhs_unsigned >= lhs_size_unsigned;
- }
-
- HloEvaluator* parent_;
-}; // class HloEvaluator::TypedVisitor
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
- typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
- typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
+ typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
+ typed_visitors_[U8] = MakeUnique<HloEvaluatorTypedVisitor<uint8>>(this);
typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
- "HloEvaluator::TypedVisitor: unhandled primitive type: U16.");
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "U16.");
});
- typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
- typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
- typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
+ typed_visitors_[U32] = MakeUnique<HloEvaluatorTypedVisitor<uint32>>(this);
+ typed_visitors_[U64] = MakeUnique<HloEvaluatorTypedVisitor<uint64>>(this);
+ typed_visitors_[S8] = MakeUnique<HloEvaluatorTypedVisitor<int8>>(this);
typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
- "HloEvaluator::TypedVisitor: unhandled primitive type: S16.");
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "S16.");
});
- typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
- typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
- typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this);
- typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
- typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
- typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
+ typed_visitors_[S32] = MakeUnique<HloEvaluatorTypedVisitor<int32>>(this);
+ typed_visitors_[S64] = MakeUnique<HloEvaluatorTypedVisitor<int64>>(this);
+ typed_visitors_[F16] =
+ MakeUnique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
+ typed_visitors_[F32] = MakeUnique<HloEvaluatorTypedVisitor<float>>(this);
+ typed_visitors_[F64] = MakeUnique<HloEvaluatorTypedVisitor<double>>(this);
+ typed_visitors_[C64] = MakeUnique<HloEvaluatorTypedVisitor<complex64>>(this);
// Most of the evaluator computations we use don't support BF16 (e.g.,
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
// elementwise computations to be done in F32 and do BF16<->F32 conversion
// around the input and the output of the computations.
- typed_visitors_[BF16] = MakeUnique<TypedVisitor<bfloat16, float>>(this);
+ typed_visitors_[BF16] =
+ MakeUnique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
- "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE.");
+ "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
});
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented(
- "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE.");
+ "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
});
}
@@ -3034,7 +977,7 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
// This would also handle output array of tuple types as the DefaultAction
- // would go through the TypedVisitor which doesn't handle tuples.
+ // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
evaluated_[select] = on_true.CloneToUnique();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index c0dcee0c3e..cc5676ea7b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -109,19 +109,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
substitutions);
protected:
- // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
- // literal type of each evaluated Handle* method of a TypedVisitor.
- // There are however a few notable exceptions to this rule, notably:
- // - HandleCompare and HandleIsFinite: where the resulting literal type is
- // always boolean.
- // These operations are handled outside of the parent HloEvaluator handlers
- // instead of from within TypedVisitor.
+ // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
+ // class.
//
- // Type params:
- // - ReturnT: The type of input and output of each operation.
- // - ElementwiseT: The type in which internal computation are done.
- template <typename ReturnT, typename ElementwiseT = ReturnT>
- class TypedVisitor;
+ // A straightforward implementation would be to make it a nested class
+ // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor
+ // lives as a separate class with its own header because its template gets
+ // instantiated many times and we want to use extern templates to shard out
+ // the compilation of those instantiations across multiple cc files.
+ template <typename ReturnT, typename ElementwiseT>
+ friend class HloEvaluatorTypedVisitor;
// Wraps around instruction handling to infer types before dispatching to
// the corresponding typed Visitor.
@@ -169,6 +166,33 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override;
private:
+ template <typename ReturnT, typename NativeT>
+ static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ HloInstruction* instruction,
+ const std::function<ReturnT(NativeT)>& unary_op,
+ const Literal& operand_literal) {
+ const auto shape = instruction->shape();
+ const auto* operand = instruction->operand(0);
+
+ // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
+ // removed.
+ if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
+ return Unimplemented(
+ "Implicit broadcasting is currently unsupported in HLO evaluator "
+ "Shape Mismatch: %s vs %s",
+ ShapeUtil::HumanString(shape).c_str(),
+ ShapeUtil::HumanString(operand->shape()).c_str());
+ }
+
+ auto result = Literal::CreateFromShape(shape);
+
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return unary_op(operand_literal.Get<NativeT>(multi_index));
+ }));
+ return std::move(result);
+ }
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
new file mode 100644
index 0000000000..f1cb363478
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -0,0 +1,2102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+
+namespace xla {
+
+// TODO(b/79274244): We'd like these type traits to live inside of
+// HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that
+// crashes clang in the frontend.
+//
+// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is
+// a "private" header that's not exposed outside of hlo_evaluator.cc.
+template <typename T>
+using is_complex_t = std::is_same<T, complex64>;
+template <typename T>
+using is_complex64_t = std::is_same<T, complex64>;
+
+// Templated DfsHloVisitor for use by HloEvaluator.
+//
+// Typically ReturnT here indicates the resulting literal type of each evaluated
+// Handle* method of a TypedVisitor. There are however a few notable exceptions
+// to this rule, notably:
+// - HandleCompare and HandleIsFinite: where the resulting literal type is
+// always boolean.
+// These operations are handled outside of the parent HloEvaluator handlers
+// instead of from within TypedVisitor.
+//
+// Type params:
+// - ReturnT: The type of input and output of each operation.
+// - ElementwiseT: The type in which internal computation are done.
+//
+// This a logically a private part of HloEvaluator. It lives in this header
+// file rather than in hlo_evaluator.cc because we use extern templates and a
+// bunch of independent cc files to speed up compiling the many instantiations
+// of this class.
+template <typename ReturnT, typename ElementwiseT = ReturnT>
+class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
+ public:
+ explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {}
+
+ // The following higher-order functions convert a function with ElementwiseT
+ // to a function with ReturnT.
+ std::function<ReturnT(ReturnT)> ConvertUnaryFunction(
+ const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
+ return [&unary_op](ReturnT arg) {
+ return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg)));
+ };
+ }
+ std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction(
+ const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
+ binary_op) {
+ return [&binary_op](ReturnT arg1, ReturnT arg2) {
+ return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1),
+ static_cast<ElementwiseT>(arg2)));
+ };
+ }
+ std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction(
+ const std::function<ElementwiseT(ElementwiseT, ElementwiseT,
+ ElementwiseT)>& ternary_op) {
+ return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) {
+ return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1),
+ static_cast<ElementwiseT>(arg2),
+ static_cast<ElementwiseT>(arg3)));
+ };
+ }
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
+ HloOpcodeString(hlo_instruction->opcode()).c_str());
+ }
+
+ // TODO(b/35950897): many of the stl functions used in the handlers are not
+ // overloaded for every XLA primitive type.
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
+ nullptr>
+ Status HandleAbs(HloInstruction* abs) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
+ ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
+ return elem_operand;
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
+ Status HandleAbs(HloInstruction* abs) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
+ ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
+ return std::abs(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr>
+ Status HandleAbs(HloInstruction* abs) {
+ const Literal& operand_literal =
+ parent_->GetEvaluatedLiteralFor(abs->operand(0));
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[abs],
+ (HloEvaluator::ElementWiseUnaryOpImpl<float, NativeT>(
+ abs, [](NativeT elem_operand) { return std::abs(elem_operand); },
+ operand_literal)));
+
+ return Status::OK();
+ }
+
+ Status HandleAbs(HloInstruction* abs) override {
+ // If the operand is of C64 type, the return type of abs will be F32.
+ // However, ElementwiseT would still be the return type, F32, and thus
+ // specifying the ElementwiseT explicitly as C64 is needed below.
+ if (abs->operand(0)->shape().element_type() == C64) {
+ return HandleAbs<complex64>(abs);
+ }
+ return HandleAbs<ElementwiseT>(abs);
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleRound(HloInstruction* round) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[round],
+ ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) {
+ return std::round(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleRound(HloInstruction* round) {
+ return InvalidArgument("Unsupported type for Round");
+ }
+
+ Status HandleRound(HloInstruction* round) override {
+ return HandleRound<ReturnT>(round);
+ }
+
+ Status HandleBroadcast(HloInstruction* broadcast) override {
+ parent_->evaluated_[broadcast] =
+ Literal::CreateFromShape(broadcast->shape());
+ auto output = parent_->evaluated_[broadcast].get();
+ const Literal& operand_to_broadcast =
+ parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
+ std::vector<int64> broadcast_indices(
+ ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
+
+ TF_RET_CHECK(broadcast->dimensions().size() ==
+ ShapeUtil::Rank(operand_to_broadcast.shape()))
+ << "broadcast dimensions is of size: " << broadcast->dimensions().size()
+ << " and rank of operand_to_broadcast is: "
+ << ShapeUtil::Rank(operand_to_broadcast.shape());
+ // Checks that operand's dimensions are the same as the broadcast's
+ // dimensions along the dimensions to be broadcasted.
+ for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
+ TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
+ operand_to_broadcast.shape().dimensions(i));
+ }
+
+ return output->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
+ broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
+ }
+ return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
+ });
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleCeil(HloInstruction* ceil) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
+ ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) {
+ return std::ceil(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleCeil(HloInstruction* ceil) {
+ return InvalidArgument("Unsupported type for Ceil");
+ }
+
+ Status HandleCeil(HloInstruction* ceil) override {
+ return HandleCeil<ReturnT>(ceil);
+ }
+
+ Status HandleConvert(HloInstruction* convert) override {
+ const HloInstruction* operand = convert->operand(0);
+ TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ parent_->GetEvaluatedLiteralFor(operand).Convert(
+ convert->shape().element_type()));
+
+ if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ parent_->evaluated_[convert] = std::move(result);
+ } else {
+ parent_->evaluated_[convert] =
+ result->Relayout(convert->shape().layout());
+ }
+ return Status::OK();
+ }
+
+ Status HandleBitcastConvert(HloInstruction* convert) override {
+ const HloInstruction* operand = convert->operand(0);
+ TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
+ convert->shape().element_type()));
+
+ if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ parent_->evaluated_[convert] = std::move(result);
+ } else {
+ parent_->evaluated_[convert] =
+ result->Relayout(convert->shape().layout());
+ }
+ return Status::OK();
+ }
+
+ Status HandleExp(HloInstruction* exp) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
+ ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) {
+ return std::exp(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleFloor(HloInstruction* floor) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[floor],
+ ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) {
+ return std::floor(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleFloor(HloInstruction* floor) {
+ return InvalidArgument("Unsupported type for Floor");
+ }
+
+ Status HandleFloor(HloInstruction* floor) override {
+ return HandleFloor<ReturnT>(floor);
+ }
+
+ Status HandleLog(HloInstruction* log) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
+ ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
+ return std::log(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleNot(HloInstruction* not_) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
+ ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
+ return ~elem_operand;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleNot(HloInstruction* not_) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
+ ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
+ return !elem_operand;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleNot(HloInstruction* not_) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
+ ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
+ return !elem_operand;
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleNot(HloInstruction* not_) {
+ return InvalidArgument("Unsupported type for Not");
+ }
+
+ Status HandleNot(HloInstruction* not_) override {
+ return HandleNot<ElementwiseT>(not_);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_signed<NativeT>::value &&
+ !std::is_floating_point<NativeT>::value>::type* = nullptr>
+ Status HandleNegate(HloInstruction* negate) {
+ using type = typename std::make_unsigned<NativeT>::type;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[negate],
+ ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) {
+ return NativeT(-type(elem_operand));
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ !std::is_signed<NativeT>::value ||
+ std::is_floating_point<NativeT>::value>::type* = nullptr>
+ Status HandleNegate(HloInstruction* negate) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[negate],
+ ElementWiseUnaryOp(
+ negate, [](ElementwiseT elem_operand) { return -elem_operand; }));
+ return Status::OK();
+ }
+
+ Status HandleNegate(HloInstruction* negate) override {
+ return HandleNegate<ReturnT>(negate);
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleSign(HloInstruction* sign) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
+ ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
+ return (ElementwiseT(0) < elem_operand) -
+ (elem_operand < ElementwiseT(0));
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleSign(HloInstruction* sign) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
+ ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
+ auto abs_val = std::abs(elem_operand);
+ return 0 == abs_val ? ElementwiseT(0)
+ : elem_operand / abs_val;
+ }));
+ return Status::OK();
+ }
+
+ Status HandleSign(HloInstruction* sign) override {
+ return HandleSign<ReturnT>(sign);
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleAtan2(HloInstruction* atan2) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
+ ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
+ ElementwiseT rhs_elem) {
+ return std::atan2(lhs_elem, rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<!std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleAtan2(HloInstruction* atan2) {
+ return InvalidArgument("Unsupported type for Atan2");
+ }
+
+ Status HandleAtan2(HloInstruction* atan2) override {
+ return HandleAtan2<ElementwiseT>(atan2);
+ }
+
+ Status HandleTanh(HloInstruction* tanh) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
+ ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
+ return std::tanh(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_signed<NativeT>::value &&
+ !std::is_floating_point<NativeT>::value>::type* = nullptr>
+ Status HandleMultiply(HloInstruction* multiply) {
+ using type = typename std::make_unsigned<NativeT>::type;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[multiply],
+ ElementWiseBinaryOp(multiply,
+ [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
+ return NativeT(type(lhs_elem) * type(rhs_elem));
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_unsigned<NativeT>::value ||
+ std::is_floating_point<NativeT>::value ||
+ is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleMultiply(HloInstruction* multiply) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[multiply],
+ ElementWiseBinaryOp(multiply,
+ [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
+ return lhs_elem * rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ Status HandleMultiply(HloInstruction* multiply) override {
+ return HandleMultiply<ElementwiseT>(multiply);
+ }
+
+ Status HandleSubtract(HloInstruction* subtract) override {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[subtract],
+ ElementWiseBinaryOp(subtract,
+ [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
+ return lhs_elem - rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ Status HandleAdd(HloInstruction* add) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[add],
+ ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem,
+ ElementwiseT rhs_elem) {
+ return lhs_elem + rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ Status HandleDivide(HloInstruction* divide) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
+ ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
+ ElementwiseT rhs_elem) {
+ return lhs_elem / rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleMaximum(HloInstruction* maximum) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[maximum],
+ ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
+ return std::max(lhs, rhs);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleMaximum(HloInstruction* maximum) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[maximum],
+ ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
+ return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs;
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleMaximum(HloInstruction* maximum) {
+ return InvalidArgument("Unsupported type for Maximum");
+ }
+
+ Status HandleMaximum(HloInstruction* maximum) override {
+ return HandleMaximum<ElementwiseT>(maximum);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleMinimum(HloInstruction* minimum) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
+ ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
+ ElementwiseT rhs_el) {
+ return std::min(lhs_el, rhs_el);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleMinimum(HloInstruction* minimum) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[minimum],
+ ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
+ ElementwiseT rhs_el) {
+ return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el;
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleMinimum(HloInstruction* minimum) {
+ return InvalidArgument("Unsupported type for Minimum");
+ }
+
+ Status HandleMinimum(HloInstruction* minimum) override {
+ return HandleMinimum<ElementwiseT>(minimum);
+ }
+
+ Status HandlePower(HloInstruction* power) override {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[power],
+ ElementWiseBinaryOp(power, [](ElementwiseT lhs_el,
+ ElementwiseT rhs_el) {
+ return std::pow(lhs_el, rhs_el);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleRemainder(HloInstruction* remainder) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
+ ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
+ ElementwiseT rhs_el) {
+ return std::fmod(lhs_el, rhs_el);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleRemainder(HloInstruction* remainder) {
+ return InvalidArgument("Unsupported type for Remainder");
+ }
+
+ Status HandleRemainder(HloInstruction* remainder) override {
+ return HandleRemainder<ElementwiseT>(remainder);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleAnd(HloInstruction* and_) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[and_],
+ ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
+ return lhs_el & rhs_el;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleAnd(HloInstruction* and_) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[and_],
+ ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
+ return lhs_el && rhs_el;
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleAnd(HloInstruction* and_) {
+ return InvalidArgument("Unsupported type for And");
+ }
+
+ Status HandleAnd(HloInstruction* and_) override {
+ return HandleAnd<ElementwiseT>(and_);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleOr(HloInstruction* or_) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[or_],
+ ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
+ return lhs_el | rhs_el;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleOr(HloInstruction* or_) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[or_],
+ ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
+ return lhs_el || rhs_el;
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleOr(HloInstruction* or_) {
+ return InvalidArgument("Unsupported type for Or");
+ }
+
+ Status HandleOr(HloInstruction* or_) override {
+ return HandleOr<ElementwiseT>(or_);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftLeft(HloInstruction* shl) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shl],
+ ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return IsShiftOutOfBounds<NativeT>(rhs_elem) ? 0
+ : (lhs_elem << rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftLeft(HloInstruction*) {
+ return InvalidArgument("Unsupported type for ShiftLeft");
+ }
+
+ Status HandleShiftLeft(HloInstruction* shl) override {
+ return HandleShiftLeft<ElementwiseT>(shl);
+ }
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftRightArithmetic(HloInstruction* shr) {
+ typedef typename std::make_signed<NativeT>::type SignedT;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shr],
+ ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
+ SignedT lhs_signed = static_cast<SignedT>(lhs_elem);
+ if (IsShiftOutOfBounds<NativeT>(rhs_elem)) {
+ return lhs_signed < 0 ? static_cast<SignedT>(-1) : 0;
+ } else {
+ return lhs_signed >> rhs_elem;
+ }
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftRightArithmetic(HloInstruction*) {
+ return InvalidArgument("Unsupported type for ShiftRightArithmetic");
+ }
+
+ Status HandleShiftRightArithmetic(HloInstruction* shra) override {
+ return HandleShiftRightArithmetic<ElementwiseT>(shra);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftRightLogical(HloInstruction* shr) {
+ typedef typename std::make_unsigned<NativeT>::type UnsignedT;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shr],
+ ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
+ // If shift amount is greater than the number of bits, then return 0.
+ if (IsShiftOutOfBounds<NativeT>(rhs_elem)) {
+ return static_cast<NativeT>(0);
+ }
+ return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
+ rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftRightLogical(HloInstruction*) {
+ return InvalidArgument("Unsupported type for ShiftRightLogical");
+ }
+
+ Status HandleShiftRightLogical(HloInstruction* shrl) override {
+ return HandleShiftRightLogical<ElementwiseT>(shrl);
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleClamp(HloInstruction* clamp) {
+ std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
+ clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
+ return std::fmin(high, std::fmax(value, low));
+ };
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[clamp],
+ ElementwiseTernaryOp(clamp,
+ std::move(ConvertTernaryFunction(clamp_op))));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleClamp(HloInstruction*) {
+ return InvalidArgument("Unsupported type for Clamp");
+ }
+
+ Status HandleClamp(HloInstruction* clamp) override {
+ return HandleClamp<ElementwiseT>(clamp);
+ }
+
+ Status HandleSelect(HloInstruction* select) override {
+ CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
+ CHECK(!ShapeUtil::IsTuple(select->shape()));
+ std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
+ [](bool pred, ReturnT on_true, ReturnT on_false) {
+ if (pred) {
+ return on_true;
+ }
+ return on_false;
+ };
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
+ ElementwiseTernaryOp(select, std::move(select_op)));
+ return Status::OK();
+ }
+
+ Status HandleReverse(HloInstruction* reverse) override {
+ const auto result_shape = reverse->shape();
+ const auto reverse_dimensions = reverse->dimensions();
+
+ auto operand = reverse->operand(0);
+ TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
+ ShapeInference::InferReverseShape(operand->shape(),
+ reverse_dimensions));
+
+ TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
+ << "return shape set to: " << ShapeUtil::HumanString(result_shape)
+ << " but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ auto result = Literal::CreateFromShape(result_shape);
+
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ std::vector<int64> from_index(out_index.begin(), out_index.end());
+ for (const int64 dim : reverse_dimensions) {
+ from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
+ }
+ return operand_literal.Get<ReturnT>(from_index);
+ }));
+
+ parent_->evaluated_[reverse] = std::move(result);
+ return Status::OK();
+ }
+
+ Status HandleConvolution(HloInstruction* conv) override {
+ auto lhs = conv->operand(0);
+ auto rhs = conv->operand(1);
+ const auto& window = conv->window();
+ const Shape& result_shape = conv->shape();
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
+
+ TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
+ TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
+ CHECK(ShapeUtil::IsArray(lhs_shape));
+ CHECK(ShapeUtil::IsArray(rhs_shape));
+ CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
+ CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
+
+ const auto& dnums = conv->convolution_dimension_numbers();
+ const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
+ CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size());
+ CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
+ CHECK_GE(num_spatial_dims, 0);
+ CHECK_EQ(window.dimensions_size(), num_spatial_dims);
+
+ const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
+ const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
+
+ CHECK_EQ(num_spatial_dims + 2, lhs_rank);
+ CHECK_EQ(num_spatial_dims + 2, rhs_rank);
+
+ TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
+ window, dnums));
+ CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
+ << "return shape set to: " << ShapeUtil::HumanString(result_shape)
+ << " but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+
+ const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
+ const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
+
+ std::vector<int64> window_dimension_sizes;
+ for (auto i : dnums.kernel_spatial_dimensions()) {
+ window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
+ }
+
+ const Shape& window_shape =
+ ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
+
+ DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
+ DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
+
+ auto lhs_literal_data = lhs_literal.data<ReturnT>();
+ auto rhs_literal_data = rhs_literal.data<ReturnT>();
+
+ auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
+ &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
+ rhs_literal_data](
+ tensorflow::gtl::ArraySlice<int64> out_index) {
+ // Dimension number applicable for input (lhs).
+ const int64 input_batch_dim = dnums.input_batch_dimension();
+ const int64 input_z_dim = dnums.input_feature_dimension();
+ // Dimension number applicable for kernel (rhs).
+ const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
+ const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
+ // Dimension number applicable for output.
+ const int64 output_batch_dim = dnums.output_batch_dimension();
+ const int64 output_z_dim = dnums.output_feature_dimension();
+
+ const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+
+ ElementwiseT result_val = static_cast<ElementwiseT>(0);
+ DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
+ 0);
+
+ // Convolve input feature with kernel.
+ do {
+ for (int64 iz = 0; iz < z_size; ++iz) {
+ int64 lhs_linear_index = 0;
+ lhs_linear_index += out_index[output_batch_dim] *
+ lhs_dim_multipliers[input_batch_dim];
+ lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
+
+ int64 rhs_linear_index = 0;
+ rhs_linear_index += out_index[output_z_dim] *
+ rhs_dim_multipliers[kernel_output_z_dim];
+ rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
+
+ // Find corresponding spatial dimension index for input (lhs).
+ for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
+ // Spatial dimension number for input (lhs) and output.
+ const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
+ const int64 output_spatial_dim =
+ dnums.output_spatial_dimensions(ki);
+
+ // Calculate lhs (input) index without taking base dilation into
+ // account.
+ const auto& window_dim = window.dimensions(ki);
+ const int64 undilated_index =
+ out_index[output_spatial_dim] * window_dim.stride() -
+ window_dim.padding_low() +
+ rhs_spatial_index[ki] * window_dim.window_dilation();
+ // Skip if the lhs (input) index is to be dilated. As an
+ // optimization, skip this mod if there's no dilation.
+ if (window_dim.base_dilation() > 1 &&
+ undilated_index % window_dim.base_dilation() != 0) {
+ goto cnt;
+ }
+
+ // Calculate the actual lhs (input) index after dilation. As an
+ // optimization, skip this integer divide if there's no dilation.
+ int64 lhs_spatial_index;
+ if (window_dim.base_dilation() > 1) {
+ lhs_spatial_index = undilated_index / window_dim.base_dilation();
+ } else {
+ lhs_spatial_index = undilated_index;
+ }
+ lhs_linear_index +=
+ lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
+
+ // Skip if input index is not in bounds.
+ if (!(lhs_spatial_index >= 0 &&
+ lhs_spatial_index <
+ lhs_shape.dimensions(input_spatial_dim))) {
+ goto cnt;
+ }
+
+ rhs_linear_index +=
+ (window_dim.window_reversal()
+ ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
+ : rhs_spatial_index[ki]) *
+ rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
+ }
+
+ result_val +=
+ static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
+ static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
+ }
+ cnt : {}
+ } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
+
+ return static_cast<ReturnT>(result_val);
+ };
+
+ auto result = Literal::CreateFromShape(result_shape);
+ TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+
+ parent_->evaluated_[conv] = std::move(result);
+ return Status::OK();
+ }
+
+ Status HandleDot(HloInstruction* dot) override {
+ auto lhs = dot->operand(0);
+ auto rhs = dot->operand(1);
+ CHECK(ShapeUtil::IsArray(dot->shape()));
+ CHECK(ShapeUtil::IsArray(lhs->shape()));
+ CHECK(ShapeUtil::IsArray(rhs->shape()));
+
+ const auto& dnums = dot->dot_dimension_numbers();
+
+ const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
+ const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
+
+ CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
+ CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
+
+ // There must be 1 and only 1 Contracting dimension for lhs and rhs.
+ CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1);
+ CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1);
+ const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
+ const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
+ // Contracted dimension sizes must be the same.
+ CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
+ rhs->shape().dimensions(rhs_contracting_dimension))
+ << "lhs contracted dimension: "
+ << lhs->shape().dimensions(lhs_contracting_dimension)
+ << " rhs contracted dimension: "
+ << rhs->shape().dimensions(rhs_contracting_dimension);
+ const int64 contracted_dimension_size =
+ lhs->shape().dimensions(lhs_contracting_dimension);
+
+ const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
+ const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
+
+ auto result = Literal::CreateFromShape(dot->shape());
+
+ CHECK_EQ(dnums.lhs_batch_dimensions_size(),
+ dnums.rhs_batch_dimensions_size());
+
+ std::vector<int64> lhs_non_contracting_dims;
+ for (int64 i = 0; i < lhs_rank; i++) {
+ if (i != lhs_contracting_dimension) {
+ lhs_non_contracting_dims.push_back(i);
+ }
+ }
+
+ std::vector<int64> rhs_non_batch_non_contracting_dims;
+ tensorflow::gtl::FlatSet<int64> batch_dims_set(
+ dnums.rhs_batch_dimensions().begin(),
+ dnums.rhs_batch_dimensions().end());
+ for (int64 i = 0; i < rhs_rank; i++) {
+ if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
+ rhs_non_batch_non_contracting_dims.push_back(i);
+ }
+ }
+
+ const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
+ const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
+
+ DimensionVector lhs_index(lhs_rank);
+ DimensionVector rhs_index(rhs_rank);
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> result_index) {
+ ElementwiseT result_val = static_cast<ElementwiseT>(0);
+
+ // Find the corresponding non-contracting indices for lhs and rhs.
+ //
+ // For `result_index`, its batch dimension, if exists, will be at the
+ // same dimension as the batch dimension of lhs and rhs. More
+ // specifically:
+ // - For lhs, the non-contracting dimensions, including the batch
+ // dimension have the same index as the `result_index`.
+ // - For rhs, the batch dimension is set seperately from other
+ // non-contracting dimensions, since these other non-contracting
+ // dimensions in rhs follow the non-contracting dimensions of lhs in
+ // the resulting index.
+ //
+ // As an example, for a resulting index:
+ // result_index [result_batch, result_x, result_y]
+ // the effecting lhs and rhs indices are:
+ // lhs [result_batch, lhs_non_contracting_dim, contracting_dim
+ // rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
+ // `result_x` is only affected by the lhs_non_contracting_dim and
+ // likewise `result_y` only depends on rhs_non_contracting_dim.
+ //
+ // so we can look up the lhs and rhs indices by:
+ //
+ // lhs:
+ // batch index is the same as `result_batch`.
+ // non-contracting dimension is the same as
+ // result_index[lhs_non_contracting_dim]
+ // rhs:
+ // batch index: the same as `result_batch`.
+ // non-contracting dimension index: *not* the same as
+ // result_index[rhs_non_contractng_dim], since the
+ // non-contracting dimensions of lhs are included in the
+ // result_index first. Instead, the non_contracting_dim of rhs must
+ // be calculated as following:
+ // lhs_non_contracting_dimensions_size +
+ // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
+ //
+ // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
+ // the index offset to the result_index that only depends on
+ // the non_batch and non-contracting dimensions of rhs. -1 at the
+ // end translates size to index.
+ for (auto i : lhs_non_contracting_dims) {
+ lhs_index[i] = result_index[i];
+ }
+ for (auto i : dnums.rhs_batch_dimensions()) {
+ rhs_index[i] = result_index[i];
+ }
+ for (auto i : rhs_non_batch_non_contracting_dims) {
+ const int64 rhs_non_batch_non_contracting_dim =
+ lhs_non_contracting_size + (i - batch_dim_size) - 1;
+ rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
+ }
+
+ // Accumulates resulting product along the contracted dimension.
+ for (int64 i = 0; i < contracted_dimension_size; ++i) {
+ lhs_index[lhs_contracting_dimension] = i;
+ rhs_index[rhs_contracting_dimension] = i;
+
+ result_val +=
+ static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
+ static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
+ }
+
+ return static_cast<ReturnT>(result_val);
+ }));
+
+ parent_->evaluated_[dot] = std::move(result);
+ return Status::OK();
+ }
+
+ Status HandlePad(HloInstruction* pad) override {
+ CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
+ // Padding value must be scalar.
+ CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
+ CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
+ pad->padding_config().dimensions_size());
+
+ TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
+ ShapeInference::InferPadShape(
+ /*operand_shape=*/pad->operand(0)->shape(),
+ /*padding_value_shape=*/pad->operand(1)->shape(),
+ /*padding_config=*/pad->padding_config()));
+ CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
+ << "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
+ << "but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+
+ // Create new HLO of padded shape with padding value.
+ ReturnT scalar =
+ parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
+ auto result = Literal::CreateFromShape(pad->shape());
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return scalar;
+ }));
+
+ const Literal& evaluated_operand =
+ parent_->GetEvaluatedLiteralFor(pad->operand(0));
+
+ std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
+ 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+
+ // Loop through each element of the operand, assign them to the
+ // corresponding index of the resulting padded literal.
+ const PaddingConfig& pad_config = pad->padding_config();
+
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ for (auto i = 0; i < input_index.size(); ++i) {
+ // Interior padding occurs logically before edge padding, so in the case
+ // of negative edge padding elements are removed from the
+ // interior-padded operand.
+ target_index[i] =
+ pad_config.dimensions(i).edge_padding_low() +
+ input_index[i] * (pad_config.dimensions(i).interior_padding() + 1);
+
+ // Account for negative low and high padding: skip assignment if the
+ // any target index is out of range.
+ if (!(target_index[i] >= 0 &&
+ target_index[i] < pad->shape().dimensions(i))) {
+ return true;
+ }
+ }
+ result->Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
+ return true;
+ };
+
+ std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(),
+ 0);
+ std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1);
+
+ ShapeUtil::ForEachIndex(
+ evaluated_operand.shape(), zero_base,
+ AsInt64Slice(evaluated_operand.shape().dimensions()), step, func);
+
+ parent_->evaluated_[pad] = std::move(result);
+ return Status::OK();
+ }
+
+ Status HandleDynamicSlice(HloInstruction* dynamic_slice) override {
+ auto operand = dynamic_slice->operand(0);
+ auto start_indices = dynamic_slice->operand(1);
+ auto result_shape = dynamic_slice->shape();
+ TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
+ ShapeInference::InferDynamicSliceShape(
+ operand->shape(), start_indices->shape(),
+ dynamic_slice->dynamic_slice_sizes()));
+ TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
+ << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
+ << "but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+ TF_RET_CHECK(
+ primitive_util::IsIntegralType(start_indices->shape().element_type()));
+
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& start_indices_literal =
+ parent_->GetEvaluatedLiteralFor(start_indices);
+
+ switch (start_indices->shape().element_type()) {
+ case S32: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_slice],
+ DynamicSlice<int32>(operand_literal, start_indices_literal,
+ result_shape));
+ } break;
+ case S64: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_slice],
+ DynamicSlice<int64>(operand_literal, start_indices_literal,
+ result_shape));
+ } break;
+ case U32: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_slice],
+ DynamicSlice<uint32>(operand_literal, start_indices_literal,
+ result_shape));
+ } break;
+ case U64: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_slice],
+ DynamicSlice<uint64>(operand_literal, start_indices_literal,
+ result_shape));
+ } break;
+ default:
+ LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
+ "start_indices: "
+ << PrimitiveType_Name(start_indices->shape().element_type());
+ }
+
+ return Status::OK();
+ }
+
+ Status HandleDynamicUpdateSlice(
+ HloInstruction* dynamic_update_slice) override {
+ auto operand = dynamic_update_slice->operand(0);
+ auto update = dynamic_update_slice->operand(1);
+ auto start_indices = dynamic_update_slice->operand(2);
+ auto result_shape = dynamic_update_slice->shape();
+ TF_ASSIGN_OR_RETURN(
+ auto inferred_return_shape,
+ ShapeInference::InferDynamicUpdateSliceShape(
+ operand->shape(), update->shape(), start_indices->shape()));
+ TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
+ << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
+ << "but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+ TF_RET_CHECK(
+ primitive_util::IsIntegralType(start_indices->shape().element_type()));
+ TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
+
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
+ const Literal& start_indices_literal =
+ parent_->GetEvaluatedLiteralFor(start_indices);
+
+ switch (start_indices->shape().element_type()) {
+ case S32: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_update_slice],
+ DynamicUpdateSlice<int32>(operand_literal, update_literal,
+ start_indices_literal));
+ } break;
+ case S64: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_update_slice],
+ DynamicUpdateSlice<int64>(operand_literal, update_literal,
+ start_indices_literal));
+ } break;
+ case U32: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_update_slice],
+ DynamicUpdateSlice<uint32>(operand_literal, update_literal,
+ start_indices_literal));
+ } break;
+ case U64: {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[dynamic_update_slice],
+ DynamicUpdateSlice<uint64>(operand_literal, update_literal,
+ start_indices_literal));
+ } break;
+ default:
+ LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
+ "start_indices: "
+ << PrimitiveType_Name(start_indices->shape().element_type());
+ }
+
+ return Status::OK();
+ }
+
+ template <typename NativeT>
+ StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ auto operands = map->operands();
+ HloComputation* computation = map->to_apply();
+
+ auto result = Literal::CreateFromShape(map->shape());
+
+ HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ std::vector<std::unique_ptr<Literal>> arg_literals;
+ arg_literals.reserve(operands.size());
+
+ // Construct scalar literal parameters to be passed to the map
+ // computation.
+ for (auto operand : operands) {
+ const Literal& arg_literal =
+ parent_->GetEvaluatedLiteralFor(operand);
+
+ auto curr_val = arg_literal.Get<NativeT>(multi_index);
+ auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
+
+ arg_literals.push_back(std::move(curr_val_literal));
+ }
+
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator
+ .Evaluate<std::unique_ptr<Literal>>(*computation,
+ arg_literals)
+ .ConsumeValueOrDie();
+ // Clear visit states so that the we can use the evaluate again on
+ // the same computation.
+ embedded_evaluator.ResetVisitStates();
+
+ return computed_result->Get<ReturnT>({});
+ }));
+ return std::move(result);
+ }
+
+ Status HandleMap(HloInstruction* map) override {
+ switch (map->operand(0)->shape().element_type()) {
+ case PRED: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map));
+ break;
+ }
+ case U8: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
+ break;
+ }
+ case U32: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
+ break;
+ }
+ case U64: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map));
+ break;
+ }
+ case S8: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
+ break;
+ }
+ case S32: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
+ break;
+ }
+ case S64: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
+ break;
+ }
+ case F16: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
+ MapImpl<Eigen::half>(map));
+ break;
+ }
+ case F32: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
+ break;
+ }
+ case F64: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map));
+ break;
+ }
+ case C64: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
+ break;
+ }
+ default:
+ LOG(FATAL) << "HandleMap: unhandled primitive type for "
+ "input operand: "
+ << PrimitiveType_Name(
+ map->operand(0)->shape().element_type());
+ }
+
+ return Status::OK();
+ }
+
+ Status HandleReduce(HloInstruction* reduce) override {
+ auto arg = reduce->operand(0);
+ auto init_value = reduce->operand(1);
+ tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ HloComputation* function = reduce->to_apply();
+ TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
+ ShapeUtil::Rank(arg->shape()) - dimensions.size());
+ TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
+ ShapeInference::InferReduceShape(
+ /*arg=*/arg->shape(),
+ /*init_value=*/init_value->shape(),
+ /*dimensions_to_reduce=*/dimensions,
+ /*to_apply=*/function->ComputeProgramShape()));
+ TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
+ << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
+ << "but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+
+ const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
+ VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
+ const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
+ VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
+ auto init_scalar = init_literal.Get<ReturnT>({});
+
+ auto result = Literal::CreateFromShape(reduce->shape());
+
+ const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
+ std::vector<int64> arg_dim_steps(arg_dimensions.size());
+ std::vector<int64> arg_dim_counts(arg_dimensions.size());
+ for (const int64 dim : dimensions) {
+ arg_dim_steps[dim] = 1;
+ arg_dim_counts[dim] = arg_dimensions[dim];
+ }
+
+ // Map each dimension in the result to a dimension in arg that isn't
+ // being reduced.
+ std::vector<int64> result_to_arg_index;
+ for (int64 i = 0; i < arg_dimensions.size(); ++i) {
+ if (arg_dim_steps[i] == 0) {
+ result_to_arg_index.push_back(i);
+ }
+ }
+
+ HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
+ // For each resulting dimension, calculate and assign computed value.
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ ReturnT result_val = init_scalar;
+
+ std::vector<int64> base(arg_dimensions.size());
+ for (int64 i = 0; i < multi_index.size(); ++i) {
+ base[result_to_arg_index[i]] = multi_index[i];
+ }
+
+ // When the reduction is addition of floats, accumulate in a double
+ // for better precision. Also, avoid creating Literals for the
+ // intermediate results; it's much faster.
+ if (ShapeUtil::ElementIsFloating(init_literal.shape()) &&
+ IsScalarAdd(function)) {
+ double computed_result = 0;
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ computed_result += arg_literal.Get<float>(input_index);
+ return true;
+ };
+ ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
+ arg_dim_steps, func);
+ return static_cast<ReturnT>(computed_result);
+ }
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto curr_val = arg_literal.Get<ReturnT>(input_index);
+
+ // Evaluate computation with specified literal operands.
+ auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
+ auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
+ std::vector<const Literal*> args = {result_val_literal.get(),
+ curr_val_literal.get()};
+
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*function, args)
+ .ConsumeValueOrDie();
+ // Clear visit states so that we can use the evaluator again on
+ // the same computation.
+ embedded_evaluator.ResetVisitStates();
+ // Assign computed result to result_val.
+ result_val = computed_result->Get<ReturnT>({});
+ return true;
+ };
+ // Computes one element of the result, reducing all dimensions that
+ // contribute to that element.
+ ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
+ arg_dim_steps, func);
+ return result_val;
+ }));
+
+ parent_->evaluated_[reduce] = std::move(result);
+ return Status::OK();
+ }
+
+ bool IsScalarAdd(HloComputation* computation) {
+ HloInstruction* instruction = computation->root_instruction();
+ if (instruction->opcode() == HloOpcode::kAdd &&
+ computation->num_parameters() == 2) {
+ const HloInstruction* lhs = instruction->operand(0);
+ const HloInstruction* rhs = instruction->operand(1);
+ return lhs->opcode() == HloOpcode::kParameter &&
+ ShapeUtil::IsScalar(lhs->shape()) &&
+ rhs->opcode() == HloOpcode::kParameter &&
+ ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
+ }
+ return false;
+ }
+
+ Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
+ auto operand = select_and_scatter->operand(0);
+ auto source = select_and_scatter->operand(1);
+ const Window& window = select_and_scatter->window();
+
+ const Literal& init_literal =
+ parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
+ auto init_scalar = init_literal.Get<ReturnT>({});
+
+ auto result = Literal::CreateFromShape(select_and_scatter->shape());
+
+ // Initialize result array with the init value.
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ return init_scalar;
+ }));
+
+ std::vector<int64> window_dimension_sizes;
+ for (const auto& window_dimension : window.dimensions()) {
+ window_dimension_sizes.push_back(window_dimension.size());
+ }
+ const Shape window_shape = ShapeUtil::MakeShape(
+ operand->shape().element_type(), window_dimension_sizes);
+
+ HloComputation* select = select_and_scatter->select();
+ HloComputation* scatter = select_and_scatter->scatter();
+
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
+
+ int64 rank = ShapeUtil::Rank(operand_literal.shape());
+
+ HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
+ DimensionVector source_index(rank);
+
+ std::fill(source_index.begin(), source_index.end(), 0);
+ do {
+ // For each element in `source`, we place a window in `operand`. For each
+ // window placement, we iterate inside the window twice:
+ //
+ // 1. Find the selected index by applying `select` function to all
+ // elements. E.g., If the `select` function is GreaterEqual, the first
+ // iteration through the window finds the biggest value and returns its
+ // index.
+ //
+ // 2. Using the selected index, scatter value from `source` to result. We
+ // do this by iterating through the window, and compare each index with
+ // the selected index.
+ tensorflow::gtl::optional<ReturnT> selected_val;
+ tensorflow::gtl::optional<std::vector<int64>> selected_index;
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), source_index,
+ [&](const std::vector<int64>& operand_index) {
+ auto curr_val = operand_literal.Get<ReturnT>(operand_index);
+ if (!selected_val) {
+ selected_val = curr_val;
+ selected_index = operand_index;
+ }
+ const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
+ const auto selected_val_literal =
+ Literal::CreateR0<ReturnT>(*selected_val);
+
+ const std::vector<const Literal*> args = {
+ selected_val_literal.get(), curr_val_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*select, args)
+ .ConsumeValueOrDie();
+ bool selected = !computed_result->Get<bool>({});
+ if (selected) {
+ selected_val = curr_val;
+ selected_index = operand_index;
+ }
+ embedded_evaluator.ResetVisitStates();
+ });
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), source_index,
+ [&](const std::vector<int64>& operand_index) {
+ if (std::equal(operand_index.begin(), operand_index.end(),
+ selected_index->begin())) {
+ auto source = source_literal.Get<ReturnT>(source_index);
+ auto scattered = result->Get<ReturnT>(operand_index);
+ const auto source_literal = Literal::CreateR0<ReturnT>(source);
+ const auto scattered_literal =
+ Literal::CreateR0<ReturnT>(scattered);
+
+ const std::vector<const Literal*> args = {
+ source_literal.get(), scattered_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*scatter, args)
+ .ConsumeValueOrDie();
+ result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ // Clear visit states so that the we can use the evaluator again
+ // on the same computation.
+ embedded_evaluator.ResetVisitStates();
+ }
+ });
+ } while (IndexUtil::BumpIndices(source->shape(), &source_index));
+
+ parent_->evaluated_[select_and_scatter] = std::move(result);
+ return Status::OK();
+ }
+
+ Status HandleReduceWindow(HloInstruction* reduce_window) override {
+ auto operand = reduce_window->operand(0);
+ const Window& window = reduce_window->window();
+ HloComputation* function = reduce_window->to_apply();
+ TF_ASSIGN_OR_RETURN(
+ auto inferred_return_shape,
+ ShapeInference::InferReduceWindowShape(
+ /*operand_shape=*/reduce_window->operand(0)->shape(),
+ /*init_value=*/reduce_window->operand(1)->shape(), window,
+ /*to_apply_shape=*/function->ComputeProgramShape()));
+ TF_RET_CHECK(
+ ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
+ << "return shape is set to: "
+ << ShapeUtil::HumanStringWithLayout(reduce_window->shape())
+ << "but is inferred to be: "
+ << ShapeUtil::HumanStringWithLayout(inferred_return_shape);
+
+ const Literal& operand_literal =
+ parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
+ VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
+ const Literal& init_literal =
+ parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
+ VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
+ auto init_scalar = init_literal.Get<ReturnT>({});
+
+ auto result = Literal::CreateFromShape(reduce_window->shape());
+
+ // Creates a Shape object from window, for iteration below.
+ std::vector<int64> window_dimension_sizes;
+ for (const auto& window_dimension : window.dimensions()) {
+ window_dimension_sizes.push_back(window_dimension.size());
+ }
+ const Shape window_shape = ShapeUtil::MakeShape(
+ operand->shape().element_type(), window_dimension_sizes);
+
+ DimensionVector window_index(window.dimensions_size());
+ DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
+
+ HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
+ // For each resulting dimension, calculate and assign computed value.
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ ReturnT result_val = init_scalar;
+
+ std::fill(window_index.begin(), window_index.end(), 0);
+ std::fill(operand_index.begin(), operand_index.end(), 0);
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), output_index,
+ [&](const std::vector<int64>& operand_index) {
+ auto curr_val = operand_literal.Get<ReturnT>(operand_index);
+
+ // Evaluate computation with specified literal operands.
+ const auto curr_val_literal =
+ Literal::CreateR0<ReturnT>(curr_val);
+ const auto result_val_literal =
+ Literal::CreateR0<ReturnT>(result_val);
+ const std::vector<const Literal*> args = {
+ result_val_literal.get(), curr_val_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*function, args)
+ .ConsumeValueOrDie();
+
+ // Clear visit states so that the we can use the evaluate again
+ // on the same computation.
+ embedded_evaluator.ResetVisitStates();
+
+ result_val = computed_result->Get<ReturnT>({});
+ });
+
+ return result_val;
+ }));
+
+ parent_->evaluated_[reduce_window] = std::move(result);
+ return Status::OK();
+ }
+
+ Status HandleSlice(HloInstruction* slice) override {
+ auto operand = slice->operand(0);
+ const Shape& shape = slice->shape();
+ TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
+ ShapeInference::InferSliceShape(
+ operand->shape(), slice->slice_starts(),
+ slice->slice_limits(), slice->slice_strides()));
+ TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape))
+ << "return shape set to: " << ShapeUtil::HumanString(shape)
+ << " but is inferred to be: "
+ << ShapeUtil::HumanString(inferred_return_shape);
+
+ const int64 rank = ShapeUtil::Rank(operand->shape());
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ DimensionVector operand_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ operand_index[i] =
+ slice->slice_starts(i) + out_index[i] * slice->slice_strides(i);
+ }
+ return operand_literal.Get<ReturnT>(operand_index);
+ };
+
+ auto result = Literal::CreateFromDimensions(
+ shape.element_type(), AsInt64Slice(shape.dimensions()));
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ parent_->evaluated_[slice] = std::move(result);
+ return Status::OK();
+ }
+
+ // Enable CLZ only for int32 and uint32.
+ template <
+ typename NativeT,
+ typename std::enable_if<
+ (std::is_floating_point<NativeT>::value ||
+ std::is_integral<NativeT>::value || is_complex_t<NativeT>::value) &&
+ !(std::is_same<NativeT, uint32>::value ||
+ std::is_same<NativeT, int32>::value)>::type* = nullptr>
+ Status HandleClz(HloInstruction* clz) {
+ return InvalidArgument("Unsupported type for Clz");
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, uint32>::value ||
+ std::is_same<NativeT, int32>::value>::type* = nullptr>
+ Status HandleClz(HloInstruction* clz) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz],
+ ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) {
+ return 31 - tensorflow::Log2Floor(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ Status HandleClz(HloInstruction* clz) override {
+ return HandleClz<ElementwiseT>(clz);
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleSin(HloInstruction* sin) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
+ ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) {
+ return std::sin(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value ||
+ is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleSin(HloInstruction* sin) {
+ return InvalidArgument("Unsupported type for Sin");
+ }
+
+ Status HandleSin(HloInstruction* sin) override {
+ return HandleSin<ElementwiseT>(sin);
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleCos(HloInstruction* cos) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
+ ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) {
+ return std::cos(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value ||
+ is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleCos(HloInstruction* cos) {
+ return InvalidArgument("Unsupported type for Cos");
+ }
+
+ Status HandleCos(HloInstruction* cos) override {
+ return HandleCos<ElementwiseT>(cos);
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_same<
+ float, NativeT>::value>::type* = nullptr>
+ Status HandleReducePrecision(HloInstruction* reduce_precision) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[reduce_precision],
+ ElementWiseUnaryOp(reduce_precision, [reduce_precision](
+ ElementwiseT elem) {
+ uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem);
+ const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
+ const uint32_t exponent_bits = reduce_precision->exponent_bits();
+
+ // Code is based on the CPU/GPU implementation in LLVM-emitting code.
+ //
+ // Bits in float type:
+ // mantissa : bits [0:22]
+ // exponent : bits [23:30]
+ // sign : bits [31]
+ if (mantissa_bits < 23) {
+ const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
+
+ // Compute rounding bias for round-to-nearest with ties to even.
+ // This is equal to a base value of 0111... plus one bit if the last
+ // remaining mantissa bit is 1.
+ const uint32_t base_rounding_bias =
+ (last_mantissa_bit_mask >> 1) - 1;
+ const uint32_t x_last_mantissa_bit =
+ (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
+ const uint32_t x_rounding_bias =
+ x_last_mantissa_bit + base_rounding_bias;
+
+ // Add rounding bias, and mask out truncated bits. Note that the
+ // case where adding the rounding bias overflows into the exponent
+ // bits is correct; the non-masked mantissa bits will all be zero,
+ // and the exponent will be incremented by one.
+ const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
+ value_as_int = value_as_int + x_rounding_bias;
+ value_as_int = value_as_int & truncation_mask;
+ }
+ if (exponent_bits < 8) {
+ // Masks for f32 values.
+ const uint32_t f32_sign_bit_mask = 1u << 31;
+ const uint32_t f32_exp_bits_mask = 0xffu << 23;
+
+ // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
+ // most- significant bit -- is equal to 1.0f for all exponent sizes.
+ // Adding 2^(n-1)-1 to this gives us the highest non-infinite
+ // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from
+ // this gives us the lowest' exponent (corresponding to 0.0f).
+ //
+ // Thus, the f32 exponent corresponding to the highest non-infinite
+ // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
+ // exponent corresponding to the lowest exponent for a bit size of n
+ // is (2^7-1) - 2^(n-1)-1.
+ //
+ // Note that we have already checked that exponents_bits >= 1.
+ const uint32_t f32_exponent_bias = (1 << 7) - 1;
+ const uint32_t reduced_exponent_bias =
+ (1 << (exponent_bits - 1)) - 1;
+ const uint32_t reduced_max_exponent =
+ f32_exponent_bias + reduced_exponent_bias;
+ const uint32_t reduced_min_exponent =
+ f32_exponent_bias - reduced_exponent_bias;
+
+ // Do we overflow or underflow?
+ const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
+ const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
+ const bool x_underflows =
+ x_exponent <= (reduced_min_exponent << 23);
+
+ // Compute appropriately-signed values of zero and infinity.
+ const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
+ const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
+
+ // Force to zero or infinity if overflow or underflow. (Note that
+ // this truncates all denormal values to zero, rather than rounding
+ // them.)
+ value_as_int = x_overflows ? x_signed_inf : value_as_int;
+ value_as_int = x_underflows ? x_signed_zero : value_as_int;
+ }
+
+ float reduced_result = tensorflow::bit_cast<float>(value_as_int);
+ if (std::isnan(elem)) {
+ reduced_result = mantissa_bits > 0
+ ? elem
+ : std::numeric_limits<float>::infinity();
+ }
+ return reduced_result;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_same<
+ double, NativeT>::value>::type* = nullptr>
+ Status HandleReducePrecision(HloInstruction* reduce_precision) {
+ return InvalidArgument("Double not supported for reduce precision");
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value ||
+ is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleReducePrecision(HloInstruction* reduce_precision) {
+ return InvalidArgument("Unsupported type for reduce precision");
+ }
+
+ Status HandleReducePrecision(HloInstruction* reduce_precision) override {
+ return HandleReducePrecision<ElementwiseT>(reduce_precision);
+ }
+
+ private:
+ // Creates a vector of multipliers which can be used to create a linear index
+ // into shape.
+ //
+ // Given the multidimensional index {i1, ..., iN} and
+ // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply
+ //
+ // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N].
+ //
+ // This lets you calculate LI given the multidimensional indices in any order.
+ static DimensionVector MakeDimMultipliers(const Shape& shape) {
+ DimensionVector v(ShapeUtil::Rank(shape));
+ int64 scale = 1;
+ for (auto dim : LayoutUtil::MinorToMajor(shape)) {
+ v[dim] = scale;
+ scale *= shape.dimensions(dim);
+ }
+ return v;
+ }
+
+ // For one particular placement of a window in a base shape (the placement is
+ // represented as `window_count_index`), iterates inside the window.
+ // Translates the window index into base index. If the base index is within
+ // bound, call `f` with the base index.
+ static void IterateThroughWindow(
+ const Shape& window_shape, const Window& window, const Shape& base_shape,
+ const tensorflow::gtl::ArraySlice<int64>& window_count_index,
+ const std::function<void(const std::vector<int64>&)>& f) {
+ const int64 rank = ShapeUtil::Rank(base_shape);
+ DimensionVector window_index(rank);
+ std::fill(window_index.begin(), window_index.end(), 0);
+ do {
+ std::vector<int64> base_index(rank);
+ bool out_of_bound = false;
+ for (int64 i = 0; i < rank; ++i) {
+ base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] - window.dimensions(i).padding_low();
+ if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
+ out_of_bound = true;
+ break;
+ }
+ }
+ if (!out_of_bound) {
+ f(base_index);
+ }
+ } while (IndexUtil::BumpIndices(window_shape, &window_index));
+ }
+
+ template <typename IndexT>
+ StatusOr<std::unique_ptr<Literal>> DynamicSlice(
+ const Literal& operand_literal, const Literal& start_indices_literal,
+ const Shape& result_shape) {
+ auto start_indices_typed = start_indices_literal.data<IndexT>();
+ std::vector<int64> start(start_indices_typed.begin(),
+ start_indices_typed.end());
+
+ std::vector<int64> operand_indices(start.size());
+
+ auto result = Literal::CreateFromShape(result_shape);
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ for (int64 i = 0; i < operand_indices.size(); ++i) {
+ CHECK_GE(multi_index[i] + start[i], 0);
+ // Mod is only used here to be consistent with the existing
+ // backends' behavior.
+ operand_indices[i] = (multi_index[i] + start[i]) %
+ operand_literal.shape().dimensions(i);
+ }
+
+ auto result = operand_literal.Get<ReturnT>(operand_indices);
+ return result;
+ }));
+
+ return std::move(result);
+ }
+
+ template <typename IndexT>
+ StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
+ const Literal& operand_literal, const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.CloneToUnique();
+ auto start_indices_typed = start_indices_literal.data<IndexT>();
+ const auto rank = ShapeUtil::Rank(result->shape());
+ std::vector<int64> start(rank, 0);
+ for (int64 i = 0; i < rank; ++i) {
+ // All other implementations currently wrap-around the index, so this
+ // should do so as well.
+ start[i] = (start_indices_typed[i] % result->shape().dimensions(i));
+ start[i] += (start[i] < 0) * result->shape().dimensions(i);
+ }
+ std::vector<int64> result_index(rank, 0);
+
+ auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
+ std::transform(update_index.begin(), update_index.end(), start.begin(),
+ result_index.begin(), std::plus<int64>());
+ // Same as above, wrap-around only to match other implementations'
+ // semantics.
+ std::transform(result_index.begin(), result_index.end(),
+ result->shape().dimensions().begin(), result_index.begin(),
+ std::modulus<int64>());
+ result->Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
+ return true;
+ };
+
+ std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
+ std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
+ ShapeUtil::ForEachIndex(update_literal.shape(), base,
+ AsInt64Slice(update_literal.shape().dimensions()),
+ step, func);
+
+ return std::move(result);
+ }
+
+ StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ HloInstruction* instruction,
+ const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
+ const Literal& operand_literal =
+ parent_->GetEvaluatedLiteralFor(instruction->operand(0));
+ TF_ASSIGN_OR_RETURN(
+ auto result_literal,
+ (HloEvaluator::ElementWiseUnaryOpImpl<ReturnT, ReturnT>(
+ instruction, ConvertUnaryFunction(unary_op), operand_literal)));
+
+ return std::move(result_literal);
+ }
+
+ StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ HloInstruction* instruction,
+ const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
+ binary_op) {
+ const auto shape = instruction->shape();
+ const auto* lhs = instruction->operand(0);
+ const auto* rhs = instruction->operand(1);
+
+ // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast
+ // is removed.
+ if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) &&
+ ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
+ return Unimplemented(
+ "Implicit broadcasting is currently unsupported in HLO evaluator "
+ "Shape Mismatch: %s vs %s vs %s: ",
+ ShapeUtil::HumanString(shape).c_str(),
+ ShapeUtil::HumanString(lhs->shape()).c_str(),
+ ShapeUtil::HumanString(rhs->shape()).c_str());
+ }
+
+ const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
+ const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
+
+ auto result = Literal::CreateFromShape(shape);
+
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return ConvertBinaryFunction(binary_op)(
+ lhs_literal.Get<ReturnT>(multi_index),
+ rhs_literal.Get<ReturnT>(multi_index));
+ }));
+ return std::move(result);
+ }
+
+ template <typename LhsType, typename RhsType, typename EhsType>
+ StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ HloInstruction* instruction,
+ const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
+ const auto shape = instruction->shape();
+ const auto* lhs = instruction->operand(0);
+ const auto* rhs = instruction->operand(1);
+ const auto* ehs = instruction->operand(2);
+
+ // TODO(b/35950897, b/27796129): add DCHECK back once implicit
+ // broadcast is removed.
+ if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) &&
+ ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) &&
+ ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) {
+ return Unimplemented(
+ "Implicit broadcasting is currently unsupported in HLO evaluator "
+ "Shape Mismatch: %s vs %s vs %s vs %s: ",
+ ShapeUtil::HumanString(shape).c_str(),
+ ShapeUtil::HumanString(lhs->shape()).c_str(),
+ ShapeUtil::HumanString(rhs->shape()).c_str(),
+ ShapeUtil::HumanString(ehs->shape()).c_str());
+ }
+
+ const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
+ const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
+ const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
+
+ auto result = Literal::CreateFromShape(shape);
+
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return ternary_op(lhs_literal.Get<LhsType>(multi_index),
+ rhs_literal.Get<RhsType>(multi_index),
+ ehs_literal.Get<EhsType>(multi_index));
+ }));
+
+ return std::move(result);
+ }
+
+ template <typename NativeT>
+ static bool IsShiftOutOfBounds(NativeT rhs) {
+ typedef typename std::make_unsigned<NativeT>::type UnsignedT;
+ UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT;
+ UnsignedT rhs_unsigned = static_cast<UnsignedT>(rhs);
+ return rhs_unsigned >= lhs_size_unsigned;
+ }
+
+ HloEvaluator* parent_;
+};
+
+// These extern templates prevent users of this class from implicitly
+// instantiating it. We explicitly instantiate this class in the various
+// hlo_evaluator_typed_visitor*.cc files.
+extern template class HloEvaluatorTypedVisitor<bool>;
+extern template class HloEvaluatorTypedVisitor<uint8>;
+extern template class HloEvaluatorTypedVisitor<uint32>;
+extern template class HloEvaluatorTypedVisitor<uint64>;
+extern template class HloEvaluatorTypedVisitor<int8>;
+extern template class HloEvaluatorTypedVisitor<int32>;
+extern template class HloEvaluatorTypedVisitor<int64>;
+extern template class HloEvaluatorTypedVisitor<Eigen::half, float>;
+extern template class HloEvaluatorTypedVisitor<float>;
+extern template class HloEvaluatorTypedVisitor<double>;
+extern template class HloEvaluatorTypedVisitor<complex64>;
+extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc
new file mode 100644
index 0000000000..39c352dfb9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<bfloat16, float>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc
new file mode 100644
index 0000000000..289b40fa06
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<bool>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc
new file mode 100644
index 0000000000..9cb4eb921f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<complex64>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc
new file mode 100644
index 0000000000..5e6252fbf8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<double>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc
new file mode 100644
index 0000000000..ee793ae77b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<float>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc
new file mode 100644
index 0000000000..038d9d39e4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<Eigen::half, float>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc
new file mode 100644
index 0000000000..b1952ca619
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<int32>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc
new file mode 100644
index 0000000000..0cbaffb40b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<int64>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc
new file mode 100644
index 0000000000..6f4bf2a392
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<int8>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc
new file mode 100644
index 0000000000..10235447e0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<uint32>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc
new file mode 100644
index 0000000000..8abeaa6ffc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<uint64>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc
new file mode 100644
index 0000000000..6dabd1c176
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+namespace xla {
+template class HloEvaluatorTypedVisitor<uint8>;
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index bb4db89f0a..b6b0387672 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -322,11 +322,13 @@ class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
const DebugOptions& debug_options, bool show_metadata,
- const HloExecutionProfile* profile, NodeFilter filter)
+ bool show_backend_config, const HloExecutionProfile* profile,
+ NodeFilter filter)
: computation_(computation),
- label_(label.ToString()),
+ label_(std::string(label)),
debug_options_(debug_options),
show_metadata_(show_metadata),
+ show_backend_config_(show_backend_config),
profile_(profile),
filter_(std::move(filter)) {}
@@ -365,6 +367,7 @@ class HloDotDumper {
string GetInstructionNodeShape(const HloInstruction* instr);
string GetInstructionNodeLabel(const HloInstruction* instr);
string GetInstructionNodeMetadata(const HloInstruction* instr);
+ string GetInstructionNodeBackendConfig(const HloInstruction* instr);
string GetInstructionNodeExtraInfo(const HloInstruction* instr);
string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
void AddInstructionIncomingEdges(const HloInstruction* instr);
@@ -393,6 +396,7 @@ class HloDotDumper {
const string label_; // overall name for the graph
const DebugOptions& debug_options_;
const bool show_metadata_;
+ const bool show_backend_config_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@@ -611,6 +615,10 @@ tooltip = " ";
if (!extra_info.empty()) {
StrAppend(&subcomp_label, "<br/>", extra_info);
}
+ string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
+ if (!node_backend_config.empty()) {
+ StrAppend(&subcomp_label, "<br/>", node_backend_config);
+ }
bool highlight = filter_.Highlight(parent_instr);
const char* fillcolor;
@@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string node_shape = GetInstructionNodeShape(instr);
string node_label = GetInstructionNodeLabel(instr);
string node_metadata = GetInstructionNodeMetadata(instr);
+ string node_backend_config = GetInstructionNodeBackendConfig(instr);
string extra_info = GetInstructionNodeExtraInfo(instr);
string inlined_constants = GetInstructionNodeInlinedOperands(instr);
string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
@@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
}
// Build the text that will be displayed inside the node.
string node_body = node_label;
- for (const string& s :
- {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
+ for (const string& s : {trivial_subcomputation, node_metadata,
+ node_backend_config, extra_info, inlined_constants}) {
if (!s.empty()) {
StrAppend(&node_body, "<br/>", s);
}
@@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
return Join(lines, "<br/>");
}
+string HloDotDumper::GetInstructionNodeBackendConfig(
+ const HloInstruction* instr) {
+ if (!show_backend_config_ || instr->backend_config().empty()) {
+ return "";
+ }
+
+ return StrCat("backend_config=\"", instr->backend_config(), "\"");
+}
+
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
std::vector<string> lines;
@@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph,
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile,
- bool show_metadata) {
+ bool show_metadata, bool show_backend_config) {
GraphRendererInterface::GraphKind graph_kind;
string graph;
if (debug_options.xla_hlo_dump_as_graphdef()) {
@@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
&graph));
graph_kind = GraphRendererInterface::TF_GRAPHDEF;
} else {
- graph = HloDotDumper(&computation, label, debug_options, show_metadata,
- hlo_execution_profile, NodeFilter())
- .Dump();
+ graph =
+ HloDotDumper(&computation, label, debug_options, show_metadata,
+ show_backend_config, hlo_execution_profile, NodeFilter())
+ .Dump();
graph_kind = GraphRendererInterface::DOT_GRAPH;
}
@@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label,
}
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
- bool show_metadata) {
+ bool show_metadata, bool show_backend_config) {
auto debug_options = node.GetModule()->config().debug_options();
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
NodeFilter filter = MakeNodeFilter(&node, radius);
- string graph =
- HloDotDumper(node.parent(), label, debug_options, show_metadata,
- /*profile=*/nullptr, filter)
- .Dump();
+ string graph = HloDotDumper(node.parent(), label, debug_options,
+ show_metadata, show_backend_config,
+ /*profile=*/nullptr, filter)
+ .Dump();
return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
index 2704aae1e3..fc8e1468ac 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label,
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile = nullptr,
- bool show_metadata = false);
+ bool show_metadata = false, bool show_backend_config = false);
// Like DumpGraph, but renders only nodes "near" the given node in the graph.
//
@@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
// (roughly) corresponds to the max distance a node may be from the primary node
// before it's omitted from the graph.
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
- bool show_metadata = false);
+ bool show_metadata = false,
+ bool show_backend_config = false);
// Dumps the HloModule::ToString() as a file into the provided directory path
// suffixed with the provided label.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a714d0e114..857cd39adb 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -109,6 +109,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->name_ = proto.name();
instruction->metadata_ = proto.metadata();
+ instruction->set_backend_config(proto.backend_config());
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(instruction->literal_,
Literal::CreateFromProto(proto.literal()));
@@ -437,7 +438,7 @@ HloInstruction::CreateCrossReplicaSum(
<< "Outfeed shape " << shape << " must be compatible with operand shape "
<< operand->shape();
instruction->AppendOperand(operand);
- instruction->outfeed_config_ = outfeed_config.ToString();
+ instruction->outfeed_config_ = std::string(outfeed_config);
instruction->outfeed_shape_ = shape;
return instruction;
}
@@ -792,23 +793,11 @@ HloInstruction::CreateBroadcastSequence(
return instruction;
}
-// We put the fusion kind into the instruction's name for transpose-dot fusions,
-// since those fusions are really just describing a type of dot rather than
-// generating a novel computation.
-static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
- switch (fusion_kind) {
- case HloInstruction::FusionKind::kTransposeDot:
- return "dot_fusion";
- default:
- return "fusion";
- }
-}
-
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
instruction->fusion_kind_ = fusion_kind;
- instruction->name_ = FusionNodeName(fusion_kind);
+ instruction->name_ = "fusion";
instruction->set_parent(fused_root->parent());
instruction->set_metadata(fused_root->metadata());
instruction->CloneAndFuseInternal(fused_root);
@@ -824,7 +813,7 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
instruction->AppendOperand(operand);
}
instruction->fusion_kind_ = fusion_kind;
- instruction->name_ = FusionNodeName(fusion_kind);
+ instruction->name_ = "fusion";
instruction->called_computations_.push_back(fusion_computation);
fusion_computation->SetFusionInstruction(instruction.get());
return instruction;
@@ -1167,7 +1156,7 @@ bool HloInstruction::HasSideEffect() const {
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
- instruction->custom_call_target_ = custom_call_target.ToString();
+ instruction->custom_call_target_ = std::string(custom_call_target);
return instruction;
}
@@ -1179,7 +1168,7 @@ bool HloInstruction::HasSideEffect() const {
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
- instruction->channel_name_ = channel_name.ToString();
+ instruction->channel_name_ = std::string(channel_name);
instruction->cost_estimate_ns_ = cost_estimate_ns;
return instruction;
}
@@ -1231,12 +1220,15 @@ bool HloInstruction::HasSideEffect() const {
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloModule* module) const {
+ HloModule* module, CloneMap* clone_map) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " %" << new_operand->name();
}
+ if (module == nullptr) {
+ module = GetModule();
+ }
std::unique_ptr<HloInstruction> clone;
@@ -1342,7 +1334,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break;
case HloOpcode::kFft:
CHECK_EQ(new_operands.size(), 1);
- return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
+ clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
+ break;
case HloOpcode::kCrossReplicaSum:
clone = CreateCrossReplicaSum(shape, new_operands);
break;
@@ -1415,9 +1408,15 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kConstant:
clone = CreateConstant(literal_->CloneToUnique());
break;
- case HloOpcode::kFusion:
- clone = CloneFusionWithNewOperands(shape, new_operands, module);
+ case HloOpcode::kFusion: {
+ CHECK_NE(module, nullptr);
+ auto new_fused_computation = module->AddEmbeddedComputation(
+ fused_instructions_computation()->Clone("clone", module, clone_map));
+ clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(),
+ /*operands=*/new_operands,
+ /*fusion_computation=*/new_fused_computation);
break;
+ }
case HloOpcode::kParameter:
clone = CreateParameter(parameter_number_, shape, name_);
break;
@@ -1481,15 +1480,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
+ clone->set_backend_config(backend_config());
+ if (clone_map != nullptr) {
+ InsertOrDie(clone_map, this, clone.get());
+ }
return clone;
}
HloInstruction::~HloInstruction() {}
-std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
- HloModule* module) const {
+std::unique_ptr<HloInstruction> HloInstruction::Clone(
+ const string& suffix, HloModule* module, CloneMap* clone_map) const {
std::unique_ptr<HloInstruction> clone =
- CloneWithNewOperands(shape_, operands_, module);
+ CloneWithNewOperands(shape_, operands_, module, clone_map);
if (suffix.empty()) {
clone->name_ = name();
} else {
@@ -1526,71 +1529,6 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
return clone;
}
-std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloModule* module) const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(parent() != nullptr);
-
- auto new_instruction =
- WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
- // Add the operands to our new fusion instruction.
- for (HloInstruction* new_operand : operands) {
- new_instruction->AppendOperand(new_operand);
- }
- // Clone all the fused instructions for the new fusion instruction.
- HloInstructionMap<HloInstruction*> old_to_new;
- std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
- // Create the list of fused parameters by mapping through the cloned,
- // fused instructions.
- for (HloInstruction* old_fused_parameter :
- fused_instructions_computation()->parameter_instructions()) {
- new_fused_instructions.push_back(
- old_fused_parameter->Clone("clone", module));
- HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
- InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
- }
- for (auto old_fused_instruction :
- fused_instructions_computation()->MakeInstructionPostOrder()) {
- if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
- FindOrDie(old_to_new, old_fused_instruction);
- continue;
- }
- std::vector<HloInstruction*> new_operands;
- for (int64 operand_idx = 0;
- operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
- HloInstruction* old_operand =
- old_fused_instruction->mutable_operand(operand_idx);
- new_operands.push_back(FindOrDie(old_to_new, old_operand));
- }
- new_fused_instructions.push_back(
- old_fused_instruction->CloneWithNewOperands(
- old_fused_instruction->shape(), new_operands, module));
- HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
- new_fused_instruction->set_parent(parent_);
- InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
- }
- new_instruction->fusion_kind_ = fusion_kind_;
- auto computation_builder = HloComputation::Builder(
- fused_instructions_computation()->name() + ".clone",
- new_instruction.get());
- // We iterated the fusion instructions in reverse post order which means
- // that we must reverse our new list of fusion instructions.
- for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
- new_fused_instruction_iter != new_fused_instructions.rend();
- ++new_fused_instruction_iter) {
- computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
- }
- if (module == nullptr) {
- module = GetModule();
- }
- auto fused_root_ = fused_expression_root();
- new_instruction->called_computations_.push_back(
- CHECK_NOTNULL(module)->AddEmbeddedComputation(
- computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
- return new_instruction;
-}
-
std::pair<const HloInstruction*, ShapeIndex>
HloInstruction::LatestNonGteAncestorAndIndex() const {
const HloInstruction* hlo = this;
@@ -2172,6 +2110,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
!metadata_.source_file().empty())) {
StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
}
+ if (options.print_backend_config() && !backend_config().empty()) {
+ StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\"");
+ }
return result;
}
@@ -2357,6 +2298,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
}
+
return extra;
}
@@ -2386,6 +2328,7 @@ HloInstructionProto HloInstruction::ToProto() const {
}
*proto.mutable_metadata() = metadata_;
+ proto.set_backend_config(backend_config());
if (literal_ != nullptr) {
*proto.mutable_literal() = literal_->ToProto();
}
@@ -2487,8 +2430,6 @@ string HloInstruction::ToCategory() const {
return "input fusion";
case FusionKind::kOutput:
return "output fusion";
- case FusionKind::kTransposeDot:
- return "dot";
case FusionKind::kCustom:
return "custom fusion";
}
@@ -2971,6 +2912,7 @@ Status HloInstruction::AcceptOrdered(
continue;
}
+ // TODO(b/78350259): Eliminate const laundering.
HloInstruction* instruction =
const_cast<HloInstruction*>(const_instruction);
@@ -3270,8 +3212,6 @@ string ToString(HloInstruction::FusionKind kind) {
return "kInput";
case HloInstruction::FusionKind::kOutput:
return "kOutput";
- case HloInstruction::FusionKind::kTransposeDot:
- return "kTransposeDot";
case HloInstruction::FusionKind::kCustom:
return "kCustom";
}
@@ -3288,9 +3228,6 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
if (kind_name == "kOutput") {
return HloInstruction::FusionKind::kOutput;
}
- if (kind_name == "kTransposeDot") {
- return HloInstruction::FusionKind::kTransposeDot;
- }
if (kind_name == "kCustom") {
return HloInstruction::FusionKind::kCustom;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index a5e9aecb9e..14be58d069 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -66,6 +66,7 @@ class HloPrintOptions {
: print_large_constants_(false),
print_subcomputation_references_(true),
print_metadata_(true),
+ print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
print_program_shape_(true),
@@ -77,6 +78,7 @@ class HloPrintOptions {
.set_print_large_constants(true)
.set_print_subcomputation_references(true)
.set_print_metadata(false)
+ .set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
.set_print_percent(false);
@@ -99,12 +101,18 @@ class HloPrintOptions {
return *this;
}
- // If true, metatdata will be printed.
+ // If true, metadata will be printed.
HloPrintOptions& set_print_metadata(bool value) {
print_metadata_ = value;
return *this;
}
+ // If true, backend_config will be printed.
+ HloPrintOptions& set_print_backend_config(bool value) {
+ print_backend_config_ = value;
+ return *this;
+ }
+
// If true, operands' shapes will be printed.
HloPrintOptions& set_print_operand_shape(bool value) {
print_operand_shape_ = value;
@@ -141,6 +149,7 @@ class HloPrintOptions {
return print_subcomputation_references_;
}
bool print_metadata() const { return print_metadata_; }
+ bool print_backend_config() const { return print_metadata_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
@@ -151,6 +160,7 @@ class HloPrintOptions {
bool print_large_constants_;
bool print_subcomputation_references_;
bool print_metadata_;
+ bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
bool print_program_shape_;
@@ -167,7 +177,6 @@ class HloInstruction {
kOutput, // Op's output is fused into the op itself.
// REQUIRES: At least one operand buffer must be able
// to alias the output buffer.
- kTransposeDot, // Fused into a dot with transposed operands.
kCustom, // Custom category for backend-specific fusions that
// do not match any of the more specific ones.
};
@@ -643,6 +652,8 @@ class HloInstruction {
// Detaches an instruction from its operands. That is, remove the instruction
// from each operand's user set. This should only be called prior to
// deallocating the instruction.
+ //
+ // TODO(b/78305363): Make this automatic when deleting an instruction.
void DetachFromOperands();
// Performs a postorder DFS visit using this node as the root. If
@@ -1157,23 +1168,30 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kRng
RandomDistribution random_distribution() const;
+ // See documentation for Clone().
+ using CloneMap = std::unordered_map<const HloInstruction*, HloInstruction*>;
+
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
- // the instruction to form the name of the cloned instruction. If the module
- // pointer is not nullptr, it will be the module where the cloned computations
- // will be added to (in order to support deep cloning). Ignores the control
- // predecessors and successors of this HLO instruction.
+ // the instruction to form the name of the cloned instruction. Ignores the
+ // control predecessors and successors of this HLO instruction.
+ //
+ // If the module pointer is not nullptr, then any cloned computations will be
+ // added to this module in order to support deep cloning. Otherwise the module
+ // of the instruction is used.
+ //
+ // If clone_map is not nullptr, then each original instruction that is cloned
+ // will be inserted and map to its clone. clone_map should not already contain
+ // any of the instructions to clone.
std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
- HloModule* module = nullptr) const;
+ HloModule* module = nullptr,
+ CloneMap* clone_map = nullptr) const;
- // Clones the HLO instruction as above but with new shape and operands. If
- // the module pointer is not nullptr, it will be the module where the cloned
- // computations will be added to (in order to support deep cloning). Ignores
- // the control predecessors and successors of this HLO instruction.
+ // Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloModule* module = nullptr) const;
+ HloModule* module = nullptr, CloneMap* clone_map = nullptr) const;
// Returns the computations this instruction directly calls (if any).
const std::vector<HloComputation*>& called_computations() const {
@@ -1245,7 +1263,7 @@ class HloInstruction {
// Gets/sets the string identifier for this instruction.
const string& name() const { return name_; }
- void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
+ void set_name(tensorflow::StringPiece name) { name_ = std::string(name); }
// Use the given NameUniquer to select a unique name for the instruction based
// on the instruction's existing name.
@@ -1262,6 +1280,19 @@ class HloInstruction {
// if no id has been assigned yet).
int unique_id() const { return unique_id_; }
+ // Returns the backend-specific configuration for how a backend should compile
+ // this HLO. The meaning of the field is backend specific. Not for use before
+ // or during general HLO optimization, since HLO optimizations do not preserve
+ // this field and they cannot interpret it due to its meaning being backend
+ // specific.
+ //
+ // TODO(b/78194644): Introduce structured configuration format as per
+ // go/xla-heuristics.
+ const string& backend_config() const { return backend_config_; }
+ void set_backend_config(string backend_config) {
+ backend_config_ = std::move(backend_config);
+ }
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1283,6 +1314,7 @@ class HloInstruction {
// Get/Set the number of partitions per outer dimension (in order, starting
// with outer-most dimension first). Currently used by the parallel cpu
// backend to partition HLOs into parallel tasks.
+ //
// TODO(b/62783254) Replace these methods with a more general way to
// annotate HLOs with backend-specific information.
const std::vector<int64>& outer_dimension_partitions() const {
@@ -1510,6 +1542,10 @@ class HloInstruction {
// The string representation of the infeed configuration.
string infeed_config_;
+ // The backend-specific configuration for how a backend should compile this
+ // HLO. See the documentation on backend_config().
+ string backend_config_;
+
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 5b65b1152c..909cdc0b62 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1102,7 +1102,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
- {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
+ {dot, reshape}, HloInstruction::FusionKind::kLoop);
auto fusion2 = fusion->Clone();
const HloInstruction* root = fusion->fused_expression_root();
@@ -1169,7 +1169,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
auto computation = module->AddEntryComputation(builder.Build());
auto nested_fusion = computation->CreateFusionInstruction(
- {dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
+ {dot, b_t}, HloInstruction::FusionKind::kLoop);
auto fusion = computation->CreateFusionInstruction(
{add, nested_fusion}, HloInstruction::FusionKind::kOutput);
@@ -1246,13 +1246,6 @@ TEST_F(HloInstructionTest, Stringification) {
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
- HloInstruction* fusion = computation->CreateFusionInstruction(
- {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
-
- EXPECT_EQ(
- fusion->ToString(options),
- "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, "
- "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation");
HloInstruction* loop = builder.AddInstruction(
HloInstruction::CreateWhile(sout, computation, computation, x));
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 69deac263e..7e4b883435 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -17,10 +17,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace testing {
+using ::tensorflow::str_util::Join;
+
bool HloMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
@@ -195,6 +198,41 @@ void HloShardingMatcher::DescribeTo(std::ostream* os) const {
}
}
+bool HloDotWithContractingDimsMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (!HloMatcher::MatchAndExplain(instruction, listener)) {
+ return false;
+ }
+
+ const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers();
+ if (dim_nums.lhs_contracting_dimensions_size() != 1 ||
+ dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
+ *listener << instruction->ToString()
+ << " has wrong lhs_contracting_dimensions (got {"
+ << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {"
+ << lhs_contracting_dim_ << "})";
+ return false;
+ }
+
+ if (dim_nums.rhs_contracting_dimensions_size() != 1 ||
+ dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
+ *listener << instruction->ToString()
+ << " has wrong rhs_contracting_dimensions (got {"
+ << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {"
+ << rhs_contracting_dim_ << "})";
+ return false;
+ }
+
+ return true;
+}
+
+void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const {
+ HloMatcher::DescribeTo(os);
+ *os << " with lhs_contracting_dims={" << lhs_contracting_dim_
+ << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}";
+}
+
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 5175736a25..c33bdadf1c 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -131,6 +131,27 @@ class HloShardingMatcher
tensorflow::gtl::optional<HloSharding> sharding_;
};
+// Matches a Dot HLO instruction with specific LHS and RHS contracting
+// dimensions.
+class HloDotWithContractingDimsMatcher : public HloMatcher {
+ public:
+ explicit HloDotWithContractingDimsMatcher(
+ ::testing::Matcher<const HloInstruction*> lhs,
+ ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim,
+ int64 rhs_contracting_dim)
+ : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
+ lhs_contracting_dim_(lhs_contracting_dim),
+ rhs_contracting_dim_(rhs_contracting_dim) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ int64 lhs_contracting_dim_;
+ int64 rhs_contracting_dim_;
+};
+
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
@@ -158,7 +179,6 @@ HLO_MATCHER(Convolution);
HLO_MATCHER(Copy);
HLO_MATCHER(CrossReplicaSum);
HLO_MATCHER(Divide);
-HLO_MATCHER(Dot);
HLO_MATCHER(DynamicSlice);
HLO_MATCHER(DynamicUpdateSlice);
HLO_MATCHER(Eq);
@@ -310,6 +330,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
}
+inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
+ ::testing::Matcher<const HloInstruction*> lhs_matcher,
+ ::testing::Matcher<const HloInstruction*> rhs_matcher) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
+ ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
+}
+
+// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
+// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
+// equal to `rhs_contracting_dim`.
+//
+// Currently the HLO verifier rejects Dot operations with more than one
+// contracting dimension (even though we can represent these in the
+// DotDimensionNumbers proto) so there is no need to generalize this to support
+// multiple contracting dimensions.
+inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
+ ::testing::Matcher<const HloInstruction*> lhs_matcher,
+ ::testing::Matcher<const HloInstruction*> rhs_matcher,
+ int64 lhs_contracting_dim, int64 rhs_contracting_dim) {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloDotWithContractingDimsMatcher(
+ lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
+}
+
#undef HLO_MATCHER
} // namespace opcode_matchers
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index f2463060b7..016cc01e33 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
@@ -165,5 +166,41 @@ TEST(HloMatchersTest, ShardingMatcher) {
"has incorrect sharding (expected: {maximal device=0})");
}
+TEST(HloMatchersTest, DotMatcher) {
+ string hlo_string = R"(
+HloModule DotOperationFusion_TransposeFusion
+
+ENTRY DotOperationFusion_TransposeFusion {
+ arg0 = f32[1,256] parameter(0)
+ arg1 = f32[256,1024] parameter(1)
+ ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+ HloInstruction* root = module->entry_computation()->root_instruction();
+
+ EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
+ /*lhs_contracting_dim=*/1,
+ /*rhs_contracting_dim=*/0));
+
+ EXPECT_THAT(
+ Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
+ /*lhs_contracting_dim=*/0,
+ /*rhs_contracting_dim=*/0)),
+ "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
+ "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
+ "lhs_contracting_dimensions (got {1} want {0})");
+
+ EXPECT_THAT(
+ Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
+ /*lhs_contracting_dim=*/1,
+ /*rhs_contracting_dim=*/1)),
+ "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
+ "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
+ "rhs_contracting_dimensions (got {0} want {1})");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index c7a7192867..5308fb5848 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config)
config_(config),
unique_id_(next_unique_module_id_++) {}
+StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
+ const HloInstruction* hlo) {
+ if (hlo == nullptr) {
+ return nullptr;
+ }
+
+ TF_RET_CHECK(hlo->GetModule() == this);
+
+ // TODO(b/78350259): Eliminate const laundering.
+ return const_cast<HloInstruction*>(hlo);
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index f9674df812..1604a72612 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -217,6 +217,25 @@ class HloModule {
// the lifetime of this process.
int unique_id() const { return unique_id_; }
+ // Returns a non-const version of the passed-in const HloInstruction*. This is
+ // safe on the argument that if you have a non-const module, then you can
+ // access all instructions in the module as non-const.
+ //
+ // Returns an error if the passed-in instruction is not from this module,
+ // except that it is allowed to pass in a null pointer.
+ //
+ // TODO(b/78350259): Eliminate const laundering. The argument above is not
+ // reliable since at any time someone could add or discover a way for a
+ // non-const module to transitively contain a const HloInstruction. The
+ // reliable way to do this would be to create a const laundering map from a
+ // module, mapping each encountered HloInstruction to its non-const version
+ // and then look up each instruction in need of laundering in that map, but
+ // this is much more expensive and complicated. This returns a Status instead
+ // of doing a CHECK-failure in part to make it strongly apparent that this is
+ // something that can fail.
+ StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
+ const HloInstruction* hlo);
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 5120775737..d8f1ab916b 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -90,7 +90,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
return Status::OK();
};
- string prefix = name().ToString() + ": pipeline start";
+ string prefix = std::string(name()) + ": pipeline start";
bool changed = false;
string message;
TF_RETURN_IF_ERROR(
@@ -98,12 +98,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
const string xla_dump_per_pass_hlo_proto_to =
module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(),
- "pipeline_start");
+ DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to,
+ std::string(name()), "pipeline_start");
}
for (auto& pass : passes_) {
- if (disabled_passes.count(pass->name().ToString()) > 0) {
+ if (disabled_passes.count(std::string(pass->name())) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes";
continue;
@@ -121,7 +121,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to,
- name().ToString(), pass->name().ToString());
+ std::string(name()), std::string(pass->name()));
}
changed |= changed_this_pass;
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 1a767628f6..23ace5afea 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -430,6 +430,15 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
return ListScheduler::Run(computation, points_to_analysis, size_function);
}
+StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ const auto& post_order = computation.MakeInstructionPostOrder();
+ return std::vector<const HloInstruction*>{post_order.begin(),
+ post_order.end()};
+}
+
StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
@@ -459,7 +468,22 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
size_function));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
- if (list_memory <= dfs_memory) {
+ TF_ASSIGN_OR_RETURN(
+ std::vector<const HloInstruction*> post_order_sequence,
+ PostOrderMemoryScheduler(computation, points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(
+ const int64 post_order_memory,
+ MinimumMemoryForComputation(computation, post_order_sequence,
+ points_to_analysis, size_function));
+ VLOG(2) << "Min-memory post order sequence: "
+ << HumanReadableNumBytes(post_order_memory);
+
+ if (post_order_memory < std::min(list_memory, dfs_memory)) {
+ VLOG(2) << "Chose min-memory post_order sequence: "
+ << HumanReadableNumBytes(post_order_memory);
+ return post_order_sequence;
+
+ } else if (list_memory <= dfs_memory) {
VLOG(2) << "Chose min-memory list sequence: "
<< HumanReadableNumBytes(list_memory);
return list_sequence;
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index 068e68383d..fcb006f818 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -55,6 +55,12 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function);
+// Naive Post Order scheduler
+StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function);
+
// The default scheduling algorithm. Runs both the list scheduler
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 8a30cbf9cd..096ebb7946 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
// produces no HLO value in the graph.
if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
outfeed->operand(0)->shape())) {
- return InvalidArgument(
+ return InternalError(
"Expected outfeed to have shape compatible with operand's shape %s, "
"actual shape is %s:\n%s",
ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
@@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
transpose->operand(0)->shape(), transpose->dimensions()));
}
-Status ShapeVerifier::HandleParameter(HloInstruction*) {
+Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return tensorflow::Status::OK();
}
@@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
if (fp_type == PRIMITIVE_TYPE_INVALID) {
fp_type = subshape.element_type();
} else if (fp_type != subshape.element_type()) {
- return FailedPrecondition(
+ return InternalError(
"Seen floating point types of different precisions in "
"%s, but mixed precision is disallowed.",
instruction->ToString().c_str());
@@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
}
if (!compatible) {
- return InvalidArgument(
+ return InternalError(
"Expected instruction to have shape compatible with %s, actual "
"shape is %s:\n%s",
ShapeUtil::HumanString(inferred_shape).c_str(),
@@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
if (instr1->channel_id() != instr2->channel_id()) {
- return FailedPrecondition(
+ return InternalError(
"Expected to have the same channel id, actual channel ids are: %s "
"(%lld), %s (%lld)",
instr1->ToString().c_str(), instr1->channel_id(),
@@ -571,22 +571,22 @@ string ComputationsToString(
Status VerifyHloStructure(HloModule* module) {
for (const HloComputation* computation : module->computations()) {
if (computation->parent() == nullptr) {
- return FailedPrecondition("Computation %s has a null parent pointer",
- computation->name().c_str());
+ return InternalError("Computation %s has a null parent pointer",
+ computation->name().c_str());
}
if (computation->parent() != module) {
- return FailedPrecondition(
+ return InternalError(
"Computation %s parent() does not point to parent module",
computation->name().c_str());
}
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->parent() == nullptr) {
- return FailedPrecondition("Instruction %s has a null parent pointer",
- instruction->name().c_str());
+ return InternalError("Instruction %s has a null parent pointer",
+ instruction->name().c_str());
}
if (instruction->parent() != computation) {
- return FailedPrecondition(
+ return InternalError(
"Instruction %s parent() does not point to parent computation",
instruction->name().c_str());
}
@@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) {
for (int i = 0; i < instruction->operand_count(); ++i) {
const HloInstruction* operand = instruction->operand(i);
if (operand->parent() != instruction->parent()) {
- return FailedPrecondition(
+ return InternalError(
"Operand %d (%s) of instruction %s is in a different "
"computation: %s vs %s",
i, operand->name().c_str(), instruction->name().c_str(),
@@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
- return FailedPrecondition(
+ return InternalError(
"Instruction of fused computation does not match expected instruction "
"%s.",
fusion->ToString().c_str());
@@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
for (auto* instruction : fused_computation->instructions()) {
if (fused_root == instruction) {
if (root_owned) {
- return FailedPrecondition("Root appears more than once in %s.",
- fusion->ToString().c_str());
+ return InternalError("Root appears more than once in %s.",
+ fusion->ToString().c_str());
}
root_owned = true;
}
for (int i = 0; i < fused_parameters.size(); ++i) {
if (fused_parameters[i] == instruction) {
if (parameter_owned[i]) {
- return FailedPrecondition("Parameter appears more than once in %s.",
- fusion->ToString().c_str());
+ return InternalError("Parameter appears more than once in %s.",
+ fusion->ToString().c_str());
}
parameter_owned[i] = true;
}
}
}
if (!root_owned) {
- return FailedPrecondition("Root not found in computation of %s.",
- fusion->ToString().c_str());
+ return InternalError("Root not found in computation of %s.",
+ fusion->ToString().c_str());
}
// Make sure all the parameter_owned entries are set
for (int i = 0; i < parameter_owned.size(); i++) {
if (!parameter_owned[i]) {
- return FailedPrecondition("Parameter %d not found in computation of %s.",
- i, fusion->ToString().c_str());
+ return InternalError("Parameter %d not found in computation of %s.", i,
+ fusion->ToString().c_str());
}
}
// Fused root must have no users.
if (fused_root->user_count() != 0) {
- return FailedPrecondition("Root of %s may not have users.",
- fusion->ToString().c_str());
+ return InternalError("Root of %s may not have users.",
+ fusion->ToString().c_str());
}
// All uses of fused instructions must be in the fusion computation, and every
@@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
fusion->fused_instructions_computation()->instructions()) {
if (instruction != fused_root) {
if (instruction->user_count() == 0) {
- return FailedPrecondition(
- "Non-root instruction %s in %s must have users.",
- instruction->ToString().c_str(), fusion->ToString().c_str());
+ return InternalError("Non-root instruction %s in %s must have users.",
+ instruction->ToString().c_str(),
+ fusion->ToString().c_str());
}
for (auto& user : instruction->users()) {
if (fused_computation != user->parent()) {
- return FailedPrecondition(
+ return InternalError(
"Non-root instruction %s in %s may not have external users.",
instruction->ToString().c_str(), fusion->ToString().c_str());
}
@@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
for (auto fused_param : fused_parameters) {
int64 param_no = fused_param->parameter_number();
if (param_no < 0) {
- return FailedPrecondition(
- "Unexpected negative parameter number %lld in %s.", param_no,
- fusion->ToString().c_str());
+ return InternalError("Unexpected negative parameter number %lld in %s.",
+ param_no, fusion->ToString().c_str());
}
if (param_no >= fused_parameters.size()) {
- return FailedPrecondition(
+ return InternalError(
"Unexpected parameter number %lld in %s: higher then number of "
"parameters %lu.",
param_no, fusion->ToString().c_str(), fused_parameters.size());
}
if (parameter_numbers[param_no]) {
- return FailedPrecondition(
+ return InternalError(
"Did not expect parameter number %lld more than once in %s.",
param_no, fusion->ToString().c_str());
}
parameter_numbers[param_no] = true;
if (!ShapeUtil::Compatible(fused_param->shape(),
fusion->operand(param_no)->shape())) {
- return FailedPrecondition(
+ return InternalError(
"Shape mismatch between parameter number %lld and its operand in %s.",
param_no, fusion->ToString().c_str());
}
}
- // Make sure all the parameter_numbers entries were seen
+ // Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
if (!parameter_numbers[i]) {
- return FailedPrecondition("Did not see parameter number %d in %s.", i,
- fusion->ToString().c_str());
+ return InternalError("Did not see parameter number %d in %s.", i,
+ fusion->ToString().c_str());
}
}
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
index fc24acd271..fb36d3a0d6 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
@@ -32,7 +32,7 @@ class HumanReadableProfileBuilder {
explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name,
int64 total_cycles,
double clock_rate_ghz)
- : computation_name_(computation_name.ToString()),
+ : computation_name_(std::string(computation_name)),
total_cycles_(total_cycles),
clock_rate_ghz_(clock_rate_ghz) {
CHECK_GE(clock_rate_ghz, 1e-9);
@@ -47,9 +47,10 @@ class HumanReadableProfileBuilder {
tensorflow::StringPiece category, int64 cycles, int64 flop_count,
int64 transcendental_count, int64 bytes_accessed,
float optimal_seconds) {
- op_infos_.push_back(
- {op_name.ToString(), short_name.ToString(), category.ToString(), cycles,
- flop_count, transcendental_count, bytes_accessed, optimal_seconds});
+ op_infos_.push_back({std::string(op_name), std::string(short_name),
+ std::string(category), cycles, flop_count,
+ transcendental_count, bytes_accessed,
+ optimal_seconds});
}
// Gets the human-readable profile.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index dc1a39e9fa..6bb2ca19fe 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -28,6 +28,25 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
+namespace {
+// These nodes can always be duplicated into consumers, even if
+// InstructionFusion::may_duplicate_ is false.
+//
+// In general these should be nodes that get *cheaper* the more they're
+// duplicated (and fused into consumers).
+//
+// TODO(jlebar): Duplicating instructions when we have a variable called "may
+// duplicate" that's equal to false is not pretty.
+bool IsAlwaysDuplicable(const HloInstruction& instruction) {
+ // We are always willing to duplicate a widening type-conversion instruction
+ // if it means we can fuse the convert into a consumer. This allows the
+ // consumer to read less memory, which is almost always a performance win.
+ return instruction.opcode() == HloOpcode::kConvert &&
+ ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) <
+ ShapeUtil::ByteSizeOf(instruction.shape());
+}
+} // namespace
+
/*static*/ bool InstructionFusion::IsExpensive(
const HloInstruction& instruction) {
switch (instruction.opcode()) {
@@ -418,9 +437,11 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
+
// Cost condition: don't duplicate expensive instructions.
if (FusionWouldDuplicate(*producer, *consumer) &&
- (is_expensive_(*producer) || !may_duplicate_)) {
+ (!may_duplicate_ || is_expensive_(*producer)) &&
+ !IsAlwaysDuplicable(*producer)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index e78b99a80c..6dd8fa1ab0 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -21,6 +21,8 @@ limitations under the License.
namespace xla {
+namespace op = xla::testing::opcode_matchers;
+
using InstructionFusionTest = HloTestBase;
TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
@@ -124,7 +126,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
// Make sure the add hasn't been duplicated.
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
}
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
@@ -291,4 +293,29 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
.ValueOrDie());
}
+TEST_F(InstructionFusionTest,
+ WideningConvertsAreAlwaysDuplicableIntoConsumers) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY Test {
+ p0 = f16[100] parameter(0)
+ c = f32[100] convert(p0)
+ add = f32[100] add(c, c)
+ ROOT mul = f32[100] multiply(c, c)
+ })")
+ .ValueOrDie();
+
+ // The convert should be fused into the add and mul, even though may_duplicate
+ // is false, because it's always beneficial to fuse/duplicate widening
+ // converts into consumers.
+ EXPECT_TRUE(
+ InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion(op::Parameter()));
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc
index 68c99256a2..79dfd1e409 100644
--- a/tensorflow/compiler/xla/service/liveness_util.cc
+++ b/tensorflow/compiler/xla/service/liveness_util.cc
@@ -173,9 +173,9 @@ bool HasUniqueFusedUseOfOperandAt(
// (2) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0. Or...
-// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion
-// instruction where the only use of 'operand' at 'index' in the set
-// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or...
+// (3) Is a kDot -> kAdd output fusion instruction where the only use of
+// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused
+// root at operand 0 or 1. Or...
// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
// 0.
//
@@ -209,17 +209,13 @@ bool CanShareOperandBufferWithUser(
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
- // Check if one operand of kAdd fused root is either kDot, or nested
- // kFusion of kind kTransposeDot.
+ // Check if one operand of kAdd fused root is kDot or kConvolution.
auto* add = user->fused_expression_root();
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution ||
- operand->opcode() == HloOpcode::kDot ||
- (operand->opcode() == HloOpcode::kFusion &&
- operand->fusion_kind() ==
- HloInstruction::FusionKind::kTransposeDot);
+ operand->opcode() == HloOpcode::kDot;
});
if (add_operand_it == add->operands().end()) {
return false;
@@ -314,17 +310,13 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand,
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
- // Check if one operand of kAdd fused root is either kDot, or nested
- // kFusion of kind kTransposeDot.
+ // Check if one operand of kAdd fused root is kDot, or kConvolution.
auto* add = user->fused_expression_root();
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution ||
- operand->opcode() == HloOpcode::kDot ||
- (operand->opcode() == HloOpcode::kFusion &&
- operand->fusion_kind() ==
- HloInstruction::FusionKind::kTransposeDot);
+ operand->opcode() == HloOpcode::kDot;
});
if (add_operand_it == add->operands().end()) {
return false;
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
index f8b309488e..c01b52df62 100644
--- a/tensorflow/compiler/xla/service/liveness_util_test.cc
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -303,48 +303,6 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
*dataflow_analysis_));
}
-TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
- auto builder = HloComputation::Builder(TestName());
- Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
-
- auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
- auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
- auto b_t = builder.AddInstruction(
- HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
-
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
-
- auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto add_operand = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape, one, {1}));
-
- auto add = builder.AddInstruction(HloInstruction::CreateBinary(
- data_shape, HloOpcode::kAdd, dot, add_operand));
-
- BuildModule(builder.Build());
-
- auto nested_fusion = computation_->CreateFusionInstruction(
- {dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
-
- auto fusion = computation_->CreateFusionInstruction(
- {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
- RunAnalysis();
-
- // Output fused transpose-dot-add should be share buffer with 'add_operand'.
- EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
- *points_to_analysis_));
-
- EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
- *dataflow_analysis_));
-}
-
TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 3312a88844..7323abeb20 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -333,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(
}
CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_));
- std::vector<llvm::Value*> actual_index;
- bool is_implicit_broadcast = false;
- // We perform broadcasting when the operand shape has dimension(s) of size
- // 1. In this case we fix the index value for that dimension to zero. This
- // effectively broadcasts along this dimension.
- for (int64 i = 0; i < index.size(); ++i) {
- auto dim = shape_->dimensions(i);
- actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
- is_implicit_broadcast |= dim == 1;
- }
-
- if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
+ if (index.LinearValidOnShape(*shape_)) {
llvm::Module* module =
ir_builder->GetInsertBlock()->getParent()->getParent();
return ir_builder->CreateInBoundsGEP(
@@ -354,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress(
{index.linear()}, llvm_ir::AsStringRef(name));
}
+ std::vector<llvm::Value*> actual_index;
+ for (int64 i = 0; i < index.size(); ++i) {
+ // When dimension i is of size 1, LLVM optimization is able to replace
+ // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
+ // produce better code in some cases.
+ auto dim = shape_->dimensions(i);
+ actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
+ }
+
// "base_ptr_" has the type of "<ir_type_for_its_shape>*"
// (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
// should be computed by
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index f74bcb0b79..3a6a7c25f4 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -53,7 +53,7 @@ NameUniquer::NameUniquer(const string& separator) {
}
string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
- string root = GetSanitizedName(prefix.empty() ? "name" : prefix.ToString());
+ string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix));
// Strip away numeric suffix (if any). Only recognize separator if it is in
// the middle of the name.
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 586f6ef7a9..d3bc47e61e 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -702,6 +702,30 @@ class HloInstructionPatternOperandImpl {
HloInstructionPattern<OperandType, OperandImpl> operand_;
};
+// An HloInstructionPattern implementation that matches only if the instruction
+// is a fusion node with a particular kind.
+template <typename Previous>
+class HloInstructionPatternFusionKindImpl {
+ public:
+ explicit constexpr HloInstructionPatternFusionKindImpl(
+ const Previous& previous, ::xla::HloInstruction::FusionKind kind)
+ : previous_(previous), kind_(kind) {}
+
+ bool Match(const ::xla::HloInstruction* inst) const {
+ return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
+ inst->fusion_kind() == kind_;
+ }
+
+ bool Match(::xla::HloInstruction* inst) const {
+ return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
+ inst->fusion_kind() == kind_;
+ }
+
+ private:
+ Previous previous_;
+ ::xla::HloInstruction::FusionKind kind_;
+};
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
@@ -807,6 +831,16 @@ class HloInstructionPattern {
matched_inst_);
}
+ // Modifies the pattern to match only if the instruction is a fusion node with
+ // the given kind.
+ constexpr HloInstructionPattern<HloInstructionType,
+ HloInstructionPatternFusionKindImpl<Impl>>
+ WithFusionKind(HloInstruction::FusionKind kind) const {
+ return HloInstructionPattern<HloInstructionType,
+ HloInstructionPatternFusionKindImpl<Impl>>(
+ HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_);
+ }
+
private:
Impl impl_;
HloInstructionType** matched_inst_;
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index c88157c312..204e8c9920 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -170,5 +170,28 @@ TEST(PatternMatcherTest, TupleShape) {
Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
}
+TEST(PatternMatcherTest, FusionKind) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+
+ fused_computation {
+ ROOT fp0 = f32[] parameter(0)
+ }
+
+ ENTRY while.v11 {
+ p0 = f32[] parameter(0)
+ ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr));
+
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ EXPECT_TRUE(Match(
+ root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
+ EXPECT_FALSE(Match(
+ root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
+ EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
+ HloInstruction::FusionKind::kLoop)));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 48b2922e77..c493547d9e 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -172,11 +172,11 @@ tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
tensorflow::StringPiece op_type) {
if (ShapeUtil::IsTuple(shape)) {
return InvalidArgument("Expected non-tuple argument for %s, but got %s.",
- op_type.ToString().c_str(),
+ std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
} else if (ShapeUtil::IsOpaque(shape)) {
return InvalidArgument("Expected non-opaque argument for %s, but got %s.",
- op_type.ToString().c_str(),
+ std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
} else {
return tensorflow::Status::OK();
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 3efd38ce0d..f7a5512fec 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -35,7 +35,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
const HloInstruction& dot,
const TransposeFolding::TransposableGemmOperandsFn&
transposable_gemm_operands) {
- if (HloOpcode::kDot != dot.opcode()) {
+ if (HloOpcode::kDot != dot.opcode() ||
+ dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) {
return {};
}
@@ -44,6 +45,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
auto& operand = *dot.operand(i);
if (operand.IsRank2Transpose()) {
operand_set.push_back(i);
+ } else if (ShapeUtil::Rank(operand.shape()) != 2) {
+ return {};
}
}
@@ -74,23 +77,39 @@ using InstructionOperandsPair =
// Folds the operands of `dot` that are foldable transposes. `computation` is
// the parent HLO computation of `dot`.
-//
-// Returns whether the module is changed.
-bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
- auto* dot = pair.first;
- std::vector<HloInstruction*> instructions_to_fuse(1, dot);
- for (const int64 operand_index : pair.second) {
- instructions_to_fuse.push_back(dot->mutable_operand(operand_index));
- }
-
- // Early-exit if no operands are foldable.
- if (instructions_to_fuse.size() == 1) {
- return false;
+Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
+ HloInstruction* dot = pair.first;
+
+ DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers();
+ HloInstruction* new_lhs = dot->mutable_operand(0);
+ HloInstruction* new_rhs = dot->mutable_operand(1);
+
+ CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0);
+ CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0);
+ CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1);
+ CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1);
+
+ for (int64 operand_index : pair.second) {
+ // We've checked that there aren't any batch dimensions and that the inputs
+ // are rank 2, and shape inference guarantees that there is exactly one
+ // contracting dimension.
+ if (operand_index == 0) {
+ CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose);
+ new_dim_numbers.set_lhs_contracting_dimensions(
+ 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0));
+ new_lhs = new_lhs->mutable_operand(0);
+ } else {
+ CHECK_EQ(operand_index, 1);
+ CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose);
+ new_dim_numbers.set_rhs_contracting_dimensions(
+ 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0));
+ new_rhs = new_rhs->mutable_operand(0);
+ }
}
- dot->parent()->CreateFusionInstruction(
- instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot);
- return true;
+ std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
+ dot->shape(), new_lhs, new_rhs, new_dim_numbers);
+ return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
// Folds the operands of `convolution` that are foldable transposes.
@@ -205,7 +224,8 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
bool changed = false;
for (InstructionOperandsPair& pair : foldable_dots) {
- changed |= FoldTransposeIntoDot(pair);
+ TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair));
+ changed = true;
}
for (InstructionOperandsPair& pair : foldable_convolutions) {
changed |= FoldTransposeIntoConvolution(pair);
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 0319109f7f..f73f1227aa 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -31,9 +32,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
+namespace op = xla::testing::opcode_matchers;
+
namespace xla {
namespace {
@@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase {
};
TEST_F(TransposeFoldingTest, FoldDotTranspose) {
- auto builder = HloComputation::Builder("entry_computation");
- HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"x"));
- HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"y"));
- HloInstruction* transpose_y =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
- /*rhs=*/transpose_y, dot_dnums));
+ string hlo_string = R"(
+HloModule FoldDotTranspose
+
+ENTRY entry_computation {
+ x = f32[2,3]{1,0} parameter(0)
+ y = f32[2,3]{1,0} parameter(1)
+ transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0}
+ ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
- auto module = CreateNewModule("test_module");
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build(dot));
FoldTranspose(module.get());
- // Instructions after folding: x, y, and the fusion.
- std::unordered_set<HloInstruction*> instruction_set(
- entry_computation->instructions().begin(),
- entry_computation->instructions().end());
- CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
- CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
- CHECK_EQ(1, instruction_set.size())
- << "entry_computation should contain exactly 3 instructions.";
- HloInstruction* fusion = *instruction_set.begin();
- EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Dot(op::Parameter(0), op::Parameter(1),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
+}
+
+TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) {
+ string hlo_string = R"(
+HloModule FoldDotTranspose
- // The fusion instruction should contain two parameters, one transpose and
- // one dot.
- EXPECT_EQ(4, fusion->fused_instruction_count());
+ENTRY entry_computation {
+ x = f32[2,3] parameter(0)
+ y = f32[3,2] parameter(1)
+ transpose = f32[2,3] transpose(y), dimensions={1,0}
+ ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ TransposeFolding transpose_folding(
+ [](const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ },
+ [](const HloInstruction& convolution,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ });
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) {
+ string hlo_string = R"(
+HloModule FoldDotTranspose
+
+ENTRY entry_computation {
+ x = f32[3] parameter(0)
+ y = f32[3,2] parameter(1)
+ transpose = f32[2,3] transpose(y), dimensions={1,0}
+ ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ TransposeFolding transpose_folding(
+ [](const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ },
+ [](const HloInstruction& convolution,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return candidate_operands;
+ });
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
+ EXPECT_FALSE(changed);
}
TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
- auto builder = HloComputation::Builder("entry_computation");
- // 2x1
- HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1}, {2}})));
- // 3x2
- HloInstruction* const1 =
- builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}})));
- HloInstruction* transpose0 =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0}));
- HloInstruction* transpose1 =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0}));
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(F32, {1, 3}),
- /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
+ string hlo_string = R"(
+HloModule FoldDotTransposeConstant
+
+ENTRY entry_computation {
+ constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } })
+ transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0}
+ constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } })
+ transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0}
+ ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
- auto module = CreateNewModule("test_module");
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build(dot));
FoldTranspose(module.get());
- for (auto* instruction : entry_computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kFusion) {
- CHECK_EQ(2, instruction->operand_count());
- EXPECT_EQ(const0, instruction->operand(0));
- EXPECT_EQ(const1, instruction->operand(1));
- }
- }
-
- // The created fusion instruction should contain two parameters, two
- // transposes (one for each parameter) and one dot.
- EXPECT_EQ(5,
- entry_computation->root_instruction()->fused_instruction_count());
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Dot(op::Constant(), op::Constant(),
+ /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1));
}
TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
@@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
EXPECT_EQ(6, callee_computation->instruction_count());
}
-TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
- auto builder = HloComputation::Builder("entry_computation");
- HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"x"));
- HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
- /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
- /*name=*/"y"));
- HloInstruction* transpose_y =
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
- DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(1);
- dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
- /*rhs=*/transpose_y, dot_dnums));
-
- auto module = CreateNewModule("test_module");
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build(dot));
+TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) {
+ string hlo_string = R"(
+HloModule FoldDotTransposeInCall
- HloInstruction* call = module->OutlineExpressionFromComputation(
- {transpose_y, dot}, "outlined", entry_computation);
+callee {
+ name.0 = f32[2,3]{1,0} parameter(0)
+ name.1 = f32[2,3]{1,0} parameter(1)
+ transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0}
+ ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+ENTRY entry_computation {
+ y = f32[2,3]{1,0} parameter(1)
+ x = f32[2,3]{1,0} parameter(0)
+ ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
FoldTranspose(module.get());
- // Instructions after folding: x, y, and the fusion.
- std::unordered_set<HloInstruction*> instruction_set(
- entry_computation->instructions().begin(),
- entry_computation->instructions().end());
- CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
- CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
- CHECK_EQ(1, instruction_set.erase(call))
- << "call is not in entry_computation.";
- CHECK(instruction_set.empty())
- << "entry_computation should contain exactly 3 instructions.";
- HloInstruction* fusion =
- call->called_computations().front()->root_instruction();
- EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
-
- // The fusion instruction should contain two parameters, one transpose and
- // one dot.
- EXPECT_EQ(4, fusion->fused_instruction_count());
+ const HloComputation* callee = module->GetComputationWithName("callee");
+ ASSERT_NE(callee, nullptr);
+ EXPECT_THAT(callee->root_instruction(),
+ op::Dot(op::Parameter(1), op::Parameter(0),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
}
// Test that a two dimension swap of the kernel gets folded into convolution.
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
index 5b44c26b7c..4f64fe8f83 100644
--- a/tensorflow/compiler/xla/service_interface.h
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_
#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h
index cccbce5fc8..0e1387c939 100644
--- a/tensorflow/compiler/xla/statusor.h
+++ b/tensorflow/compiler/xla/statusor.h
@@ -13,13 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// StatusOr<T> is the union of a Status object and a T
-// object. StatusOr models the concept of an object that is either a
-// usable value, or an error Status explaining why such a value is
-// not present. To this end, StatusOr<T> does not allow its Status
-// value to be Status::OK. Furthermore, the value of a StatusOr<T*>
-// must not be null. This is enforced by a debug check in most cases,
-// but even when it is not, clients must not set the value to null.
+// StatusOr<T> is the union of a Status object and a T object. StatusOr models
+// the concept of an object that is either a value, or an error Status
+// explaining why such a value is not present. To this end, StatusOr<T> does not
+// allow its Status value to be Status::OK.
//
// The primary use-case for StatusOr<T> is as the return value of a
// function which may fail.
diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc
index f9d25945bc..7d76370e85 100644
--- a/tensorflow/compiler/xla/statusor_test.cc
+++ b/tensorflow/compiler/xla/statusor_test.cc
@@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) {
static_assert(std::is_same<StatusOr<char>::element_type, char>(), "");
}
+TEST(StatusOr, NullPointerStatusOr) {
+ // As a very special case, null-plain-pointer StatusOr used to be an
+ // error. Test that it no longer is.
+ StatusOr<int*> null_status(nullptr);
+ EXPECT_TRUE(null_status.ok());
+ EXPECT_EQ(null_status.ValueOrDie(), nullptr);
+}
+
TEST(StatusOr, TestNoDefaultConstructorInitialization) {
// Explicitly initialize it with an error code.
StatusOr<NoDefaultConstructor> statusor(tensorflow::errors::Cancelled(""));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 0571ff5055..b982cf0dbc 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1867,7 +1867,10 @@ xla_test(
xla_test(
name = "local_client_execute_test",
+ # TODO(b/79375911): Test times out in LLVM at normal size.
+ size = "large",
srcs = ["local_client_execute_test.cc"],
+ shard_count = 30,
tags = ["optonly"],
deps = [
"//tensorflow/compiler/xla:literal_util",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index c09e7eaf2b..41f9a5f666 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -565,4 +565,33 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
}
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
+ const Literal& literal,
+ const string& name,
+ XlaBuilder* builder,
+ XlaOp* data_handle) {
+ return CreateParameterAndTransferLiteral(parameter_number, literal, name,
+ nullptr, builder, data_handle);
+}
+
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ const DeviceHandle* device_handle, XlaBuilder* builder,
+ XlaOp* data_handle) {
+ const Literal* param_literal = &literal;
+ std::unique_ptr<Literal> converted_literal;
+ if (use_bfloat16_) {
+ converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+ param_literal = converted_literal.get();
+ }
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*param_literal, device_handle)
+ .ConsumeValueOrDie();
+ *data_handle =
+ builder->Parameter(parameter_number, param_literal->shape(), name);
+ return data;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index e58979a303..16e838e60f 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -616,35 +616,6 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
return result;
}
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
- const Literal& literal,
- const string& name,
- XlaBuilder* builder,
- XlaOp* data_handle) {
- return CreateParameterAndTransferLiteral(parameter_number, literal, name,
- nullptr, builder, data_handle);
-}
-
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(
- int64 parameter_number, const Literal& literal, const string& name,
- const DeviceHandle* device_handle, XlaBuilder* builder,
- XlaOp* data_handle) {
- const Literal* param_literal = &literal;
- std::unique_ptr<Literal> converted_literal;
- if (use_bfloat16_) {
- converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
- param_literal = converted_literal.get();
- }
- std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*param_literal, device_handle)
- .ConsumeValueOrDie();
- *data_handle =
- builder->Parameter(parameter_number, param_literal->shape(), name);
- return data;
-}
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 6b3efba4f8..efa5aed2d1 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -798,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
this->error_spec_);
}
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{96.0, 105.0, 114.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{105.0}, {105.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{105.0, 105.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{96.0}, {105.0}, {114.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {6.0, 5.0},
+ {4.0, 3.0},
+ {2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{126.0, 129.0, 132.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(
+ new Array2D<float>({{1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {6.0, 5.0},
+ {4.0, 3.0},
+ {2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0},
+ {9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0},
+ {3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({0, 1});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{129.0}, {129.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({{56.0, 168.0, 91.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
+TEST_F(DotOperationTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(
+ DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
+ std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
+ {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ std::unique_ptr<Array2D<float>> constant_rhs_array(
+ new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
+ {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
+ // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
+
+ XlaBuilder builder(TestName());
+ auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
+ auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
+ auto start_constant = builder.ConstantR1<int32>({1, 0});
+ auto dynamic_slice =
+ builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+
+ Array2D<float> expected({{168.0}, {168.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 3a945fb3b1..156a06c596 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -30,6 +30,7 @@ namespace {
using tensorflow::StringPiece;
using tensorflow::gtl::optional;
+using tensorflow::str_util::Join;
using tensorflow::str_util::Split;
using tensorflow::str_util::SplitAndParseAsInts;
using tensorflow::strings::Printf;
@@ -53,7 +54,7 @@ class HloParser {
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
- string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
+ string GetError() const { return Join(error_, "\n"); }
private:
// ParseXXX returns false if an error occurred.
@@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) {
error_lines.push_back(std::string(lexer_.GetLine(loc)));
error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
- error_.push_back(tensorflow::str_util::Join(error_lines, "\n"));
+ error_.push_back(Join(error_lines, "\n"));
VLOG(1) << "Error: " << error_.back();
return false;
}
@@ -439,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<OpMetadata> metadata;
attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
+ optional<string> backend_config;
+ attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
+ &backend_config};
+
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -1093,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
instruction->set_name(name);
- // Add common attrs (sharding, control predecessors) to the instruction, if
- // they were seen.
+ // Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
instruction->set_sharding(
HloSharding::FromProto(sharding.value()).ValueOrDie());
@@ -1111,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (metadata) {
instruction->set_metadata(*metadata);
}
+ if (backend_config) {
+ instruction->set_backend_config(std::move(*backend_config));
+ }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
@@ -1488,11 +1495,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
elems_seen_per_dim.begin() + dim);
return StrCat("[",
- tensorflow::str_util::Join(
- elems_seen_until_dim, ",",
- [](string* out, const int64& num_elems) {
- tensorflow::strings::StrAppend(out, num_elems - 1);
- }),
+ Join(elems_seen_until_dim, ",",
+ [](string* out, const int64& num_elems) {
+ tensorflow::strings::StrAppend(out, num_elems - 1);
+ }),
"]");
};
do {
@@ -1680,7 +1686,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return Error(
index_loc,
StrCat("invalid multi-dimension index for shape with rank ", rank,
- ": [", tensorflow::str_util::Join(index, ", "), "]"));
+ ": [", Join(index, ", "), "]"));
}
}
if (!ParseToken(TokKind::kColon,
@@ -1848,7 +1854,19 @@ bool HloParser::ParseAttributeHelper(
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
- return Error(loc, Printf("unexpected attribute %s", name.c_str()));
+ string allowed_attrs;
+ if (attrs.empty()) {
+ allowed_attrs = "No attributes are allowed here.";
+ } else {
+ allowed_attrs = StrCat(
+ "Allowed attributes: ",
+ Join(attrs, ", ",
+ [&](string* out, const std::pair<string, AttrConfig>& kv) {
+ StrAppend(out, kv.first);
+ }));
+ }
+ return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(),
+ allowed_attrs.c_str()));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 4e085bc89c..e100d8cda1 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
R"(HloModule constant_pred_module
ENTRY %constant_pred () -> pred[] {
- ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}
+ ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
}
)"
@@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] {
)"
},
-// f32 constant, but the value is not a decimal
+// f32 constant, but the value is not a decimal and there is a backend
+// configuration
{
"ConstantF32",
R"(HloModule ConstantF32_module
ENTRY %ConstantF32.v4 () -> f32[] {
- ROOT %constant = f32[] constant(42)
+ ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
}
)"
@@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
// but the constant names will not be exactly the same.
}
+TEST_F(HloParserTest, ConfigurationField) {
+ const string original = R"(HloModule AModule
+ENTRY %configuration_test() -> s32[] {
+ %constant = s32[] constant(42), backend_config="foo bar"
+})";
+ auto result = Parse(original);
+ TF_ASSERT_OK(result.status());
+ EXPECT_EQ("foo bar", result.ValueOrDie()
+ ->entry_computation()
+ ->root_instruction()
+ ->backend_config());
+}
+
TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
const string original = R"(HloModule some_2_module
@@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
}
)";
@@ -1138,7 +1152,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
)";
ExpectHasSubstr(Parse(original).status().error_message(),
- "unexpected attribute calls");
+ "unexpected attribute \"calls\"");
}
TEST_F(HloParserTest, MissingAttribute) {
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
index 83f3bafc42..8064a967cd 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
@@ -19,6 +19,7 @@ py_library(
srcs = [
"activity.py",
"annos.py",
+ "cfg.py",
"live_values.py",
"type_info.py",
],
@@ -44,6 +45,19 @@ py_test(
)
py_test(
+ name = "cfg_test",
+ srcs = ["cfg_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":static_analysis",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
name = "live_values_test",
srcs = ["live_values_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
new file mode 100644
index 0000000000..230e4cc0f3
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
@@ -0,0 +1,431 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow graph analysis.
+
+Given a Python AST we construct a control flow graph, with edges both to the
+next and previous statements (so it can easily walk the graph both ways). Its
+nodes contain the AST of the statements. It can then perform forward or backward
+analysis on this CFG.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+import functools
+import operator
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct.static_analysis import activity
+
+
+class CfgNode(object):
+ """A node in the CFG."""
+ __slots__ = ['next', 'value', 'prev']
+
+ def __init__(self, value):
+ self.next = set()
+ self.prev = set()
+ self.value = value
+
+
+class Cfg(namedtuple('Cfg', ['entry', 'exit'])):
+ """A Control Flow Graph.
+
+ Each statement is represented as a node. For control flow statements such
+ as conditionals and loops the conditional itself is a node which either
+ branches or cycles, respectively.
+ Attributes:
+ entry: The entry node, which contains the `gast.arguments` node of the
+ function definition.
+ exit: The exit node. This node is special because it has no value (i.e. no
+ corresponding AST node). This is because Python functions can have
+ multiple return statements.
+ """
+ pass
+
+
+class CfgBuilder(gast.NodeVisitor):
+ """Construct a control flow graph.
+
+ Construct a CFG starting from a FunctionDef node.
+ Usage:
+ cfg_obj = CfgBuilder().build_cfg(fndef_node)
+ """
+
+ def __init__(self):
+ # The current leaves of the CFG
+ self.current_leaves = []
+ # TODO(alexbw): generalize to break, return, continue, yield, etc.
+ # A stack of lists, tracking continue statements
+ self.continue_ = []
+ # A stack of lists tracking break nodes
+ self.break_ = []
+
+ def set_current_leaves(self, cfg_node):
+ """Link this cfg_node to the current leaves.
+
+ This is the central function for building the CFG. It links the current
+ head cfg_nodes to the passed cfg_node. It then resets the head to the
+ passed cfg_node.
+
+ Args:
+ cfg_node: A CfgNode instance.
+ """
+ for head in self.current_leaves:
+ head.next.add(cfg_node)
+ # While we're linking the CFG forward, add backlinks
+ cfg_node.prev.add(head)
+ self.current_leaves = [cfg_node]
+
+ def build_cfg(self, node):
+ """Build a CFG for a function.
+
+ Implementation of building a CFG for dataflow analysis. See, e.g.:
+ https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf
+
+ Args:
+ node: A function definition the body of which to analyze.
+ Returns:
+ A CFG object.
+ Raises:
+ TypeError: If the input is not a function definition.
+ """
+ if not isinstance(node, gast.FunctionDef):
+ raise TypeError('input must be a function definition')
+ entry_cfg_node = CfgNode(node.args)
+ self.current_leaves = [entry_cfg_node]
+ self.visit_statements(node.body)
+ exit_cfg_node = CfgNode(None)
+ self.set_current_leaves(exit_cfg_node)
+ return Cfg(entry_cfg_node, exit_cfg_node)
+
+ def visit_statements(self, nodes):
+ for node in nodes:
+ # Check for control flow
+ if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break,
+ gast.Continue, gast.With)):
+ self.visit(node)
+ else:
+ expr = CfgNode(node)
+ self.set_current_leaves(expr)
+
+ def generic_visit(self, node):
+ raise ValueError('unknown control flow')
+
+ def visit_If(self, node):
+ # TODO(alexbw): change this to use immutable tuples instead of lists
+ # The current head will hold the conditional
+ test = CfgNode(node.test)
+ self.set_current_leaves(test)
+ # Handle the body
+ self.visit_statements(node.body)
+ body_exit = self.current_leaves
+ self.current_leaves = []
+ self.current_leaves.append(test)
+ # Handle the orelse
+ self.visit_statements(node.orelse)
+ self.current_leaves.extend(body_exit)
+
+ def visit_While(self, node):
+ test = CfgNode(node.test)
+ self.set_current_leaves(test)
+ # Start a new level of nesting
+ self.break_.append([])
+ self.continue_.append([])
+ # Handle the body
+ self.visit_statements(node.body)
+ self.current_leaves.extend(self.continue_.pop())
+ self.set_current_leaves(test)
+ # Handle the orelse
+ self.visit_statements(node.orelse)
+ # The break statements and the test go to the next node
+ self.current_leaves.extend(self.break_.pop())
+
+ def visit_For(self, node):
+ iter_ = CfgNode(node.iter)
+ self.set_current_leaves(iter_)
+ self.break_.append([])
+ self.continue_.append([])
+ self.visit_statements(node.body)
+ self.current_leaves.extend(self.continue_.pop())
+ self.set_current_leaves(iter_)
+ self.current_leaves.extend(self.break_.pop())
+
+ def visit_Break(self, node):
+ self.break_[-1].extend(self.current_leaves)
+ self.current_leaves[:] = []
+
+ def visit_Continue(self, node):
+ self.continue_[-1].extend(self.current_leaves)
+ self.current_leaves[:] = []
+
+ def visit_Try(self, node):
+ self.visit_statements(node.body)
+ body = self.current_leaves
+ handlers = []
+ for handler in node.handlers:
+ self.current_leaves = body[:]
+ self.visit_statements(handler.body)
+ handlers.extend(self.current_leaves)
+ self.current_leaves = body
+ self.visit_statements(node.orelse)
+ self.current_leaves = handlers + self.current_leaves
+ self.visit_statements(node.finalbody)
+
+ def visit_With(self, node):
+ for item in node.items:
+ self.set_current_leaves(CfgNode(item))
+ self.visit_statements(node.body)
+
+
+# TODO(alexbw): once CFG analysis occurs at a block level,
+# this extra class will not be necessary
+class PropagateAnalysis(gast.NodeVisitor):
+ """Port analysis annotations from statements to their enclosing blocks."""
+
+ def __init__(self, analysis):
+ self.transfer_fn = analysis.transfer_fn
+ self.in_label = analysis.in_label
+ self.out_label = analysis.out_label
+ super(PropagateAnalysis, self).__init__()
+
+ def visit_If(self, node):
+ # Depth-first.
+ self.generic_visit(node)
+ incoming = anno.getanno(node.body[0], self.in_label)
+ incoming |= anno.getanno(node.test, self.in_label)
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ outgoing |= anno.getanno(node.test, self.out_label)
+ if node.orelse:
+ orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
+ outgoing = self.transfer_fn(outgoing, orelse_outgoing)
+ anno.setanno(node, self.in_label, incoming)
+ anno.setanno(node, self.out_label, outgoing)
+
+ def visit_For(self, node):
+ self.generic_visit(node)
+ incoming = set(anno.getanno(node.body[0], self.in_label))
+ incoming -= set((anno.getanno(node.target, anno.Basic.QN),))
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ if node.orelse:
+ orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
+ outgoing = self.transfer_fn(outgoing, orelse_outgoing)
+ anno.setanno(node, self.in_label, frozenset(incoming))
+ anno.setanno(node, self.out_label, outgoing)
+
+ def visit_While(self, node):
+ self.generic_visit(node)
+ incoming = anno.getanno(node.body[0], self.in_label)
+ incoming |= anno.getanno(node.test, self.in_label)
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ if node.orelse:
+ orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
+ outgoing = self.transfer_fn(outgoing, orelse_outgoing)
+ anno.setanno(node, self.in_label, incoming)
+ anno.setanno(node, self.out_label, outgoing)
+
+ def visit_With(self, node):
+ self.generic_visit(node)
+ incoming = anno.getanno(node.body[0], self.in_label)
+ for item in node.items:
+ incoming |= anno.getanno(item, self.in_label)
+ outgoing = anno.getanno(node.body[-1], self.out_label)
+ anno.setanno(node, self.in_label, incoming)
+ anno.setanno(node, self.out_label, outgoing)
+
+
+# TODO(alexbw): Abstract the CFG walking machinery into a superclass
+# which is parameterized on which fields it selects when walking.
+# TODO(alexbw): Abstract the application of dataflow analysis
+class Forward(object):
+ """Forward analysis on CFG.
+
+ Args:
+ label: A name for this analysis e.g. 'active' for activity analysis. The AST
+ nodes in the CFG will be given annotations 'name_in', 'name_out',
+ 'name_gen' and 'name_kill' which contain the incoming values, outgoing
+ values, values generated by the statement, and values deleted by the
+ statement respectively.
+ transfer_fn: Either the AND or OR operator. If the AND operator is used it
+ turns into forward must analysis (i.e. a value will only be carried
+ forward if it appears on all incoming paths). The OR operator means that
+ forward may analysis is done (i.e. the union of incoming values will be
+ taken).
+ """
+
+ def __init__(self, label, context, transfer_fn=operator.or_):
+ self.transfer_fn = transfer_fn
+ self.context = context
+ self.out_label = label + '_out'
+ self.in_label = label + '_in'
+ self.gen_label = label + '_gen'
+ self.kill_label = label + '_kill'
+
+ # TODO(alexbw): see if we can simplify by visiting breadth-first
+ def visit(self, node):
+ """Depth-first walking the CFG, applying dataflow information propagtion."""
+ # node.value is None only for the exit CfgNode.
+ if not node.value:
+ return
+
+ if anno.hasanno(node.value, self.out_label):
+ before = hash(anno.getanno(node.value, self.out_label))
+ else:
+ before = None
+ preds = [
+ anno.getanno(pred.value, self.out_label)
+ for pred in node.prev
+ if anno.hasanno(pred.value, self.out_label)
+ ]
+ if preds:
+ incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0])
+ else:
+ incoming = frozenset()
+ anno.setanno(node.value, self.in_label, incoming)
+ gen, kill = self.get_gen_kill(node, incoming)
+ anno.setanno(node.value, self.gen_label, gen)
+ anno.setanno(node.value, self.kill_label, kill)
+ anno.setanno(node.value, self.out_label, (incoming - kill) | gen)
+
+ if hash(anno.getanno(node.value, self.out_label)) != before:
+ for succ in node.next:
+ self.visit(succ)
+
+ def get_gen_kill(self, cfg_node, incoming):
+ """Calculate Gen and Kill properties of a CFG node in dataflow analysis.
+
+ A function which takes the CFG node as well as a set of incoming
+ values. It must return a set of newly generated values by the statement as
+ well as a set of deleted (killed) values.
+
+ Args:
+ cfg_node: A CfgNode instance.
+ incoming:
+ """
+ raise NotImplementedError()
+
+
+class Backward(Forward):
+ """Backward analysis on CFG."""
+
+ def visit(self, cfg_node):
+ # cfg_node.value is None for the exit node, which will be visited only once
+ if not cfg_node.value:
+ for pred in cfg_node.prev:
+ self.visit(pred)
+ return
+
+ if anno.hasanno(cfg_node.value, self.in_label):
+ before = hash(anno.getanno(cfg_node.value, self.in_label))
+ else:
+ before = None
+ succs = [
+ anno.getanno(succ.value, self.in_label)
+ for succ in cfg_node.next
+ if anno.hasanno(succ.value, self.in_label)
+ ]
+ if succs:
+ incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0])
+ else:
+ incoming = frozenset()
+ anno.setanno(cfg_node.value, self.out_label, incoming)
+ gen, kill = self.get_gen_kill(cfg_node, incoming)
+ anno.setanno(cfg_node.value, self.gen_label, gen)
+ anno.setanno(cfg_node.value, self.kill_label, kill)
+ anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen)
+ if hash(anno.getanno(cfg_node.value, self.in_label)) != before:
+ for pred in cfg_node.prev:
+ self.visit(pred)
+
+
+def run_analyses(node, analyses):
+ """Perform dataflow analysis on all functions within an AST.
+
+ Args:
+ node: An AST node on which to run dataflow analysis.
+ analyses: Either an instance of the Forward or Backward dataflow analysis
+ class, or a list or tuple of them.
+
+ Returns:
+ node: The node, but now with annotations on the AST nodes containing the
+ results of the dataflow analyses.
+ """
+ if not isinstance(analyses, (tuple, list)):
+ analyses = (analyses,)
+ for analysis in analyses:
+ if not isinstance(analysis, (Forward, Backward)):
+ raise TypeError('not a valid forward analysis object')
+
+ for child_node in gast.walk(node):
+ if isinstance(child_node, gast.FunctionDef):
+ cfg_obj = CfgBuilder().build_cfg(child_node)
+ for analysis in analyses:
+ if isinstance(analysis, Backward):
+ analysis.visit(cfg_obj.exit)
+ elif isinstance(analysis, Forward):
+ analysis.visit(cfg_obj.entry)
+ for analysis in analyses:
+ PropagateAnalysis(analysis).visit(node)
+ return node
+
+
+class Liveness(Backward):
+ """Perform a liveness analysis.
+
+ Each statement is annotated with a set of variables that may be used
+ later in the program.
+ """
+
+ def __init__(self, context):
+ super(Liveness, self).__init__('live', context)
+
+ def get_gen_kill(self, node, _):
+ gen = activity.get_read(node.value, self.context)
+ kill = activity.get_updated(node.value, self.context)
+ return gen, kill
+
+
+class ReachingDefinitions(Forward):
+ """Perform reaching definition analysis.
+
+ Each statement is annotated with a set of (variable, definition) pairs.
+ """
+
+ def __init__(self, context):
+ super(ReachingDefinitions, self).__init__('definitions', context)
+
+ def get_gen_kill(self, node, incoming):
+ definitions = activity.get_updated(node.value, self.context)
+ gen = frozenset((id_, node.value) for id_ in definitions)
+ kill = frozenset(def_ for def_ in incoming if def_[0] in definitions)
+ return gen, kill
+
+
+class Defined(Forward):
+ """Perform defined variable analysis.
+
+ Each statement is annotated with a set of variables which are guaranteed to
+ be defined at that point.
+ """
+
+ def __init__(self, context):
+ super(Defined, self).__init__('defined', context, transfer_fn=operator.and_)
+
+ def get_gen_kill(self, node, _):
+ gen = activity.get_updated(node.value, self.context)
+ return gen, frozenset()
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
new file mode 100644
index 0000000000..af7eaf30e8
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
@@ -0,0 +1,252 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cfg module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import context
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct.static_analysis import cfg
+from tensorflow.python.platform import test
+
+
+class CFGTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
+ arg_types = arg_types or {}
+ node, source = parser.parse_entity(test_fn)
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=None,
+ recursive=True)
+ node = qual_names.resolve(node)
+ return node, ctx
+
+ def _check_anno_matches(self, node, anno_name, var_names):
+ if isinstance(var_names, str):
+ var_names = (var_names,)
+ qual_vars = set()
+ for var_name in var_names:
+ if isinstance(var_name, str):
+ if '[' in var_name or ']' in var_name:
+ raise ValueError('Annotation matching not supported with subscript.')
+ if '.' not in var_name:
+ qual_vars.add(qual_names.QN(var_name))
+ else:
+ attrs = var_name.split('.')
+ this_qn = functools.reduce(qual_names.QN, attrs[1:],
+ qual_names.QN(attrs[0]))
+ qual_vars.add(this_qn)
+ self.assertEqual(anno.getanno(node, anno_name), qual_vars)
+
+ def test_reaching(self):
+
+ def f(x):
+ print(x)
+ while True:
+ x = x
+ x = x
+ return x
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
+ body = node.body[0].body
+ # Only the argument reaches the expression
+ def_in = anno.getanno(body[0], 'definitions_in')
+ # One element, x, from arguments
+ self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,)))
+
+ while_body = body[1].body
+ def_in = anno.getanno(while_body[0], 'definitions_in')
+ # One definition, two possible sources.
+ # - One from an assignment (if the loop is entered)
+ # - The other from the arguments (if loop is not entered)
+ self.assertEqual(
+ set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))
+
+ def_in = anno.getanno(while_body[1], 'definitions_in')
+ # If we've reached this line, the only reaching definition of x is the
+ # Assign node in previous line
+ self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,)))
+
+ def_in = anno.getanno(body[2], 'definitions_in')
+ # Same situation as while_body[0]
+ self.assertEqual(
+ set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))
+
+ def test_defined(self):
+
+ def f(x):
+ if x:
+ y = 2 # pylint: disable=unused-variable
+ return x
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Defined(ctx))
+ body = node.body[0].body
+ # only x is for sure defined at the end
+ self._check_anno_matches(body[1], 'defined_in', 'x')
+ # at the end of the if body both x and y are defined
+ if_body = body[0].body
+ self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y'))
+
+ # TODO(alexbw): b/73926938 split this test up
+ def test_live(self):
+
+ def get_live_annotated_fnbody(f):
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Liveness(ctx))
+ body = node.body[0].body
+ return body
+
+ def f1(x):
+ a = g(x) # pylint: disable=undefined-variable
+ b = h(a) # pylint: disable=undefined-variable, unused-variable
+ return x
+
+ def f2(x, a): # pylint: disable=unused-argument
+ if a > 0: # x should not be live
+ x = 0
+ if a > 1:
+ x = 1
+ else:
+ x = 2
+
+ def f3(x, a):
+ if a > 0: # x and a should be live
+ x = 0
+ if a > 1: # x and a should be live_in
+ x = 1
+ return x # x should be live
+
+ def f4(x, a):
+ if a > 0: # x should be live
+ x = 0
+ x += 1
+
+ def f5(x, a):
+ if a > 0: # x.y should be live
+ x.y = 0
+ return x.y
+
+ def f6(x):
+ return x # should this cause x.* to be live?
+
+ def f7(x, n):
+ for i in range(n):
+ x += i
+ return x
+
+ def f8(x, f):
+ with f:
+ x += 1
+
+ body = get_live_annotated_fnbody(f1)
+ self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x'))
+ self._check_anno_matches(body[2], 'live_in', ('x'))
+ self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x'))
+ self._check_anno_matches(body[2], 'live_out', ())
+
+ body = get_live_annotated_fnbody(f2)
+ self._check_anno_matches(body[0], 'live_in', ('a'))
+ self._check_anno_matches(body[1], 'live_in', ('a'))
+
+ body = get_live_annotated_fnbody(f3)
+ self._check_anno_matches(body[0], 'live_in', ('a', 'x'))
+ self._check_anno_matches(body[1], 'live_in', ('a', 'x'))
+ self._check_anno_matches(body[2], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f4)
+ self._check_anno_matches(body[0], 'live_in', ('x', 'a'))
+ self._check_anno_matches(body[1], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f5)
+ self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a'))
+
+ body = get_live_annotated_fnbody(f6)
+ self._check_anno_matches(body[0], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f7)
+ self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range'))
+ self._check_anno_matches(body[1], 'live_in', ('x'))
+
+ body = get_live_annotated_fnbody(f8)
+ self._check_anno_matches(body[0], 'live_in', ('f', 'x'))
+
+ def test_node_equality(self):
+ node_a = gast.parse('y = x').body[0]
+ node_b = gast.parse('y = x').body[0]
+ self.assertNotEqual(node_a, node_b)
+
+ def test_nested_functions_defined(self):
+
+ def f(x):
+ y = x * 2
+
+ def g(z):
+ return z + y
+
+ return g(x)
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Defined(ctx))
+
+ body = node.body[0].body
+ self.assertEqual(
+ anno.getanno(body[2], 'defined_in'),
+ frozenset(map(qual_names.QN, ('g', 'x', 'y'))))
+
+ # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries.
+ # NOTE: 'z' is easy to find, but 'y' is not identified as
+ # defined, because CFG analysis is applied with each function separately.
+ # fndef_body = body[1].body
+ # self.assertEqual(
+ # anno.getanno(fndef_body[0], 'defined_in'),
+ # frozenset(map(qual_names.QN, ('z', 'y'))))
+
+ def test_nested_functions_dont_leak_definitions(self):
+
+ def f(x):
+ print(x)
+
+ def g():
+ y = 2
+ return y
+
+ return g() # y is not defined here
+
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.Defined(ctx))
+ body = node.body[0].body
+ self.assertEqual(
+ anno.getanno(body[2], 'defined_in'),
+ frozenset(map(qual_names.QN, ('x', 'g'))))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 9d6cc9245a..f06b73c00d 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -501,11 +501,18 @@ def sparse_make_stats_update(
example_partition_ids)
# Compute aggregate stats for each partition.
+ # Since unsorted_segment_sum can be numerically unstable, use 64bit
+ # operation.
+ gradients64 = math_ops.cast(gradients, dtypes.float64)
+ hessians64 = math_ops.cast(hessians, dtypes.float64)
per_partition_gradients = math_ops.unsorted_segment_sum(
- gradients, mapped_partitions, array_ops.size(unique_partitions))
+ gradients64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
- hessians, mapped_partitions, array_ops.size(unique_partitions))
-
+ hessians64, mapped_partitions, array_ops.size(unique_partitions))
+ per_partition_gradients = math_ops.cast(per_partition_gradients,
+ dtypes.float32)
+ per_partition_hessians = math_ops.cast(per_partition_hessians,
+ dtypes.float32)
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
bias_feature_ids = array_ops.fill(
diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
index 1b184d296b..50cc00afdc 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
@@ -187,7 +187,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
stamp_token: Expected current token.
next_stamp_token: Next value for the token.
Returns:
- A list of quantiles or approximate boundaries.
+ The flush operation.
"""
return gen_quantile_ops.quantile_accumulator_flush(
quantile_accumulator_handle=self._quantile_accumulator_handle,
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index d2c30f1215..e529b25b3c 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -19,6 +19,7 @@ For creating and managing dependencies:
@@CheckpointableObjectGraph
@@dot_graph_from_checkpoint
@@object_metadata
+@@NoDependency
@@split_dependency
"""
@@ -29,6 +30,7 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
+from tensorflow.python.training.checkpointable import NoDependency
from tensorflow.python.training.checkpointable_utils import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 6588fd04ac..2568b899d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -427,7 +427,9 @@ class BatchDatasetTest(test.TestCase):
self.assertEqual([None], dataset.output_shapes[1][0].as_list())
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
- def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1):
+ def _testMapAndBatchDatasetHelper(self,
+ num_parallel_calls=None,
+ num_parallel_batches=None):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
@@ -446,6 +448,7 @@ class BatchDatasetTest(test.TestCase):
batching.map_and_batch(
map_func=_map_fn,
batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
num_parallel_batches=num_parallel_batches))
.make_initializable_iterator())
init_op = iterator.initializer
@@ -497,12 +500,18 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
- def testMapAndBatchDataset(self):
+ def testMapAndBatch(self):
return self._testMapAndBatchDatasetHelper()
- def testMapAndBatchDatasetWithParallelBatching(self):
+ def testMapAndBatchWithParallelBatches(self):
return self._testMapAndBatchDatasetHelper(num_parallel_batches=10)
+ def testMapAndBatchWithSequentialCalls(self):
+ return self._testMapAndBatchDatasetHelper(num_parallel_calls=1)
+
+ def testMapAndBatchWithParallelCalls(self):
+ return self._testMapAndBatchDatasetHelper(num_parallel_calls=2)
+
def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False):
iterator = (
dataset_ops.Dataset.range(10).apply(
@@ -682,7 +691,7 @@ class UnbatchDatasetSerializationTest(
class MapAndBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
- def testSerializationCore(self):
+ def testNumParallelBatches(self):
range_size = 11
num_repeats = 2
batch_size = 5
@@ -709,6 +718,33 @@ class MapAndBatchDatasetSerializationTest(
self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
num_outputs_drop_remainder)
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 42ec2b0b01..b9393de4e9 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -466,14 +466,14 @@ def assert_element_shape(expected_shapes):
class _MapAndBatchDataset(dataset_ops.MapDataset):
"""A `Dataset` that maps a function over a batch of elements."""
- def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches,
+ def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
drop_remainder):
"""See `Dataset.map()` for details."""
super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
self._batch_size_t = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
- self._num_parallel_batches_t = ops.convert_to_tensor(
- num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches")
+ self._num_parallel_calls_t = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
self._drop_remainder_t = ops.convert_to_tensor(
drop_remainder, dtype=dtypes.bool, name="drop_remainder")
@@ -483,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.map_and_batch_dataset(
+ return gen_dataset_ops.map_and_batch_dataset_v2(
input_resource,
self._map_func.captured_inputs,
f=self._map_func,
batch_size=self._batch_size_t,
- num_parallel_batches=self._num_parallel_batches_t,
+ num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
@@ -511,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset):
def map_and_batch(map_func,
batch_size,
- num_parallel_batches=1,
- drop_remainder=False):
+ num_parallel_batches=None,
+ drop_remainder=False,
+ num_parallel_calls=None):
"""Fused implementation of `map` and `batch`.
Maps `map_func` across `batch_size` consecutive elements of this dataset
@@ -528,21 +529,37 @@ def map_and_batch(map_func,
nested structure of tensors.
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
- num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the
- number of batches to create in parallel. On one hand, higher values can
- help mitigate the effect of stragglers. On the other hand, higher values
- can increase contention if CPU is scarce.
- drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the
- last batch should be dropped in case its size is smaller than desired;
- the default behavior is not to drop the smaller batch.
+ num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
+ representing the number of batches to create in parallel. On one hand,
+ higher values can help mitigate the effect of stragglers. On the other
+ hand, higher values can increase contention if CPU is scarce.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in case its size is smaller than
+ desired; the default behavior is not to drop the smaller batch.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of elements to process in parallel. If not
+ specified, `batch_size * num_parallel_batches` elements will be
+ processed in parallel.
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
+
+ Raises:
+ ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
+ specified.
"""
+ if num_parallel_batches is None and num_parallel_calls is None:
+ num_parallel_calls = batch_size
+ elif num_parallel_batches is not None and num_parallel_calls is None:
+ num_parallel_calls = batch_size * num_parallel_batches
+ elif num_parallel_batches is not None and num_parallel_calls is not None:
+ raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
+ "arguments are mutually exclusive.")
+
def _apply_fn(dataset):
return _MapAndBatchDataset(dataset, map_func, batch_size,
- num_parallel_batches, drop_remainder)
+ num_parallel_calls, drop_remainder)
return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 946310aa6f..45d191127e 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -265,6 +265,10 @@ class NamedDistribution(object):
one_device_strategy = NamedDistribution(
"OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
None)
+tpu_strategy_single_iteration = NamedDistribution(
+ "TPUSingleIteration",
+ tpu_strategy.TPUStrategy(iterations_per_step=1),
+ required_tpu=True)
tpu_strategy = NamedDistribution(
"TPU", tpu_strategy.TPUStrategy(), required_tpu=True)
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
index b87224251c..2b05884b9b 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
+++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""An example tf.keras model that is trained using MirroredStrategy."""
+"""An example of training tf.keras Model using MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sys import argv
+
+import sys
+
import numpy as np
import tensorflow as tf
@@ -33,30 +35,37 @@ def input_fn():
def main(args):
if len(args) < 2:
- print('You must specify model_dir for checkpoints such as'
- ' /tmp/tfkeras_example./')
+ print('You must specify model_dir for checkpoints such as'
+ ' /tmp/tfkeras_example/.')
return
- print('Using %s to store checkpoints.' % args[1])
-
- strategy = tf.contrib.distribute.MirroredStrategy(
- ['/device:GPU:0', '/device:GPU:1'])
- config = tf.estimator.RunConfig(train_distribute=strategy)
- optimizer = tf.train.GradientDescentOptimizer(0.2)
+ model_dir = args[1]
+ print('Using %s to store checkpoints.' % model_dir)
+ # Define tf.keras Model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
+ # Compile tf.keras Model.
+ optimizer = tf.train.GradientDescentOptimizer(0.2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
tf.keras.backend.set_learning_phase(True)
+
+ # Define a DistributionStrategy and convert the tf.keras Model to a
+ # tf.Estimator that utilizes the DistributionStrategy.
+ strategy = tf.contrib.distribute.MirroredStrategy(
+ ['/device:GPU:0', '/device:GPU:1'])
+ config = tf.estimator.RunConfig(train_distribute=strategy)
keras_estimator = tf.keras.estimator.model_to_estimator(
- keras_model=model, config=config, model_dir=args[1])
+ keras_model=model, config=config, model_dir=model_dir)
+ # Train and evaluate the tf.Estimator.
keras_estimator.train(input_fn=input_fn, steps=10)
eval_result = keras_estimator.evaluate(input_fn=input_fn)
print('Eval result: {}'.format(eval_result))
+
if __name__ == '__main__':
- tf.app.run(argv=argv)
+ tf.app.run(argv=sys.argv)
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index e134fe34e1..d2054715f1 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -44,13 +44,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
combinations.distributions_and_v1_optimizers(),
combinations.combine(mode=["graph"], use_callable_loss=[True, False])
+ combinations.combine(mode=["eager"], use_callable_loss=[True]),
- combinations.combine(is_tpu=[False])) +
- combinations.combine(
- distribution=[combinations.tpu_strategy],
- optimizer_fn=[combinations.adam_optimizer_v1_fn],
- mode=["graph"],
- use_callable_loss=[False],
- is_tpu=[True]))
+ combinations.combine(is_tpu=[False])) + combinations.combine(
+ distribution=[combinations.tpu_strategy],
+ optimizer_fn=[
+ combinations.adam_optimizer_v1_fn,
+ # TODO(isaprykin): Make Adam v2 work with while_loops
+ # and TPUs.
+ ],
+ mode=["graph"],
+ use_callable_loss=[False],
+ is_tpu=[True]))
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
is_tpu):
with distribution.scope():
@@ -101,7 +104,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution=[combinations.tpu_strategy],
optimizer_fn=[
combinations.adam_optimizer_v1_fn,
- combinations.gradient_descent_optimizer_v1_fn
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn,
],
mode=["graph"],
is_tpu=[True]))
@@ -171,13 +175,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
set(created_variables))
@combinations.generate(
- combinations.times(combinations.distributions_and_v1_optimizers(),
- combinations.combine(
- mode=["graph", "eager"],
- momentum=[0.8, 0.9, 0.99],
- renorm=[False, True])))
+ combinations.times(
+ combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]),
+ combinations.times(
+ combinations.distributions_and_v1_optimizers(),
+ combinations.combine(
+ mode=["graph", "eager"],
+ is_tpu=[False],
+ # TODO(isaprykin): Allow False here. Currently subsequent
+ # towers will re-execute UPDATE_OPS of previous towers.
+ update_ops_in_cross_tower_mode=[True])) +
+ combinations.combine(
+ distribution=[combinations.tpu_strategy_single_iteration],
+ optimizer_fn=[
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn
+ ],
+ mode=["graph"],
+ is_tpu=[True],
+ update_ops_in_cross_tower_mode=[False])))
def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
- renorm):
+ renorm, is_tpu,
+ update_ops_in_cross_tower_mode):
"""Verifies that moving mean updates are reduced across towers."""
with distribution.scope():
num_towers = len(distribution.worker_devices)
@@ -185,7 +204,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
optimizer_fn,
batch_per_epoch=num_towers,
momentum=momentum,
- renorm=renorm)
+ renorm=renorm,
+ update_ops_in_tower_mode=not update_ops_in_cross_tower_mode)
# Disable prefetching since that makes the specific input on each device
# to be non deterministic, and this test relies on specific input being
@@ -196,16 +216,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
dataset_fn).make_one_shot_iterator()
def run_step():
- return control_flow_ops.group(
- distribution.unwrap(
- distribution.call_for_each_tower(
- model_fn,
- iterator.get_next(),
- run_concurrently=batchnorm.built)) +
- ops.get_collection(ops.GraphKeys.UPDATE_OPS))
+ fetches = distribution.unwrap(
+ distribution.call_for_each_tower(
+ model_fn, iterator.get_next(),
+ run_concurrently=batchnorm.built))
+ if update_ops_in_cross_tower_mode:
+ fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
+ return control_flow_ops.group(fetches)
if not context.executing_eagerly():
with self.test_session() as sess:
+ if is_tpu:
+ sess.run(tpu.initialize_system())
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -229,22 +251,40 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)
+ if is_tpu:
+ with self.test_session() as sess:
+ sess.run(tpu.shutdown_system())
+
@combinations.generate(
combinations.times(
combinations.combine(
- distribution=[combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus],
- optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn,
- combinations.gradient_descent_optimizer_v2_fn],
- loss_reduction=[losses_impl.Reduction.SUM,
- losses_impl.Reduction.MEAN,
- losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
- losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]),
- combinations.combine(mode=["graph"], use_callable_loss=[True, False])
- + combinations.combine(mode=["eager"], use_callable_loss=[True])))
+ optimizer_fn=[
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn
+ ],
+ loss_reduction=[
+ losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN,
+ losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
+ losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS
+ ]),
+ combinations.times(
+ combinations.combine(
+ distribution=[
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ is_tpu=[False]),
+ combinations.combine(
+ mode=["graph"], use_callable_loss=[True, False]) +
+ combinations.combine(mode=["eager"], use_callable_loss=[True])) +
+ combinations.combine(
+ distribution=[combinations.tpu_strategy_single_iteration],
+ is_tpu=[True],
+ mode=["graph"],
+ use_callable_loss=[True, False])))
def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction,
- use_callable_loss):
+ use_callable_loss, is_tpu):
with distribution.scope():
all_vars = []
@@ -280,12 +320,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.test_session() as sess:
+ if is_tpu:
+ sess.run(tpu.initialize_system())
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
run_step()
- self.assertEqual(distribution.num_towers, len(all_vars))
v = all_vars[0]
self.assertTrue(all([v is vi for vi in all_vars[1:]]))
weight = numpy.squeeze(self.evaluate(distribution.fetch(v)))
@@ -312,6 +353,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# One of the mean loss reductions.
self.assertNear(weight, 2 + 10.6, 0.0001)
+ if is_tpu:
+ with self.test_session() as sess:
+ sess.run(tpu.shutdown_system())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 6c5c055070..3635bd2e34 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -370,22 +370,27 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
expected_sum = 0.0
expected_mean = 0.0
for i, d in enumerate(dist.worker_devices):
- # Test access within a device scope, should see different values.
- with ops.device(d):
- v_sum_value = self.evaluate(ret_v_sum.read_value())
- v_mean_value = self.evaluate(ret_v_mean.read_value())
- expected = i + 3.0
- self.assertEqual(expected, v_sum_value)
- expected_sum += expected
- expected = i * 6.0
- self.assertEqual(expected, v_mean_value)
- expected_mean += expected
-
- # fetch() should return the value you get by applying the
- # reduction across all towers.
- self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
+ # Should see different values on different devices.
+ v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
+ v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
+ expected = i + 3.0
+ self.assertEqual(expected, v_sum_value)
+ expected_sum += expected
+ expected = i * 6.0
+ self.assertEqual(expected, v_mean_value)
+ expected_mean += expected
expected_mean /= len(dist.worker_devices)
+
+ # Without get(device), should return the value you get by
+ # applying the reduction across all towers (whether you use
+ # fetch(), get(), or nothing).
+ self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean)))
+ self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
+ self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
+ if not context.executing_eagerly():
+ self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
+ self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
# testing this in eager mode.
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
index 0db0b59fca..d1fdb3279c 100644
--- a/tensorflow/contrib/distribute/python/single_loss_example.py
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import step_fn
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.layers import normalization
from tensorflow.python.ops import array_ops
@@ -59,7 +60,7 @@ def minimize_loss_example(optimizer_fn,
# TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
return dataset.apply(
- batching.map_and_batch(lambda x: x, batch_size=2, drop_remainder=True))
+ batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True))
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
@@ -68,11 +69,10 @@ def minimize_loss_example(optimizer_fn,
layer = core.Dense(1, use_bias=use_bias)
- def model_fn(xs):
+ def model_fn(x):
"""A very simple model written by the user."""
def loss_fn():
- x = math_ops.reduce_mean(xs, keepdims=True)
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
return y * y
@@ -89,7 +89,8 @@ def minimize_loss_example(optimizer_fn,
def batchnorm_example(optimizer_fn,
batch_per_epoch=1,
momentum=0.9,
- renorm=False):
+ renorm=False,
+ update_ops_in_tower_mode=False):
"""Example of non-distribution-aware legacy code with batch normalization."""
def dataset_fn():
@@ -103,12 +104,19 @@ def batchnorm_example(optimizer_fn,
optimizer = optimizer_fn()
batchnorm = normalization.BatchNormalization(
renorm=renorm, momentum=momentum, fused=False)
+ layer = core.Dense(1, use_bias=False)
def model_fn(x):
+ """A model that uses batchnorm."""
def loss_fn():
- y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1)
- loss = math_ops.reduce_mean(y - constant_op.constant(1.))
+ y = batchnorm(x, training=True)
+ with ops.control_dependencies(
+ ops.get_collection(ops.GraphKeys.UPDATE_OPS)
+ if update_ops_in_tower_mode else []):
+ loss = math_ops.reduce_mean(
+ math_ops.reduce_sum(layer(y)) - constant_op.constant(1.))
+ # `x` and `y` will be fetched by the gradient computation, but not `loss`.
return loss
# Callable loss.
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index a7e4fe80f3..75441786a6 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util import nest
-# TODO(isaprykin): Consider whether inheriting is really appropriate.
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
@@ -73,7 +72,6 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def infeed_input(i):
"""Get input, split it and then enqueue."""
iteration_inputs = [f.get(i) for f in feeds()]
-
infeed_inputs = [[inputs_per_core[core_id]
for inputs_per_core in iteration_inputs]
for core_id in range(self._num_cores_per_host)]
@@ -117,3 +115,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
iterate_on_tpu, [], num_shards=self._num_cores_per_host)
return control_flow_ops.group(tpu_result, enqueue_ops)
+
+ def _reduce(self, method_string, value, destinations):
+ del destinations # TPU is graph mode only. Rely on implicit Send/Recv.
+ if method_string == 'mean':
+ # TODO(jhseu): Revisit once we support model-parallelism.
+ value *= (1. / self._num_cores_per_host)
+ return tpu_ops.cross_replica_sum(value)
+
+ @property
+ def num_towers(self):
+ return self._num_cores_per_host
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index aaf177d07e..759f3c3599 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.training import checkpointable
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -60,7 +61,7 @@ class DistributedValues(object):
else:
device = distribute_lib.get_update_device()
if device is None:
- device = device_util.current()
+ return self._get_cross_tower()
device = device_util.canonicalize(device)
try:
return self._index[device]
@@ -231,12 +232,6 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
- def _as_graph_element(self):
- # pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
- return self._primary_var._as_graph_element()
- return self.get()._as_graph_element()
-
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
@@ -320,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored,
def assign(self, *args, **kwargs):
return self.get(device=_get_update_device()).assign(*args, **kwargs)
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return array_ops.identity(self._index[device])
+ return array_ops.identity(self._primary_var)
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
@@ -364,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access
+def _assert_tower_context():
+ if not distribute_lib.get_tower_context():
+ raise RuntimeError(
+ "Tower-local variables may only be assigned in a tower context.")
+
+
class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
@@ -374,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign_sub(*args, **kwargs)
def assign_add(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign(*args, **kwargs)
@property
def reduce_method(self):
return self._reduce_method
+ def _get_cross_tower(self):
+ all_components = tuple(self._index.values())
+ # TODO(josh11b): Use a strategy-specific method.
+ total = math_ops.add_n(all_components)
+ if self._reduce_method == "mean":
+ return total * (1./ len(all_components))
+ return total
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._get_cross_tower()
+ return self.get()._as_graph_element()
+
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
@@ -672,11 +702,12 @@ class MultiWorkerDataset(object):
return MultiWorkerDataIterator(iterators, self._worker_device_map)
-class PerIteration(object):
- """Holds input for multiple iterations at once."""
+class _PerKey(object):
+ """Holds data associated by keys."""
- def __init__(self, index):
- self._index = index
+ def __init__(self, *index):
+ # pylint: disable=protected-access
+ self._index = list(index)
def get(self, iteration):
return array_ops.gather(self._index, iteration)
@@ -687,6 +718,24 @@ class PerIteration(object):
def get_dtype(self):
return self._index[-1][-1].dtype
+ def __str__(self):
+ return "%s:%s" % (self.__class__.__name__, self._index)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self._index)
+
+
+class PerIteration(_PerKey):
+ """Holds input for multiple iterations at once."""
+
+ def __init__(self, *index):
+ # pylint: disable=protected-access
+ super(PerIteration, self).__init__(*[batch._index for batch in index])
+
+
+class Batches(_PerKey):
+ pass
+
class MultiIterator(object):
"""Iterator that returns results of multiple get_next()s."""
@@ -697,11 +746,31 @@ class MultiIterator(object):
self._batches_per_iteration = batches_per_iteration
def get_next(self, name=None):
- return PerIteration([[
- self._dataset_iterator.get_next(name=name)
- for _ in range(self._batches_per_iteration)
- ]
- for _ in range(self._iterations)])
+ """Return PerIteration with `iterations x batches_per_iteration` inputs."""
+ data = []
+ for _ in range(self._batches_per_iteration):
+ batch = []
+ for _ in range(self._iterations):
+ batch.append(self._dataset_iterator.get_next(name=name))
+ data.append(batch)
+
+ # Here is an example. Suppose each get_next returns a tuple of two tensors.
+ # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is:
+ # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]]
+ #
+ # After the first `map_structure` it gets transformed to:
+ # [(Batches(a, A), Batches(z, Z)),
+ # (Batches(b, B), Batches(y, Y)),
+ # (Batches(c, C), Batches(x, X))]
+ #
+ # After the second `map_structure` it gets transformed to a tuple of:
+ # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]),
+ # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)]))
+
+ data = nest.map_structure(Batches, *data)
+ data = nest.map_structure(PerIteration, *data)
+
+ return data
@property
def initializer(self):
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index fad613155d..a1d56066b4 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -372,6 +372,7 @@ cuda_py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
],
+ shard_count = 4,
)
cuda_py_test(
@@ -459,7 +460,7 @@ cuda_py_test(
cuda_py_test(
name = "batch_reshape_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/batch_reshape_test.py"],
additional_deps = [
":distributions_py",
@@ -578,7 +579,7 @@ cuda_py_test(
cuda_py_test(
name = "wishart_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/wishart_test.py"],
additional_deps = [
":distributions_py",
@@ -866,7 +867,7 @@ cuda_py_test(
cuda_py_test(
name = "batch_normalization_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"],
additional_deps = [
":bijectors_py",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index ca20442c39..dc45114b1c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
@@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase):
-np.log(6, dtype=np.float32) - np.sum(x),
self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
+ def testChainIldjWithPlaceholder(self):
+ chain = Chain((Exp(), Exp()))
+ samples = array_ops.placeholder(
+ dtype=np.float32, shape=[None, 10], name="samples")
+ ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
+ self.assertTrue(ildj is not None)
+ with self.test_session():
+ ildj.eval({samples: np.zeros([2, 10], np.float32)})
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
index 7435bcbc68..b003526392 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
@@ -131,8 +131,8 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
return mu, sigma
def testKLBatch(self):
- batch_shape = (2,)
- event_shape = (3,)
+ batch_shape = [2]
+ event_shape = [3]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
@@ -156,6 +156,33 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
self.assertAllClose(expected_kl_0, kl_v[0])
self.assertAllClose(expected_kl_1, kl_v[1])
+ def testKLBatchBroadcast(self):
+ batch_shape = [2]
+ event_shape = [3]
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ # No batch shape.
+ mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
+ mvn_a = ds.MultivariateNormalFullCovariance(
+ loc=mu_a,
+ covariance_matrix=sigma_a,
+ validate_args=True)
+ mvn_b = ds.MultivariateNormalFullCovariance(
+ loc=mu_b,
+ covariance_matrix=sigma_b,
+ validate_args=True)
+
+ kl = ds.kl_divergence(mvn_a, mvn_b)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
+ mu_b, sigma_b)
+ expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
+ mu_b, sigma_b)
+ self.assertAllClose(expected_kl_0, kl_v[0])
+ self.assertAllClose(expected_kl_1, kl_v[1])
+
def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
"""Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
index 685f32883d..b556d06123 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
@@ -235,8 +235,8 @@ class MultivariateNormalTriLTest(test.TestCase):
return mu, sigma
def testKLNonBatch(self):
- batch_shape = ()
- event_shape = (2,)
+ batch_shape = []
+ event_shape = [2]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
@@ -257,8 +257,8 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_kl, kl_v)
def testKLBatch(self):
- batch_shape = (2,)
- event_shape = (3,)
+ batch_shape = [2]
+ event_shape = [3]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
@@ -282,9 +282,36 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_kl_0, kl_v[0])
self.assertAllClose(expected_kl_1, kl_v[1])
+ def testKLBatchBroadcast(self):
+ batch_shape = [2]
+ event_shape = [3]
+ with self.test_session():
+ mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
+ # No batch shape.
+ mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
+ mvn_a = ds.MultivariateNormalTriL(
+ loc=mu_a,
+ scale_tril=np.linalg.cholesky(sigma_a),
+ validate_args=True)
+ mvn_b = ds.MultivariateNormalTriL(
+ loc=mu_b,
+ scale_tril=np.linalg.cholesky(sigma_b),
+ validate_args=True)
+
+ kl = ds.kl_divergence(mvn_a, mvn_b)
+ self.assertEqual(batch_shape, kl.get_shape())
+
+ kl_v = kl.eval()
+ expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
+ mu_b, sigma_b)
+ expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
+ mu_b, sigma_b)
+ self.assertAllClose(expected_kl_0, kl_v[0])
+ self.assertAllClose(expected_kl_1, kl_v[1])
+
def testKLTwoIdenticalDistributionsIsZero(self):
- batch_shape = (2,)
- event_shape = (3,)
+ batch_shape = [2]
+ event_shape = [3]
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = ds.MultivariateNormalTriL(
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
index 85ad23e413..b158a51bb0 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
@@ -20,10 +20,9 @@ from __future__ import print_function
import itertools
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
@@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims):
return input_tensor.shape.is_fully_defined() and isinstance(ndims, int)
-def _maybe_get_event_ndims_statically(event_ndims):
- static_event_ndims = (event_ndims if isinstance(event_ndims, int)
- else tensor_util.constant_value(event_ndims))
- if static_event_ndims is not None:
- return static_event_ndims
-
- return event_ndims
-
-
def _compute_min_event_ndims(bijector_list, compute_forward=True):
"""Computes the min_event_ndims associated with the give list of bijectors.
@@ -238,13 +228,13 @@ class Chain(bijector.Bijector):
return y
def _inverse_log_det_jacobian(self, y, **kwargs):
- ildj = constant_op.constant(
- 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian")
+ y = ops.convert_to_tensor(y, name="y")
+ ildj = math_ops.cast(0., dtype=y.dtype.base_dtype)
if not self.bijectors:
return ildj
- event_ndims = _maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_event_ndims_statically(
self.inverse_min_event_ndims)
if _use_static_shape(y, event_ndims):
@@ -258,11 +248,12 @@ class Chain(bijector.Bijector):
if _use_static_shape(y, event_ndims):
event_shape = b.inverse_event_shape(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_event_ndims_statically(
+ event_shape.ndims)
else:
event_shape = b.inverse_event_shape_tensor(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(
- array_ops.rank(event_shape))
+ event_ndims = self._maybe_get_event_ndims_statically(
+ array_ops.size(event_shape))
y = b.inverse(y, **kwargs.get(b.name, {}))
return ildj
@@ -274,13 +265,12 @@ class Chain(bijector.Bijector):
def _forward_log_det_jacobian(self, x, **kwargs):
x = ops.convert_to_tensor(x, name="x")
- fldj = constant_op.constant(
- 0., dtype=x.dtype, name="inverse_log_det_jacobian")
+ fldj = math_ops.cast(0., dtype=x.dtype.base_dtype)
if not self.bijectors:
return fldj
- event_ndims = _maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_event_ndims_statically(
self.forward_min_event_ndims)
if _use_static_shape(x, event_ndims):
@@ -293,13 +283,21 @@ class Chain(bijector.Bijector):
x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
if _use_static_shape(x, event_ndims):
event_shape = b.forward_event_shape(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
else:
event_shape = b.forward_event_shape_tensor(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(
- array_ops.rank(event_shape))
+ event_ndims = self._maybe_get_event_ndims_statically(
+ array_ops.size(event_shape))
x = b.forward(x, **kwargs.get(b.name, {}))
return fldj
+ def _maybe_get_event_ndims_statically(self, event_ndims):
+ event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
+ event_ndims)
+ if event_ndims_ is None:
+ return event_ndims
+ return event_ndims_
+
+
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 8517a3bf7b..b8f352d5f5 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -36,9 +36,7 @@ def device_and_data_format():
'channels_last')
-def random_batch(batch_size, device_and_format=None):
- _, data_format = device_and_format or device_and_data_format()
-
+def random_batch(batch_size, data_format):
shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
shape = (batch_size,) + shape
@@ -70,7 +68,7 @@ class ResNet50Test(tf.test.TestCase):
if defun:
model.call = tfe.defun(model.call)
with tf.device(device), tfe.execution_mode(execution_mode):
- images, _ = random_batch(2)
+ images, _ = random_batch(2, data_format)
output = model(images, training=False)
tfe.async_wait()
self.assertEqual((2, 1000), output.shape)
@@ -91,7 +89,7 @@ class ResNet50Test(tf.test.TestCase):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
with tf.device(device):
- images, _ = random_batch(2)
+ images, _ = random_batch(2, data_format)
output = model(images, training=False)
output_shape = ((2, 2048, 1, 1)
if data_format == 'channels_first' else (2, 1, 1, 2048))
@@ -101,7 +99,7 @@ class ResNet50Test(tf.test.TestCase):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
with tf.device(device):
- images, _ = random_batch(2)
+ images, _ = random_batch(2, data_format)
output = model(images, training=False)
self.assertEqual((2, 2048), output.shape)
@@ -115,7 +113,7 @@ class ResNet50Test(tf.test.TestCase):
name='t0').as_default(), tf.contrib.summary.always_record_summaries():
with tf.device(device), tfe.execution_mode(execution_mode):
optimizer = tf.train.GradientDescentOptimizer(0.1)
- images, labels = random_batch(2)
+ images, labels = random_batch(2, data_format)
train_one_step(model, images, labels, optimizer)
self.assertEqual(320, len(model.variables))
tfe.async_wait()
@@ -134,7 +132,7 @@ class ResNet50Test(tf.test.TestCase):
model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
- images, labels = random_batch(2)
+ images, labels = random_batch(2, data_format)
gc.disable()
# Warm up. Note that this first run does create significant amounts of
# garbage to be collected. The hope is that this is a build-only effect,
@@ -202,18 +200,18 @@ class ResNet50Benchmarks(tf.test.Benchmark):
# which forces a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
- def _benchmark_eager_apply(self, label, defun=False, execution_mode=None,
- device_and_format=None):
+ def _benchmark_eager_apply(self, label, device_and_format, defun=False,
+ execution_mode=None, compiled=False):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_format or device_and_data_format()
+ device, data_format = device_and_format
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call)
+ model.call = tfe.defun(model.call, compiled=compiled)
batch_size = 64
num_burn = 5
num_iters = 30
with tf.device(device):
- images, _ = random_batch(batch_size, device_and_format)
+ images, _ = random_batch(batch_size, data_format)
for _ in xrange(num_burn):
model(images, training=False).cpu()
if execution_mode:
@@ -227,30 +225,34 @@ class ResNet50Benchmarks(tf.test.Benchmark):
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_apply_sync(self):
- self._benchmark_eager_apply('eager_apply', defun=False)
+ self._benchmark_eager_apply('eager_apply', device_and_data_format(),
+ defun=False)
def benchmark_eager_apply_async(self):
self._benchmark_eager_apply(
- 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC)
+ 'eager_apply_async', device_and_data_format(), defun=False,
+ execution_mode=tfe.ASYNC)
def benchmark_eager_apply_with_defun(self):
- self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
+ self._benchmark_eager_apply('eager_apply_with_defun',
+ device_and_data_format(), defun=True)
def _benchmark_eager_train(self,
label,
make_iterator,
+ device_and_format,
defun=False,
execution_mode=None,
- device_and_format=None):
+ compiled=False):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_format or device_and_data_format()
+ device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
- (images, labels) = random_batch(batch_size, device_and_format)
+ (images, labels) = random_batch(batch_size, data_format)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call)
+ model.call = tfe.defun(model.call, compiled=compiled)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
@@ -273,18 +275,21 @@ class ResNet50Benchmarks(tf.test.Benchmark):
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_train_sync(self):
- self._benchmark_eager_train('eager_train', MockIterator, defun=False)
+ self._benchmark_eager_train('eager_train', MockIterator,
+ device_and_data_format(), defun=False)
def benchmark_eager_train_async(self):
self._benchmark_eager_train(
'eager_train_async',
MockIterator,
+ device_and_data_format(),
defun=False,
execution_mode=tfe.ASYNC)
def benchmark_eager_train_with_defun(self):
self._benchmark_eager_train(
- 'eager_train_with_defun', MockIterator, defun=True)
+ 'eager_train_with_defun', MockIterator,
+ device_and_data_format(), defun=True)
def benchmark_eager_train_datasets(self):
@@ -294,7 +299,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return tfe.Iterator(ds)
self._benchmark_eager_train(
- 'eager_train_dataset', make_iterator, defun=False)
+ 'eager_train_dataset', make_iterator,
+ device_and_data_format(), defun=False)
def benchmark_eager_train_datasets_with_defun(self):
@@ -304,7 +310,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return tfe.Iterator(ds)
self._benchmark_eager_train(
- 'eager_train_dataset_with_defun', make_iterator, defun=True)
+ 'eager_train_dataset_with_defun', make_iterator,
+ device_and_data_format(), defun=True)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py
index e80ccbb74d..db50b33af2 100644
--- a/tensorflow/contrib/eager/python/tfe_test.py
+++ b/tensorflow/contrib/eager/python/tfe_test.py
@@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase):
return math_ops.multiply(x, x)
grad = tfe.gradients_function(square)
- self.assertEquals([6], [x.numpy() for x in grad(3)])
+ self.assertEquals([6], [x.numpy() for x in grad(3.)])
def testGradOfGrad(self):
@@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
grad = tfe.gradients_function(square)
gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
- self.assertEquals([2], [x.numpy() for x in gradgrad(3)])
+ self.assertEquals([2], [x.numpy() for x in gradgrad(3.)])
def testCustomGrad(self):
@@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase):
return y, grad_fn
grad = tfe.gradients_function(f)
- self.assertEquals([12], [x.numpy() for x in grad(3)])
+ self.assertEquals([12], [x.numpy() for x in grad(3.)])
def testGPU(self):
if tfe.num_gpus() <= 0:
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 571e2e3a5d..e9a68801ef 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -17,6 +17,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":export",
":extenders",
":head",
":linear",
@@ -181,6 +182,43 @@ py_test(
)
py_library(
+ name = "export",
+ srcs = [
+ "python/estimator/export.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_test(
+ name = "export_test",
+ size = "medium",
+ srcs = ["python/estimator/export_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"], # b/62863147
+ deps = [
+ ":export",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:export_output",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
+ ],
+)
+
+py_library(
name = "head",
srcs = [
"python/estimator/head.py",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index d43b3ea6bf..ec502f86dd 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
+from tensorflow.contrib.estimator.python.estimator.export import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
from tensorflow.contrib.estimator.python.estimator.linear import *
@@ -56,6 +57,8 @@ _allowed_symbols = [
'TowerOptimizer',
'RNNClassifier',
'RNNEstimator',
+ 'export_saved_model_for_mode',
+ 'export_all_saved_models',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py
new file mode 100644
index 0000000000..e7e366a3f2
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/export.py
@@ -0,0 +1,216 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrapper for methods to export train/eval graphs from Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator import model_fn as model_fn_lib
+
+
+def export_saved_model_for_mode(
+ estimator, export_dir_base, input_receiver_fn,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ strip_default_attrs=False,
+ mode=model_fn_lib.ModeKeys.PREDICT):
+ # pylint: disable=line-too-long
+ """Exports a single train/eval/predict graph as a SavedModel.
+
+ For a detailed guide, see
+ @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+
+ Sample usage:
+ ```python
+ classifier = tf.estimator.LinearClassifier(
+ feature_columns=[age, language])
+ classifier.train(input_fn=input_fn, steps=1000)
+
+ feature_spec = {
+ 'age': tf.placeholder(dtype=tf.int64),
+ 'language': array_ops.placeholder(dtype=tf.string)
+ }
+ label_spec = tf.placeholder(dtype=dtypes.int64)
+
+ train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
+ feature_spec, label_spec)
+
+ export_dir = tf.contrib.estimator.export_saved_model_for_mode(
+ classifier,
+ export_dir_base='my_model/',
+ input_receiver_fn=train_rcvr_fn,
+ mode=model_fn_lib.ModeKeys.TRAIN)
+
+ # export_dir is a timestamped directory with the SavedModel, which
+ # can be used for serving, analysis with TFMA, or directly loaded in.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ ...
+ ```
+
+ This method takes an input_receiver_fn and mode. For the mode passed in,
+ this method builds a new graph by calling the input_receiver_fn to obtain
+ feature and label `Tensor`s. Next, this method calls the `Estimator`'s
+ model_fn in the passed mode to generate the model graph based on
+ those features and labels, and restores the given checkpoint
+ (or, lacking that, the most recent checkpoint) into the graph.
+ Finally, it creates a timestamped export directory below the
+ export_dir_base, and writes a `SavedModel` into it containing
+ the `MetaGraphDef` for the given mode and its associated signatures.
+
+ For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
+ for each element of the export_outputs dict returned from the model_fn,
+ named using the same keys. One of these keys is always
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ signature will be served when a serving request does not specify one.
+ For each signature, the outputs are provided by the corresponding
+ `ExportOutput`s, and the inputs are always the input receivers provided by
+ the serving_input_receiver_fn.
+
+ For training and evaluation, the train_op is stored in an extra collection,
+ and loss, metrics, and predictions are included in a SignatureDef for the
+ mode in question.
+
+ Extra assets may be written into the SavedModel via the assets_extra
+ argument. This should be a dict, where each key gives a destination path
+ (including the filename) relative to the assets.extra directory. The
+ corresponding value gives the full path of the source file to be copied.
+ For example, the simple case of copying a single file without renaming it
+ is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+
+ Args:
+ estimator: an instance of tf.estimator.Estimator
+ export_dir_base: A string containing a directory in which to create
+ timestamped subdirectories containing exported SavedModels.
+ input_receiver_fn: a function that takes no argument and
+ returns the appropriate subclass of `InputReceiver`.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel, or `None` if no extra assets are needed.
+ as_text: whether to write the SavedModel proto in text format.
+ checkpoint_path: The checkpoint path to export. If `None` (the default),
+ the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ mode: tf.estimator.ModeKeys value indicating with mode will be exported.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if input_receiver_fn is None, no export_outputs
+ are provided, or no checkpoint can be found.
+ """
+ # pylint: enable=line-too-long
+
+ # pylint: disable=protected-access
+ return estimator._export_saved_model_for_mode(
+ export_dir_base, input_receiver_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs,
+ mode=mode)
+ # pylint: enable=protected-access
+
+
+def export_all_saved_models(
+ estimator, export_dir_base, input_receiver_fn_map,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
+ """Exports requested train/eval/predict graphs as separate SavedModels.
+
+ This is a wrapper around export_saved_model_for_mode that accepts
+ multiple modes simultaneously and creates directories for each under
+ export_dir_base. See `Estimator.export_saved_model_for_mode` for
+ further details as to how the export works for each mode.
+
+ Sample usage:
+ ```python
+ classifier = tf.estimator.LinearClassifier(
+ feature_columns=[age, language])
+ classifier.train(input_fn=input_fn)
+
+ feature_spec = {
+ 'age': tf.placeholder(dtype=tf.int64),
+ 'language': array_ops.placeholder(dtype=tf.string)
+ }
+ label_spec = tf.placeholder(dtype=dtypes.int64)
+
+ train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
+ feature_spec, label_spec)
+
+ serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+
+ rcvr_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn,
+ model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn,
+ }
+
+ export_dirs = tf.contrib.estimator.export_all_saved_models(
+ classifier,
+ export_dir_base='my_model/',
+ input_receiver_fn_map=rcvr_fn_map)
+
+ # export_dirs is a dict of directories with SavedModels, which
+ # can be used for serving, analysis with TFMA, or directly loaded in.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING],
+ export_dirs[tf.estimator.ModeKeys.TRAIN])
+ ...
+ ```
+
+ Args:
+ estimator: an instance of tf.estimator.Estimator
+ export_dir_base: A string containing a directory in which to create
+ timestamped subdirectories containing exported SavedModels.
+ input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
+ mappings, where the input_receiver_fn is a function that takes no
+ argument and returns the appropriate subclass of `InputReceiver`.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel, or `None` if no extra assets are needed.
+ as_text: whether to write the SavedModel proto in text format.
+ checkpoint_path: The checkpoint path to export. If `None` (the default),
+ the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+
+ Returns:
+ A dict of tf.estimator.ModeKeys value to string path for each exported
+ directory.
+
+ Raises:
+ ValueError: if any input_receiver_fn is None, no export_outputs
+ are provided, or no checkpoint can be found.
+ """
+ # pylint: enable=line-too-long
+
+ # pylint: disable=protected-access
+ return estimator._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs)
+ # pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py
new file mode 100644
index 0000000000..89d02582e1
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/export_test.py
@@ -0,0 +1,391 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for contrib wrapping of export_saved_model_for_mode functionality.
+
+These are direct copies of the tests included in core, with import locations
+changed. These should be removed when the functionality in core is part of the
+public API.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.contrib.estimator.python.estimator import export as contrib_export
+from tensorflow.python.client import session
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import training
+from tensorflow.python.util import compat
+
+
+def _model_fn_for_export_tests(features, labels, mode):
+ _, _ = features, labels
+ variables.Variable(1., name='weight')
+ scores = constant_op.constant([3.])
+ classes = constant_op.constant(['wumpus'])
+ update_global_step = state_ops.assign_add(training.get_global_step(), 1)
+ with ops.control_dependencies([update_global_step]):
+ train_op = constant_op.constant(2.)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ loss=constant_op.constant(1.),
+ train_op=train_op,
+ export_outputs={
+ 'test': export_output.ClassificationOutput(scores, classes)})
+
+
+def _x_y_input_fn():
+ return ({'x': constant_op.constant([[1], [1]]),
+ 'y': constant_op.constant([[2], [2]])},
+ constant_op.constant([[1], [1]]))
+
+
+def _model_fn_with_x_y(features, labels, mode):
+ _ = labels
+ variables.Variable(1., name='weight')
+ scores = constant_op.constant([3.])
+ classes = constant_op.constant(['wumpus'])
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ variables.Variable(36., name='name_collision')
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ export_outputs={
+ 'test': export_output.ClassificationOutput(scores, classes)})
+ else:
+ prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else ''
+
+ multiplied = math_ops.multiply(
+ features['x'], features['y'], name='{}multiplied'.format(prefix))
+ metrics = {'mean': metrics_lib.mean(features['x'] - features['y'],
+ name='{}mean'.format(prefix))}
+ variables.Variable(1., name='later_var')
+ variables.Variable(3., name='name_collision')
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=multiplied,
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ eval_metric_ops=metrics)
+
+
+def _get_serving_input_receiver_fn():
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ return export.build_parsing_serving_input_receiver_fn(feature_spec)
+
+
+def _get_supervised_input_receiver_fn():
+ feature_spec = {
+ 'x': array_ops.placeholder(
+ dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
+ 'y': array_ops.placeholder(
+ dtype=dtypes.int64, shape=(2, 1), name='feature_y')
+ }
+ label_spec = array_ops.placeholder(
+ dtype=dtypes.float32, shape=[1], name='truth')
+
+ return export.build_raw_supervised_input_receiver_fn(
+ feature_spec, label_spec)
+
+
+class EstimatorExportTest(test.TestCase):
+
+ def test_export_saved_model_train(self):
+ self._test_export_saved_model_for_mode(
+ _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN)
+
+ def test_export_saved_model_eval(self):
+ self._test_export_saved_model_for_mode(
+ _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL)
+
+ def test_export_saved_model_predict(self):
+ self._test_export_saved_model_for_mode(
+ _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT)
+
+ def _test_export_saved_model_for_mode(self, input_receiver_fn, mode):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
+ est.train(input_fn=_x_y_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = contrib_export.export_saved_model_for_mode(
+ est, export_dir_base, input_receiver_fn, mode=mode)
+
+ # Check that all the files are in the right places.
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self._validate_exported_files(export_dir)
+
+ # Restore, to validate that the export was well-formed.
+ tag_set = model_fn_lib.EXPORT_TAG_MAP[mode]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, tag_set, export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertFalse('name_collision_1' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_receiver_map(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 1)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertFalse('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_train_only(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 1)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('multiplied' in graph_ops)
+ self.assertTrue('mean/update_op' in graph_ops)
+ self.assertFalse('eval_multiplied' in graph_ops)
+ self.assertTrue('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_eval_only(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 1)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.EVAL], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('eval_multiplied' in graph_ops)
+ self.assertTrue('eval_mean/value' in graph_ops)
+ self.assertFalse('multiplied' in graph_ops)
+ self.assertTrue('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_no_serving(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 2)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('multiplied' in graph_ops)
+ self.assertFalse('eval_multiplied' in graph_ops)
+ self.assertTrue('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+ export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.EVAL], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('eval_multiplied' in graph_ops)
+ self.assertFalse('multiplied' in graph_ops)
+ # TODO(karmel): is this the desired behavior when names are shared?
+ self.assertTrue('feature_x_1' in graph_ops)
+ self.assertTrue('feature_y_1' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_three_defs(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ # Restore, to validate that the export was well-formed.
+ for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items():
+ export_dir = export_dirs[mode]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, tag_set, export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('global_step/Assign' in graph_ops)
+ self.assertTrue('global_step/Initializer/zeros' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_all_vars(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('later_var' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertFalse('later_var' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_name_collision(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('name_collision' in graph_ops)
+ self.assertFalse('name_collision_1' in graph_ops)
+ collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertEqual(3, collection_vars[-1].eval())
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('name_collision' in graph_ops)
+ self.assertFalse('name_collision_1' in graph_ops)
+ collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ # This is a non-obvious detail: when we load the estimator spec
+ # for predict, name_collision gets set to 36. However, we then restore
+ # from checkpoint, which should overwrite that var and make it the 3
+ # from training. In practice, this would not be a good way to write
+ # a model_fn, but leaving this check in for now to ensure consistency
+ # with what would happen given our current order of spec, then
+ # checkpoint.
+ self.assertEqual(3, collection_vars[-1].eval())
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def _test_export_all_saved_models(self, input_receiver_fn_map):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_with_x_y)
+ est.train(input_fn=_x_y_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dirs = contrib_export.export_all_saved_models(
+ est, export_dir_base, input_receiver_fn_map)
+
+ # Check that all the files are in the right places.
+ self.assertTrue(gfile.Exists(export_dir_base))
+
+ for _, export_dir in export_dirs.items():
+ self._validate_exported_files(export_dir)
+
+ return export_dirs, tmpdir
+
+ def _validate_exported_files(self, export_dir):
+ self.assertTrue(gfile.Exists(export_dir))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 5d19bf4714..109fdd3883 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -560,10 +560,10 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
weights=weights,
processed_labels=processed_labels)
- def create_estimator_spec(
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
- """Returns an `EstimatorSpec`.
+ """Returns an `model_fn._TPUEstimatorSpec`.
Args:
features: Input `dict` of `Tensor` or `SparseTensor` objects.
@@ -586,7 +586,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
`loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
- `EstimatorSpec`.
+ `model_fn._TPUEstimatorSpec`.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode, or if both are set.
@@ -606,7 +606,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
classifier_output = head_lib._classification_output( # pylint:disable=protected-access
scores=probabilities, n_classes=self._n_classes,
label_vocabulary=self._label_vocabulary)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
@@ -629,16 +629,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=self._eval_metric_ops(
- labels=processed_labels,
- probabilities=probabilities,
- weights=weights,
- unreduced_loss=unreduced_loss,
- regularization_loss=regularization_loss))
+ eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access
+ self._eval_metric_ops, {
+ 'labels': processed_labels,
+ 'probabilities': probabilities,
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss,
+ }))
# Train.
if optimizer is not None:
@@ -672,7 +674,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
summary.scalar(
head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index d5b3b279a1..7355a403ae 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -381,7 +381,7 @@ py_test(
py_test(
name = "rev_block_lib_test",
- size = "small",
+ size = "medium",
srcs = ["python/layers/rev_block_lib_test.py"],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 3b053cd4c6..4a360711f8 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -485,6 +485,7 @@ py_test(
name = "state_saving_rnn_estimator_test",
size = "medium",
srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["noasan"],
deps = [
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 3744abd860..dfc6a393d0 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -468,10 +468,15 @@ class Experiment(object):
on which that evaluation was based.
At the beginning of evaluation, the passed `eval_results` will be None
so it's expected that the predicate function handles that gracefully.
- When `predicate_fn` is not specified, continuous eval will run in an
- infinite loop (if `train_steps` is None). or exit once global step
- reaches `train_steps`.
-
+ Continuous eval behavior under different conditions:
+ * When `predicate_fn` is specified:
+ + if `train_steps` is None, run until `predicate_fn` returns False.
+ + if `train_steps` is specified, run until either global step
+ reaches `train_steps` or `predicate_fn` returns False.
+ * When `predicate_fn` is not specified:
+ + if `train_steps` is None, run in an infinite loop.
+ + if `train_steps` is specified, run until global step reaches
+ `train_steps`.
export: Whether to export from this step. Default is 'True'.
Raises:
diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md
new file mode 100644
index 0000000000..8fd63d5cee
--- /dev/null
+++ b/tensorflow/contrib/lite/RELEASE.md
@@ -0,0 +1,8 @@
+# Release 0.1.7
+
+* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit
+ fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0).
+* To reproduce the iOS library, it's required to cherry pick git commit
+ f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue.
+* The code is based on TensorFlow 1.8.0 release candidate and it's very close
+ to TensorFlow 1.8.0 release.
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 4910c89eae..35cf43dd32 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -162,6 +162,9 @@ typedef struct {
} TfLitePadParams;
typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
// TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
// For now we will fix the maximum possible number of dimensions.
int shape[8];
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 962a7a8970..a038acf284 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -85,6 +85,11 @@ typedef enum {
kTfLiteBuiltinMinimum = 57,
kTfLiteBuiltinLess = 58,
kTfLiteBuiltinNeg = 59,
+ kTfLiteBuiltinPadv2 = 60,
+ kTfLiteBuiltinGreater = 61,
+ kTfLiteBuiltinGreaterEqual = 62,
+ kTfLiteBuiltinLessEqual = 63,
+ kTfLiteBuiltinSelect = 64,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 0051ee84ec..f45fcceb2e 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -281,6 +281,32 @@ Options {
}
```
+**GREATER**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ greater than the corresponding element of the second tensor.
+}
+```
+
+**GREATER_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ greater than or equal to the corresponding element of the second tensor.
+}
+```
+
**L2_NORMALIZATION**
```
@@ -325,6 +351,19 @@ Outputs {
}
```
+**LESS_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is less
+ than or equal to the corresponding element of the second tensor.
+}
+```
+
**LOCAL_RESPONSE_NORMALIZATION**
```
@@ -600,6 +639,20 @@ Outputs {
}
```
+**SELECT**
+
+```
+Inputs {
+ 0: tensor
+ 1: tensor
+ 2: tensor
+}
+Outputs {
+ 0: tensor that contains the elementwise values of 'tensor 1' if the
+ corresponding value of 'tensor 0' is true or the value of 'tensor 2' if false.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 1074f64263..0450e86ae7 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -201,7 +201,7 @@ class Interpreter {
// Overrides execution plan. This bounds checks indices sent in.
TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
- // Get a tensor data structure.
+ // Get a mutable tensor data structure.
// TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
// read/write access to structure
TfLiteTensor* tensor(int tensor_index) {
@@ -210,9 +210,14 @@ class Interpreter {
return &context_.tensors[tensor_index];
}
+ // Get an immutable tensor data structure.
+ const TfLiteTensor* tensor(int tensor_index) const {
+ if (tensor_index >= context_.tensors_size || tensor_index < 0)
+ return nullptr;
+ return &context_.tensors[tensor_index];
+ }
+
// Get a pointer to an operation and registration data structure if in bounds.
- // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
- // read/write access to structure
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
int node_index) const {
if (node_index >= nodes_and_registration_.size() || node_index < 0)
@@ -220,7 +225,8 @@ class Interpreter {
return &nodes_and_registration_[node_index];
}
- // Perform a checked cast to the appropriate tensor type.
+ // Perform a checked cast to the appropriate tensor type (mutable pointer
+ // version).
template <class T>
T* typed_tensor(int tensor_index) {
if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
@@ -231,6 +237,18 @@ class Interpreter {
return nullptr;
}
+ // Perform a checked cast to the appropriate tensor type (immutable pointer
+ // version).
+ template <class T>
+ const T* typed_tensor(int tensor_index) const {
+ if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
+ if (tensor_ptr->type == typeToTfLiteType<T>()) {
+ return reinterpret_cast<const T*>(tensor_ptr->data.raw);
+ }
+ }
+ return nullptr;
+ }
+
// Return a pointer into the data of a given input tensor. The given index
// must be between 0 and inputs().size().
template <class T>
@@ -238,13 +256,20 @@ class Interpreter {
return typed_tensor<T>(inputs_[index]);
}
- // Return a pointer into the data of a given output tensor. The given index
- // must be between 0 and outputs().size().
+ // Return a mutable pointer into the data of a given output tensor. The given
+ // index must be between 0 and outputs().size().
template <class T>
T* typed_output_tensor(int index) {
return typed_tensor<T>(outputs_[index]);
}
+ // Return an immutable pointer into the data of a given output tensor. The
+ // given index must be between 0 and outputs().size().
+ template <class T>
+ const T* typed_output_tensor(int index) const {
+ return typed_tensor<T>(outputs_[index]);
+ }
+
// Change the dimensionality of a given tensor. Note, this is only acceptable
// for tensor indices that are inputs.
// Returns status of failure or success.
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index 1dda55b8ed..1e57922603 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -46,12 +46,27 @@ android_library(
],
)
-java_library(
+android_library(
name = "ovicbenchmarkerlib",
srcs = [
"ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java",
"ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflowlite",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@org_checkerframework_qual",
+ ],
+)
+
+java_library(
+ name = "ovicbenchmarkerlib_java",
+ srcs = [
+ "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java",
+ "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
+ ],
javacopts = JAVACOPTS,
visibility = ["//visibility:public"],
deps = [
@@ -170,18 +185,14 @@ java_test(
size = "medium",
srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"],
data = [
- "ovic/src/testdata/float_model.lite",
- "ovic/src/testdata/labels.txt",
- "ovic/src/testdata/low_res_model.lite",
- "ovic/src/testdata/quantized_model.lite",
- "ovic/src/testdata/test_image_128.jpg",
- "ovic/src/testdata/test_image_224.jpg",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.ovic.OvicClassifierTest",
visibility = ["//visibility:public"],
deps = [
- ":ovicbenchmarkerlib",
+ ":ovicbenchmarkerlib_java",
"@com_google_truth",
"@junit",
],
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
index 20f520814d..ef8a9e0845 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
@@ -13,51 +13,55 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
-<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
- xmlns:app="http://schemas.android.com/apk/res-auto"
+
+<LinearLayout
+ xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
- android:layout_height="match_parent">
+ android:layout_height="match_parent"
+ android:background="#bb7700"
+ android:orientation="horizontal">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="0dp"
+ android:layout_height="match_parent"
+ android:layout_weight=".8"/>
+
+ <LinearLayout
+ android:layout_width="0dp"
+ android:layout_height="match_parent"
+ android:layout_weight=".2"
+ android:orientation="vertical">
+
+ <ImageView
+ android:id="@+id/logoview"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:scaleType="centerInside"
+ android:src="@drawable/logo"/>
- <LinearLayout
+ <ToggleButton
+ android:id="@+id/button"
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content"
+ android:textOff="@string/tflite"
+ android:textOn="@string/nnapi"/>
+ <NumberPicker
+ android:id="@+id/np"
+ android:layout_width="wrap_content"
+ android:layout_height="47dp"
+ android:layout_gravity="center_horizontal"
+ android:visibility="visible"/>
+
+ <TextView
+ android:id="@+id/text"
+ android:textStyle="bold"
android:layout_width="match_parent"
android:layout_height="match_parent"
- android:background="#bb7700"
- android:orientation="horizontal"
- android:weightSum="100">
-
- <LinearLayout
- android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:layout_weight="30"
- android:orientation="vertical">
-
- <com.example.android.tflitecamerademo.AutoFitTextureView
- android:id="@+id/texture"
- android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:layout_weight="100" />
-
- <ImageView
- android:id="@+id/logoview"
- android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:layout_weight="100"
- android:scaleType="centerCrop"
- android:src="@drawable/logo" />
-
- </LinearLayout>
-
- <TextView
- android:id="@+id/text"
- android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:layout_weight="70"
- android:paddingLeft="5dp"
- android:paddingTop="20dp"
- android:textColor="#FFF"
- android:textSize="20sp"
- android:textStyle="bold" />
-
- </LinearLayout>
-
-</RelativeLayout>
+ android:paddingTop="20dp"
+ android:textColor="#FFF"
+ android:textSize="20sp"/>
+
+ </LinearLayout>
+</LinearLayout>
+
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
index d12435d5ab..72a229ecdb 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
@@ -15,45 +15,47 @@
-->
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
- xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
- android:layout_height="match_parent">
+ android:layout_height="match_parent"
+ android:background="#bb7700">
- <LinearLayout
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
android:layout_width="match_parent"
android:layout_height="match_parent"
- android:orientation="vertical"
- android:weightSum="60">
-
- <FrameLayout
- android:id="@+id/control"
- android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:layout_alignParentBottom="true"
- android:layout_alignParentStart="true"
- android:layout_weight="60"
- android:background="#cc7700"
- android:paddingLeft="20dp"
- android:paddingStart="20dp">
-
- </FrameLayout>
+ android:layout_weight="1" />
- <com.example.android.tflitecamerademo.AutoFitTextureView
- android:id="@+id/texture"
+ <LinearLayout
android:layout_width="wrap_content"
android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentEnd="false"
android:layout_alignParentStart="true"
- android:layout_alignParentLeft="true"
- android:layout_alignParentTop="true" />
+ android:layout_alignParentTop="false"
+ android:background="#bb7700"
+ android:orientation="vertical"
+ android:weightSum="100">
+
+ <ImageView
+ android:id="@+id/logoview2"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_weight="30"
+ android:scaleType="fitStart"
+ android:src="@drawable/logo" />
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:layout_weight="20"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentEnd="true"
+ android:layout_alignParentRight="true"
+ android:layout_weight="30"
android:textColor="#FFF"
android:textSize="20sp"
android:textStyle="bold" />
+
</LinearLayout>
<RelativeLayout
@@ -83,33 +85,4 @@
android:layout_below="@+id/button"
android:visibility="visible" />
</RelativeLayout>
-
- <RelativeLayout
- android:id="@+id/control2"
- android:layout_width="match_parent"
- android:layout_height="135dp"
- android:layout_alignParentLeft="true"
- android:layout_alignParentStart="true"
- android:layout_alignTop="@+id/control"
- android:layout_marginLeft="300dp"
- android:layout_marginStart="300dp"
- android:background="@color/control_background">
-
- <ToggleButton
- android:id="@+id/button"
- android:textOff="@string/tflite"
- android:textOn="@string/nnapi"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
- android:layout_alignParentLeft="true"
- android:layout_alignParentStart="true" />
-
- <NumberPicker
- android:id="@+id/np"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
- android:layout_below="@+id/button"
- android:visibility="visible" />
- </RelativeLayout>
-
</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md
index 76c33838bf..77799b3569 100644
--- a/tensorflow/contrib/lite/java/ovic/README.md
+++ b/tensorflow/contrib/lite/java/ovic/README.md
@@ -6,7 +6,7 @@ This folder contains building code for track one of the [Low Power ImageNet Reco
Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
-## To test the benchmarker:
+## Test the benchmarker:
The testing utilities helps the developers (you) to make sure that your submissions in TfLite format will be processed as expected in the competition's benchmarking system.
@@ -37,7 +37,7 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/
You can run test with Bazel as below. This helps to ensure that the installation is correct.
```sh
-bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --test_output=all
+bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --cxxopt=-Wno-all --test_output=all
```
### Test your submissions
@@ -56,28 +56,83 @@ cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/
The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224.
-* Add your model and test image to the BUILD rule:
+* Add your model and test image to the BUILD rule at `tensorflow/contrib/lite/java/ovic/src/testdata/BUILD`:
```JSON
-java_test(
- name = "OvicClassifierTest",
- size = "medium",
- srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"],
- data = [
- "ovic/src/testdata/float_model.lite",
- "ovic/src/testdata/labels.txt",
- "ovic/src/testdata/low_res_model.lite",
- "ovic/src/testdata/quantized_model.lite",
- "ovic/src/testdata/test_image_128.jpg",
- "ovic/src/testdata/test_image_224.jpg",
- "ovic/src/testdata/my_model.lite", # <--- Your submission.
- "ovic/src/testdata/my_test_image.jpg", # <--- Your test image.
- ],
- ...
+filegroup(
+ name = "ovic_testdata",
+ srcs = [
+ "@tflite_ovic_testdata//:float_model.lite",
+ "@tflite_ovic_testdata//:low_res_model.lite",
+ "@tflite_ovic_testdata//:quantized_model.lite",
+ "@tflite_ovic_testdata//:test_image_128.jpg",
+ "@tflite_ovic_testdata//:test_image_224.jpg"
+ "my_model.lite", # <--- Your submission.
+ "my_test_image.jpg", # <--- Your test image.
+ ],
+ ...
```
* Modify `OvicClassifierTest.java` to test your model.
-Change `TEST_IMAGE_PATH` to `testdata/my_test_image.jpg`. If your model runs inference in floating point, change `FLOAT_MODEL_PATH` to `testdata/my_model.lite`. If your model runs [quantized inference](https://www.tensorflow.org/performance/quantization), change `QUANTIZED_MODEL_PATH` to `testdata/my_model.lite`.
+Change `TEST_IMAGE_PATH` to `my_test_image.jpg`. Change either `FLOAT_MODEL_PATH` or `QUANTIZED_MODEL_PATH` to `my_model.lite` depending on whether your model runs inference in float or [8-bit](https://www.tensorflow.org/performance/quantization).
Now you can run the bazel tests to catch any runtime issues with the submission.
+
+Note: Please make sure that your submission passes the test. If a submission fails to pass the test it will not be processed by the submission server.
+
+## Measure on-device latency
+
+We provide two ways to measure the on-device latency of your submission. The first is through our competition server, which is reliable and repeatable, but is limited to a few trials per day. The second is through the benchmarker Apk, which requires a device and may not be as accurate as the server, but has a fast turn-around and no access limitations. We recommend that the participants use the benchmarker apk for early development, and reserve the competition server for evaluating promising submissions.
+
+### Running the benchmarker app
+
+Make sure that you have followed instructions in [Test your submissions](#test-your-submissions) to add your model to the testdata folder and to the corresponding build rules.
+
+Modify `tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`:
+
+* Add your model to the benchmarker apk by changing `MODEL_PATH` and `TEST_IMAGE_PATH` below to your submission and test image.
+
+```
+ private static final String TEST_IMAGE_PATH = "my_test_image.jpg";
+ private static final String MODEL_PATH = "my_model.lite";
+```
+
+* Adjust the benchmark parameters when needed:
+
+You can chnage the length of each experiment, and the processor affinity below. `BIG_CORE_MASK` is an integer whose binary encoding represents the set of used cores. This number is phone-specific. For example, Pixel 2 has 8 cores: the 4 little cores are represented by the 4 less significant bits, and the 4 big cores by the 4 more significant bits. Therefore a mask value of 16, or in binary `00010000`, represents using only the first big core. The mask 32, or in binary `00100000` uses the second big core and should deliver identical results as the mask 16 because the big cores are interchangeable.
+
+```
+ /** Wall time for each benchmarking experiment. */
+ private static final double WALL_TIME = 3000;
+ /** Maximum number of iterations in each benchmarking experiment. */
+ private static final int MAX_ITERATIONS = 100;
+ /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */
+ private static final int BIG_CORE_MASK = 16;
+```
+
+Note: You'll need ROOT access to the phone to change processor affinity.
+
+* Build and install the app.
+
+```
+bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/contrib/lite/java/ovic/demo/app:ovic_benchmarker_binary
+adb install -r bazel-bin/tensorflow/contrib/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk
+```
+
+Start the app and click the `Start` button in dark green. The button should turn bright green, signaling that the experiment is running. The benchmarking results will be displayed after about the `WALL_TIME` you specified above. For example:
+
+```
+my_model.lite: Average latency=158.6ms after 20 runs.
+```
+
+### Sample latencies
+
+Note: the benchmarking results can be quite different depending on the background processes running on the phone. A few things that help stabilize the app's readings are placing the phone on a cooling plate, restarting the phone, and shutting down internet access.
+
+| Model | Pixel 1 latency (ms) | Pixel 2 latency (ms) |
+| -------------------- |:---------------------:| --------------------:|
+| float_model.lite | 120 | 155 |
+| quantized_model.lite | 85 | 74 |
+| low_res_model.lite | 4.2 | 4.0 |
+
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml
new file mode 100644
index 0000000000..55f2961fd7
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml
@@ -0,0 +1,48 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2018 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="ovic.demo.app"
+ android:versionCode="1"
+ android:versionName="1.0" >
+
+ <uses-sdk
+ android:minSdkVersion="19"
+ android:targetSdkVersion="21" />
+
+ <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
+ <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
+ <uses-permission android:name="android.permission.READ_PHONE_STATE" />
+
+ <application
+ android:allowBackup="true"
+ android:icon="@drawable/ic_launcher"
+ android:largeHeap="true"
+ android:label="@string/app_name">
+ <activity
+ android:name="ovic.demo.app.OvicBenchmarkerActivity"
+ android:label="@string/app_name"
+ android:screenOrientation="portrait">
+
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+ </application>
+
+</manifest>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
new file mode 100644
index 0000000000..47101ff574
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -0,0 +1,29 @@
+# Sample app for OVIC benchmarking.
+licenses(["notice"]) # Apache 2.0
+
+android_binary(
+ name = "ovic_benchmarker_binary",
+ srcs = [
+ "OvicBenchmarker.java",
+ "OvicBenchmarkerActivity.java",
+ ],
+ assets = [
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
+ ],
+ assets_dir = "",
+ custom_package = "ovic.demo.app",
+ manifest = "AndroidManifest.xml",
+ nocompress_extensions = [
+ ".lite",
+ ".tflite",
+ ],
+ resource_files = glob(["res/**"]),
+ tags = ["manual"],
+ deps = [
+ "//tensorflow/contrib/lite/java:ovicbenchmarkerlib",
+ "//tensorflow/contrib/lite/java:tensorflowlite",
+ "@androidsdk//com.android.support:support-v13-25.2.0",
+ "@androidsdk//com.android.support:support-v4-25.2.0",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java
index d0102883e6..113ab74a20 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java
@@ -1,4 +1,4 @@
-/*Copyright 2018 Google LLC
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-package org.tensorflow.ovic;
+package ovic.demo.app;
import android.graphics.Bitmap;
import android.os.SystemClock;
@@ -22,6 +22,8 @@ import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
+import org.tensorflow.ovic.OvicClassifier;
+import org.tensorflow.ovic.OvicSingleImageResult;
/**
* Class that benchmarks image classifier models.
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
new file mode 100644
index 0000000000..59457c308a
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
@@ -0,0 +1,247 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package ovic.demo.app;
+
+import android.app.Activity;
+import android.content.res.AssetFileDescriptor;
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+import android.graphics.BitmapFactory;
+import android.os.Bundle;
+import android.os.Process;
+import android.os.SystemClock;
+import android.util.Log;
+import android.view.View;
+import android.widget.TextView;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.text.DecimalFormat;
+import org.tensorflow.ovic.OvicSingleImageResult;
+
+/** Class that benchmark image classifier models. */
+public class OvicBenchmarkerActivity extends Activity {
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicBenchmarkerActivity";
+
+ /** Name of the label file stored in Assets. */
+ private static final String LABEL_PATH = "labels.txt";
+
+ private static final String TEST_IMAGE_PATH = "test_image_224.jpg";
+ private static final String MODEL_PATH = "float_model.lite";
+ /**
+ * Each bottom press will launch a benchmarking experiment. The experiment stops when either the
+ * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS,
+ * whichever comes first.
+ */
+ /** Wall time for each benchmarking experiment. */
+ private static final double WALL_TIME = 3000;
+ /** Maximum number of iterations in each benchmarking experiment. */
+ private static final int MAX_ITERATIONS = 100;
+ /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */
+ private static final int BIG_CORE_MASK = 16;
+ /** Amount of time in milliseconds to wait for affinity to set. */
+ private static final int WAIT_TIME_FOR_AFFINITY = 1000;
+
+ /* The model to be benchmarked. */
+ private MappedByteBuffer model = null;
+ private InputStream labelInputStream = null;
+ private OvicBenchmarker benchmarker;
+ /** Inference result of each iteration. */
+ OvicSingleImageResult iterResult = null;
+
+ private TextView textView = null;
+ // private Button startButton = null;
+ private static final DecimalFormat df2 = new DecimalFormat(".##");
+
+ @Override
+ protected void onCreate(Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ setContentView(R.layout.activity_main);
+
+ // TextView used to display the progress, for information purposes only.
+ textView = (TextView) findViewById(R.id.textView);
+ }
+
+ private Bitmap loadTestBitmap() throws IOException {
+ InputStream imageStream = getAssets().open(TEST_IMAGE_PATH);
+ return BitmapFactory.decodeStream(imageStream);
+ }
+
+ public void initializeTest() throws IOException {
+ Log.i(TAG, "Initializing benchmarker.");
+ benchmarker = new OvicBenchmarker(WALL_TIME);
+ AssetManager am = getAssets();
+ AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH);
+ FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
+ FileChannel fileChannel = modelInputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ labelInputStream = am.open(LABEL_PATH);
+ }
+
+ public Boolean doTestIteration() throws IOException, InterruptedException {
+ if (benchmarker == null) {
+ throw new RuntimeException("Benchmarker has not been initialized.");
+ }
+ if (benchmarker.shouldStop()) {
+ return false;
+ }
+ if (!benchmarker.readyToTest()) {
+ Log.i(TAG, "getting ready to test.");
+ benchmarker.getReadyToTest(labelInputStream, model);
+ if (!benchmarker.readyToTest()) {
+ throw new RuntimeException("Failed to get the benchmarker ready.");
+ }
+ }
+ Log.i(TAG, "Going to do test iter.");
+ // Start testing.
+ Bitmap testImageBitmap = loadTestBitmap();
+ iterResult = benchmarker.doTestIteration(testImageBitmap);
+ testImageBitmap.recycle();
+ if (iterResult == null) {
+ throw new RuntimeException("Inference failed to produce a result.");
+ }
+ Log.i(TAG, iterResult.toString());
+ return true;
+ }
+
+ public void startPressed(View view) throws IOException {
+ Log.i(TAG, "Start pressed");
+ try {
+ initializeTest();
+ } catch (IOException e) {
+ Log.e(TAG, "Can't initialize benchmarker.", e);
+ throw e;
+ }
+ String displayText = "";
+ try {
+ setProcessorAffinity(BIG_CORE_MASK);
+ } catch (IOException e) {
+ Log.e(TAG, e.getMessage());
+ displayText = e.getMessage() + "\n";
+ }
+ Log.i(TAG, "Successfully initialized benchmarker.");
+ int testIter = 0;
+ Boolean iterSuccess = false;
+ double totalLatency = 0.0f;
+ while (testIter < MAX_ITERATIONS) {
+ try {
+ iterSuccess = doTestIteration();
+ } catch (IOException e) {
+ Log.e(TAG, "Error during iteration " + testIter);
+ throw e;
+ } catch (InterruptedException e) {
+ Log.e(TAG, "Interrupted at iteration " + testIter);
+ }
+ if (!iterSuccess) {
+ break;
+ }
+ testIter++;
+ totalLatency += (double) iterResult.latency;
+ }
+ ;
+ Log.i(TAG, "Benchmarking finished");
+
+ if (textView != null) {
+ if (testIter > 0) {
+ textView.setText(
+ displayText
+ + MODEL_PATH
+ + ": Average latency="
+ + df2.format(totalLatency / testIter)
+ + "ms after "
+ + testIter
+ + " runs.");
+ } else {
+ textView.setText("Benchmarker failed to run on more than one images.");
+ }
+ }
+ }
+
+ private static void setProcessorAffinity(int mask) throws IOException {
+ int myPid = Process.myPid();
+ Log.i(TAG, String.format("Setting processor affinity to 0x%02x", mask));
+
+ String command = String.format("taskset -a -p %x %d", mask, myPid);
+ try {
+ Runtime.getRuntime().exec(command).waitFor();
+ } catch (InterruptedException e) {
+ throw new IOException("Interrupted: " + e);
+ }
+
+ // Make sure set took effect - try for a second to confirm the change took. If not then fail.
+ long startTimeMs = SystemClock.elapsedRealtime();
+ while (true) {
+ int readBackMask = readCpusAllowedMask();
+ if (readBackMask == mask) {
+ Log.i(TAG, String.format("Successfully set affinity to 0x%02x", mask));
+ return;
+ }
+ if (SystemClock.elapsedRealtime() > startTimeMs + WAIT_TIME_FOR_AFFINITY) {
+ throw new IOException(
+ String.format(
+ "Core-binding failed: affinity set to 0x%02x but read back as 0x%02x\n"
+ + "please root device.",
+ mask, readBackMask));
+ }
+
+ try {
+ Thread.sleep(50);
+ } catch (InterruptedException e) {
+ // Ignore sleep interrupted, will sleep again and compare is final cross-check.
+ }
+ }
+ }
+
+ public static int readCpusAllowedMask() throws IOException {
+ // Determine how many CPUs there are total
+ final String pathname = "/proc/self/status";
+ final String resultPrefix = "Cpus_allowed:";
+ File file = new File(pathname);
+ String line = "<NO LINE READ>";
+ String allowedCPU = "";
+ Integer allowedMask = null;
+ BufferedReader bufReader = null;
+ try {
+ bufReader = new BufferedReader(new FileReader(file));
+ while ((line = bufReader.readLine()) != null) {
+ if (line.startsWith(resultPrefix)) {
+ allowedMask = Integer.valueOf(line.substring(resultPrefix.length()).trim(), 16);
+ allowedCPU = bufReader.readLine();
+ break;
+ }
+ }
+ } catch (RuntimeException e) {
+ throw new IOException(
+ "Invalid number in " + pathname + " line: \"" + line + "\": " + e.getMessage());
+ } finally {
+ if (bufReader != null) {
+ bufReader.close();
+ }
+ }
+ if (allowedMask == null) {
+ throw new IOException(pathname + " missing " + resultPrefix + " line");
+ }
+ Log.i(TAG, allowedCPU);
+ return allowedMask;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
new file mode 100644
index 0000000000..c5d19bad89
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
@@ -0,0 +1,58 @@
+apply plugin: 'com.android.application'
+
+android {
+ compileSdkVersion 26
+ buildToolsVersion "26.0.1"
+ defaultConfig {
+ applicationId "android.example.com.ovicbenchmarker"
+ minSdkVersion 15
+ targetSdkVersion 26
+ versionCode 1
+ versionName "1.0"
+ testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+
+ // Remove this block.
+ jackOptions {
+ enabled true
+ }
+ }
+ lintOptions {
+ abortOnError false
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
+ }
+ }
+ aaptOptions {
+ noCompress "lite", "tflite"
+ }
+
+ compileOptions {
+ sourceCompatibility JavaVersion.VERSION_1_8
+ targetCompatibility JavaVersion.VERSION_1_8
+ }
+}
+
+repositories {
+ maven {
+ url 'https://google.bintray.com/tensorflow'
+ }
+}
+
+dependencies {
+ compile fileTree(dir: 'libs', include: ['*.jar'])
+ androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ exclude group: 'com.android.support', module: 'support-annotations'
+ })
+ compile 'com.android.support:appcompat-v7:25.2.0'
+ compile 'com.android.support.constraint:constraint-layout:1.0.2'
+ compile 'com.android.support:design:25.2.0'
+ compile 'com.android.support:support-annotations:25.3.1'
+ compile 'com.android.support:support-v13:25.2.0'
+
+ compile 'org.tensorflow:tensorflow-lite:+'
+
+ testCompile 'junit:junit:4.12'
+}
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png
new file mode 100644
index 0000000000..715d1b6d69
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png
new file mode 100644
index 0000000000..9beff0885f
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml
new file mode 100644
index 0000000000..93f5c6a016
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2018 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<selector
+ xmlns:android="http://schemas.android.com/apk/res/android">
+ <item
+ android:state_enabled="false">
+ <shape android:shape="rectangle">
+ <solid android:color="#808080"/>
+ </shape>
+ </item>
+ <item
+ android:state_enabled="true"
+ android:state_pressed="true">
+ <shape android:shape="rectangle">
+ <solid android:color="#44ff44"/>
+ </shape>
+ </item>
+ <item
+ android:state_enabled="true"
+ android:state_pressed="false">
+ <shape android:shape="rectangle" >
+ <solid android:color="#227f22"/>
+ </shape>
+ </item>
+</selector>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
new file mode 100644
index 0000000000..e9d83bae54
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
@@ -0,0 +1,54 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2018 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout
+ xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:tools="http://schemas.android.com/tools"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:paddingBottom="@dimen/activity_vertical_margin"
+ android:paddingLeft="@dimen/activity_horizontal_margin"
+ android:paddingRight="@dimen/activity_horizontal_margin"
+ android:paddingTop="@dimen/activity_vertical_margin"
+ tools:context="ovic.demo.app.OvicBenchmarkerActivity">
+
+ <TextView
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:text="@string/initial_status_msg"
+ android:id="@+id/textView"
+ android:layout_above="@+id/button_start"
+ android:layout_alignParentTop="true"/>
+
+ <Button
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:text="@string/start_label"
+ android:id="@id/button_start"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentLeft="true"
+ android:background="@drawable/start_button_color"
+ android:padding="10dp"
+ android:layout_marginRight="30dp"
+ android:layout_marginLeft="100dp"
+ android:layout_marginTop="10dp"
+ android:foreground="#000000"
+ android:textColor="#ffffff"
+ android:enabled="true"
+ style="?android:attr/buttonBarButtonStyle"
+ android:onClick="startPressed"/>
+
+</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml
new file mode 100644
index 0000000000..250b581430
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml
@@ -0,0 +1,20 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2018 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <dimen name="activity_vertical_margin">20dp</dimen>
+ <dimen name="activity_horizontal_margin">16dp</dimen>
+</resources>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
new file mode 100644
index 0000000000..d26beb1d27
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ Copyright 2018 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<resources>
+ <string name="app_name" translatable="false">Benchmarker</string>
+
+ <string name="start_label" translatable="false">Start</string>
+ <string name="initial_status_msg" translatable="false"> Press start to run the benchmarks.</string>
+</resources>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/build.gradle
new file mode 100644
index 0000000000..b78a0b86c9
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/build.gradle
@@ -0,0 +1,23 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+
+buildscript {
+ repositories {
+ jcenter()
+ }
+ dependencies {
+ classpath 'com.android.tools.build:gradle:2.3.1'
+
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
+ }
+}
+
+allprojects {
+ repositories {
+ jcenter()
+ }
+}
+
+task clean(type: Delete) {
+ delete rootProject.buildDir
+}
diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle.properties b/tensorflow/contrib/lite/java/ovic/demo/gradle.properties
new file mode 100644
index 0000000000..aac7c9b461
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/gradle.properties
@@ -0,0 +1,17 @@
+# Project-wide Gradle settings.
+
+# IDE (e.g. Android Studio) users:
+# Gradle settings configured through the IDE *will override*
+# any settings specified in this file.
+
+# For more details on how to configure your build environment visit
+# http://www.gradle.org/docs/current/userguide/build_environment.html
+
+# Specifies the JVM arguments used for the daemon process.
+# The setting is particularly useful for tweaking memory settings.
+org.gradle.jvmargs=-Xmx1536m
+
+# When configured, Gradle will run in incubating parallel mode.
+# This option should only be used with decoupled projects. More details, visit
+# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
+# org.gradle.parallel=true
diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000000..13372aef5e
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar
Binary files differ
diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000000..fa7a38a0e4
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Thu Sep 28 09:01:41 PDT 2017
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip
diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradlew b/tensorflow/contrib/lite/java/ovic/demo/gradlew
new file mode 100755
index 0000000000..9d82f78915
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/gradlew
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn ( ) {
+ echo "$*"
+}
+
+die ( ) {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+esac
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=$((i+1))
+ done
+ case $i in
+ (0) set -- ;;
+ (1) set -- "$args0" ;;
+ (2) set -- "$args0" "$args1" ;;
+ (3) set -- "$args0" "$args1" "$args2" ;;
+ (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
+function splitJvmOpts() {
+ JVM_OPTS=("$@")
+}
+eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
+JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
+
+exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat b/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat
new file mode 100644
index 0000000000..8a0b282aa6
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat
@@ -0,0 +1,90 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windowz variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+if "%@eval[2+2]" == "4" goto 4NT_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+goto execute
+
+:4NT_args
+@rem Get arguments from the 4NT Shell from JP Software
+set CMD_LINE_ARGS=%$
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/tensorflow/contrib/lite/java/ovic/demo/settings.gradle b/tensorflow/contrib/lite/java/ovic/demo/settings.gradle
new file mode 100644
index 0000000000..e7b4def49c
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/demo/settings.gradle
@@ -0,0 +1 @@
+include ':app'
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index b2dfd8f2e7..4cf51bb0fa 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -67,7 +67,7 @@ public class OvicClassifier {
});
/** Initializes an {@code OvicClassifier}. */
- OvicClassifier(InputStream labelInputStream, MappedByteBuffer model)
+ public OvicClassifier(InputStream labelInputStream, MappedByteBuffer model)
throws IOException, RuntimeException {
if (model == null) {
throw new RuntimeException("Input model is empty.");
@@ -106,7 +106,7 @@ public class OvicClassifier {
/** Classifies a {@link ByteBuffer} image. */
// @throws RuntimeException if model is uninitialized.
- OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) throws RuntimeException {
+ public OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) {
if (tflite == null) {
throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
}
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 098ed8ceba..56f3e7604a 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -45,17 +45,17 @@ public final class OvicClassifierTest {
private ByteBuffer lowResTestImage = null;
private OvicSingleImageResult testResult = null;
private static final String LABELS_PATH =
- "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt";
+ "tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt";
private static final String QUANTIZED_MODEL_PATH =
- "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/quantized_model.lite";
+ "external/tflite_ovic_testdata/quantized_model.lite";
private static final String LOW_RES_MODEL_PATH =
- "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/low_res_model.lite";
+ "external/tflite_ovic_testdata/low_res_model.lite";
private static final String FLOAT_MODEL_PATH =
- "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/float_model.lite";
+ "external/tflite_ovic_testdata/float_model.lite";
private static final String TEST_IMAGE_PATH =
- "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/test_image_224.jpg";
+ "external/tflite_ovic_testdata/test_image_224.jpg";
private static final String TEST_LOW_RES_IMAGE_PATH =
- "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/test_image_128.jpg";
+ "external/tflite_ovic_testdata/test_image_128.jpg";
private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform"
@Before
diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
new file mode 100644
index 0000000000..1021ea30dd
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
@@ -0,0 +1,19 @@
+# Testdata for OVIC benchmarker demo App and tests.
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "ovic_testdata",
+ srcs = [
+ "@tflite_ovic_testdata//:float_model.lite",
+ "@tflite_ovic_testdata//:low_res_model.lite",
+ "@tflite_ovic_testdata//:quantized_model.lite",
+ "@tflite_ovic_testdata//:test_image_128.jpg",
+ "@tflite_ovic_testdata//:test_image_224.jpg",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+exports_files(
+ ["labels.txt"],
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index feab18b5c2..79e3c9f266 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -164,6 +164,7 @@ cc_library(
"register.cc",
"reshape.cc",
"resize_bilinear.cc",
+ "select.cc",
"skip_gram.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
@@ -870,6 +871,23 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "select_test",
+ size = "small",
+ srcs = [
+ "select_test.cc",
+ ],
+ tags = [
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 87c413cb98..2885ce032b 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -28,7 +28,7 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -56,61 +56,139 @@ TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
-TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ requires_broadcast \
+ ? reference_ops::Broadcast##opname( \
+ GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output)) \
+ : reference_ops::opname( \
+ GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output));
+
+TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Greater, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
-#define TF_LITE_LESS(type, opname) \
- reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Less, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
- if (requires_broadcast) {
- switch (input1->type) {
- case kTfLiteFloat32:
- TF_LITE_LESS(float, BroadcastLess);
- break;
- case kTfLiteInt32:
- TF_LITE_LESS(int32_t, BroadcastLess);
- break;
- case kTfLiteInt64:
- TF_LITE_LESS(int64_t, BroadcastLess);
- break;
- default:
- context->ReportError(context,
- "Does not support type other than float|int");
- return kTfLiteError;
- }
- } else {
- switch (input1->type) {
- case kTfLiteFloat32:
- TF_LITE_LESS(float, Less);
- break;
- case kTfLiteInt32:
- TF_LITE_LESS(int32_t, Less);
- break;
- case kTfLiteInt64:
- TF_LITE_LESS(int64_t, Less);
- break;
- default:
- context->ReportError(context,
- "Does not support type other than float|int");
- return kTfLiteError;
- }
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
}
-#undef TF_LITE_LESS
return kTfLiteOk;
}
} // namespace comparisons
+TfLiteRegistration* Register_GREATER() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::GreaterEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_GREATER_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::GreaterEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_LESS() {
- static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare,
- comparisons::LessEval};
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LESS_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::LessEqualEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index da2d7f8589..835d238d36 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -23,6 +23,139 @@ namespace {
using ::testing::ElementsAreArray;
+class GreaterOpModel : public SingleOpModel {
+ public:
+ GreaterOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, GreaterFloat) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
+ false, true, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
+class GreaterEqualOpModel : public SingleOpModel {
+ public:
+ GreaterEqualOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
+ BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, GreaterEqualFloat) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualInt) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualBroadcast) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
+ GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
+ false, true, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
class LessOpModel : public SingleOpModel {
public:
LessOpModel(std::initializer_list<int> input1_shape,
@@ -47,7 +180,7 @@ class LessOpModel : public SingleOpModel {
int output_;
};
-TEST(ArgMaxOpTest, LessFloat) {
+TEST(ComparisonsTest, LessFloat) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
@@ -57,7 +190,7 @@ TEST(ArgMaxOpTest, LessFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessInt) {
+TEST(ComparisonsTest, LessInt) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
@@ -67,7 +200,7 @@ TEST(ArgMaxOpTest, LessInt) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessBroadcast) {
+TEST(ComparisonsTest, LessBroadcast) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
@@ -77,7 +210,7 @@ TEST(ArgMaxOpTest, LessBroadcast) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessBroadcastTwoD) {
+TEST(ComparisonsTest, LessBroadcastTwoD) {
LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
@@ -88,6 +221,72 @@ TEST(ArgMaxOpTest, LessBroadcastTwoD) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
}
+class LessEqualOpModel : public SingleOpModel {
+ public:
+ LessEqualOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, LessEqualFloat) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualInt) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualBroadcast) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
+ LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
+ true, false, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index df29172f83..d8340d426a 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -5,6 +5,7 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
tflite_deps_intel = [
"@arm_neon_2_x86_sse",
@@ -157,6 +158,7 @@ cc_library(
":quantization_util",
":strided_slice_logic",
":types",
+ ":reference_base",
":round",
"//third_party/eigen3",
"@gemmlowp",
@@ -386,6 +388,9 @@ cc_library(
":armv7a": [
":neon_tensor_utils",
],
+ ":haswell": [
+ ":neon_tensor_utils",
+ ],
":ios_armv7": [
":neon_tensor_utils",
],
@@ -424,6 +429,7 @@ cc_test(
"//conditions:default": [],
}),
linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
deps = [
":tensor_utils",
"//tensorflow/contrib/lite:builtin_op_data",
@@ -458,3 +464,5 @@ cc_test(
)
exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"])
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index 18601df22c..ede95dfee0 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -113,6 +113,20 @@ inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier,
right_shift);
}
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value,
+ "Only unsigned integer types handled.");
+ const T one_in_leading_positive = static_cast<T>(1)
+ << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index fd14cb23ea..580d208beb 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -38,6 +39,16 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
+// Unoptimized reference ops:
+using reference_ops::BroadcastGreater;
+using reference_ops::BroadcastGreaterEqual;
+using reference_ops::BroadcastLess;
+using reference_ops::BroadcastLessEqual;
+using reference_ops::Greater;
+using reference_ops::GreaterEqual;
+using reference_ops::Less;
+using reference_ops::LessEqual;
+
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
// construct the suitable Eigen type for the constness of the
@@ -5851,10 +5862,26 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
+void TypedMemset(void* ptr, T value, size_t num) {
+ // Optimization for common cases where memset() will suffice.
+ if (value == 0 || std::is_same<T, uint8_t>::value) {
+ memset(ptr, value, num * sizeof(T));
+ } else {
+ // Default implementation for cases where memset() will not preserve the
+ // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
+ char* pos = static_cast<char*>(ptr);
+ for (size_t i = 0; i < num; ++i) {
+ memcpy(pos, &value, sizeof(T));
+ pos = pos + sizeof(T);
+ }
+ }
+}
+
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
gemmlowp::ScopedProfilingLabel label("Pad");
TFLITE_DCHECK_EQ(left_paddings.size(), 4);
TFLITE_DCHECK_EQ(right_paddings.size(), 4);
@@ -5877,27 +5904,28 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
const int input_depth = ArraySize(input_dims, 0);
if (left_b_padding != 0) {
- memset(output_data, pad_value,
- left_b_padding * output_height * output_width * output_depth *
- sizeof(T));
+ TypedMemset<T>(
+ output_data, pad_value,
+ left_b_padding * output_height * output_width * output_depth);
}
for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
++out_b) {
if (left_h_padding != 0) {
- memset(output_data + Offset(output_dims, 0, 0, 0, out_b), pad_value,
- left_h_padding * output_width * output_depth * sizeof(T));
+ TypedMemset<T>(output_data + Offset(output_dims, 0, 0, 0, out_b),
+ pad_value, left_h_padding * output_width * output_depth);
}
for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
++out_h) {
if (left_w_padding != 0) {
- memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), pad_value,
- left_w_padding * output_depth * sizeof(T));
+ TypedMemset<T>(output_data + Offset(output_dims, 0, 0, out_h, out_b),
+ pad_value, left_w_padding * output_depth);
}
for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
++out_w) {
if (left_d_padding != 0) {
- memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b),
- pad_value, left_d_padding * sizeof(T));
+ TypedMemset<T>(
+ output_data + Offset(output_dims, 0, out_w, out_h, out_b),
+ pad_value, left_d_padding);
}
T* out = output_data +
@@ -5908,35 +5936,46 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
memcpy(out, in, input_depth * sizeof(T));
if (right_d_padding != 0) {
- memset(
+ TypedMemset<T>(
output_data + Offset(output_dims, output_depth - right_d_padding,
out_w, out_h, out_b),
- pad_value, right_d_padding * sizeof(T));
+ pad_value, right_d_padding);
}
}
if (right_w_padding != 0) {
- memset(
+ TypedMemset<T>(
output_data + Offset(output_dims, 0, output_width - right_w_padding,
out_h, out_b),
- pad_value, right_w_padding * output_depth * sizeof(T));
+ pad_value, right_w_padding * output_depth);
}
}
if (right_h_padding != 0) {
- memset(output_data + Offset(output_dims, 0, 0,
- output_height - right_h_padding, out_b),
- pad_value,
- right_h_padding * output_width * output_depth * sizeof(T));
+ TypedMemset<T>(
+ output_data +
+ Offset(output_dims, 0, 0, output_height - right_h_padding, out_b),
+ pad_value, right_h_padding * output_width * output_depth);
}
}
if (right_b_padding != 0) {
- memset(output_data +
- Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
- 0,
- right_b_padding * output_height * output_width * output_depth *
- sizeof(T));
+ TypedMemset<T>(
+ output_data +
+ Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
+ pad_value,
+ right_b_padding * output_height * output_width * output_depth);
}
}
+// Legacy Pad() method that casts an int32_t to T before padding.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
@@ -6279,6 +6318,59 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
+// UNOPTIMIZED COPY of Select from reference_ops.h.
+template <typename D, typename T>
+inline void Select(const D* input_condition_data,
+ const Dims<4>& input_condition_dims, const T* input_x_data,
+ const Dims<4>& input_x_dims, const T* input_y_data,
+ const Dims<4>& input_y_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ const int64_t batches =
+ MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims,
+ 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims,
+ 2, output_dims, 2);
+ const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims,
+ 1, input_y_dims, 1, output_dims, 1);
+ const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims,
+ 0, input_y_dims, 0, output_dims, 0);
+
+ const int64_t num_elements = batches * height * width * depth;
+ for (int64_t i = 0; i < num_elements; ++i) {
+ output_data[i] =
+ input_condition_data[i] ? input_x_data[i] : input_y_data[i];
+ }
+}
+
+// UNOPTIMIZED COPY of RankOneSelect from reference_ops.h.
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+ const Dims<4>& input_condition_dims,
+ const T* input_x_data, const Dims<4>& input_x_dims,
+ const T* input_y_data, const Dims<4>& input_y_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ const int64_t rank = ArraySize(input_condition_dims, 0);
+
+ const int64_t batches =
+ MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2);
+ const int64_t width =
+ MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1);
+ const int64_t depth =
+ MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(rank, batches);
+
+ int64_t offset = 0;
+ int64_t size = depth * height * width;
+ for (int64_t i = 0; i < rank; i++) {
+ const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
+ memcpy(output_data + offset, input_data + offset, size * sizeof(T));
+ }
+}
+
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 3e9a3c29ee..2d74b3d384 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -167,6 +167,7 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) {
EXPECT_EQ(qp.zero_point, 0);
}
+#ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) {
// Assumption is that zero is within the range.
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), "");
@@ -176,6 +177,7 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) {
// Assumption is that zero is within the range.
EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), "");
}
+#endif // GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) {
QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0);
@@ -189,6 +191,7 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
EXPECT_EQ(qp.zero_point, 255);
}
+#ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
}
@@ -261,6 +264,7 @@ TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31));
EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31));
}
+#endif // GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, CalculateInputRadius) {
EXPECT_EQ(CalculateInputRadius(4, 27), 15);
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 798b55abc7..e2978cfd67 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -35,35 +35,6 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
- int32 x, int32 quantized_multiplier, int right_shift) {
- using gemmlowp::RoundingDivideByPOT;
- using gemmlowp::SaturatingRoundingDoublingHighMul;
- return RoundingDivideByPOT(
- SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
-}
-
-inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
- int32 x, int32 quantized_multiplier, int left_shift) {
- using gemmlowp::SaturatingRoundingDoublingHighMul;
- return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
- quantized_multiplier);
-}
-
-template <typename T>
-int CountLeadingZeros(T integer_input) {
- static_assert(std::is_unsigned<T>::value,
- "Only unsigned integer types handled.");
- const T one_in_leading_positive = static_cast<T>(1)
- << (std::numeric_limits<T>::digits - 1);
- int leading_zeros = 0;
- while (integer_input < one_in_leading_positive) {
- integer_input <<= 1;
- ++leading_zeros;
- }
- return leading_zeros;
-}
-
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
// BROADCASTING.
//
@@ -3158,10 +3129,10 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
TFLITE_DCHECK_EQ(left_paddings.size(), 4);
TFLITE_DCHECK_EQ(right_paddings.size(), 4);
@@ -3194,7 +3165,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
out_w >= output_width - right_w_padding ||
out_d < left_d_padding ||
out_d >= output_depth - right_d_padding) {
- *out_ptr++ = static_cast<T>(pad_value);
+ *out_ptr++ = pad_value;
} else {
*out_ptr++ = *in_ptr++;
}
@@ -3204,6 +3175,17 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
}
}
+// Legacy Pad() method that casts an int32_t to T before padding.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
template <typename T>
inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
@@ -3603,17 +3585,29 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Less(int64_t num_elements, const T* input1, const T* input2,
- bool* output) {
- for (int64_t i = 0; i < num_elements; ++i) {
- output[i] = input1[i] < input2[i];
- }
+inline bool GreaterFn(T lhs, T rhs) {
+ return lhs > rhs;
+}
+template <typename T>
+inline bool GreaterEqualFn(T lhs, T rhs) {
+ return lhs >= rhs;
+}
+template <typename T>
+inline bool LessFn(T lhs, T rhs) {
+ return lhs < rhs;
+}
+template <typename T>
+inline bool LessEqualFn(T lhs, T rhs) {
+ return lhs <= rhs;
}
template <typename T>
-inline void Less(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
+using ComparisonFn = bool (*)(T, T);
+
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
const int64_t batches =
MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
const int64_t height =
@@ -3622,31 +3616,201 @@ inline void Less(const T* input1_data, const Dims<4>& input1_dims,
MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
const int64_t depth =
MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
- Less(batches * height * width * depth, input1_data, input2_data, output_data);
+ for (int64_t i = 0; i < batches * height * width * depth; ++i) {
+ output_data[i] = F(input1_data[i], input2_data[i]);
+ }
}
-template <typename T1, typename T2>
-inline void BroadcastLess(T1* input1_data, const Dims<4>& input1_dims,
- T2* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastLess");
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ const int64_t batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int64_t width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int64_t depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int64_t i = 0; i < batches * height * width * depth; ++i) {
+ const int32 input1_val = input1_offset + input1_data[i];
+ const int32 input2_val = input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ output_data[i] = F(scaled_input1_val, scaled_input2_val);
+ }
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
+ const T* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] <
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ F(scaled_input1_val, scaled_input2_val);
}
}
}
}
}
+#define TFLITE_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ }
+TFLITE_COMPARISON_OP(Greater);
+TFLITE_COMPARISON_OP(GreaterEqual);
+TFLITE_COMPARISON_OP(Less);
+TFLITE_COMPARISON_OP(LessEqual);
+#undef TFLITE_COMPARISON_OP
+
+template <typename D, typename T>
+inline void Select(const D* input_condition_data,
+ const Dims<4>& input_condition_dims, const T* input_x_data,
+ const Dims<4>& input_x_dims, const T* input_y_data,
+ const Dims<4>& input_y_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ const int64_t batches =
+ MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims,
+ 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims,
+ 2, output_dims, 2);
+ const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims,
+ 1, input_y_dims, 1, output_dims, 1);
+ const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims,
+ 0, input_y_dims, 0, output_dims, 0);
+
+ const int64_t num_elements = batches * height * width * depth;
+ for (int64_t i = 0; i < num_elements; ++i) {
+ output_data[i] =
+ input_condition_data[i] ? input_x_data[i] : input_y_data[i];
+ }
+}
+
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+ const Dims<4>& input_condition_dims,
+ const T* input_x_data, const Dims<4>& input_x_dims,
+ const T* input_y_data, const Dims<4>& input_y_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ const int64_t rank = ArraySize(input_condition_dims, 0);
+
+ const int64_t batches =
+ MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2);
+ const int64_t width =
+ MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1);
+ const int64_t depth =
+ MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(rank, batches);
+
+ int64_t offset = 0;
+ int64_t size = depth * height * width;
+ for (int64_t i = 0; i < rank; i++) {
+ const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
+ memcpy(output_data + offset, input_data + offset, size * sizeof(T));
+ offset += size;
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 4f9449a225..9e1e4658e9 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -37,9 +37,15 @@ struct PadContext {
PadContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, 0);
paddings = GetInput(context, node, 1);
+ if (NumInputs(node) == 3) {
+ constant_values = GetOptionalInputTensor(context, node, 2);
+ } else {
+ constant_values = nullptr;
+ }
output = GetOutput(context, node, 0);
dims = NumDimensions(input);
}
+ TfLiteTensor* constant_values;
TfLiteTensor* input;
TfLiteTensor* paddings;
TfLiteTensor* output;
@@ -76,11 +82,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
PadContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ if (op_context.constant_values != nullptr) {
+ TF_LITE_ENSURE_EQ(context, op_context.input->type,
+ op_context.constant_values->type);
+ }
// TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
@@ -98,6 +108,11 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
+ if (op_context.constant_values != nullptr) {
+ // Ensure that constant_values is a scalar.
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
+ }
+
// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
@@ -119,48 +134,70 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::Pad(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ type::PadV2(GetTensorData<scalar>(op_context.input), \
+ GetTensorDims(op_context.input), before_padding, after_padding, \
+ GetTensorData<scalar>(op_context.output), \
+ GetTensorDims(op_context.output), pad_value)
switch (op_context.input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
+ float pad_value = op_context.constant_values == nullptr
+ ? 0.f
+ : *GetTensorData<float>(op_context.constant_values);
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, float, 0);
+ TF_LITE_PAD(reference_ops, float, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, float, 0);
+ TF_LITE_PAD(optimized_ops, float, pad_value);
+ }
+ } break;
+ case kTfLiteUInt8: {
+ uint8_t pad_value;
+ if (op_context.constant_values == nullptr) {
+ // Quantized Pad requires that 0 is represented in the quantized
+ // range.
+ TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
+ std::numeric_limits<uint8_t>::min());
+ TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
+ std::numeric_limits<uint8_t>::max());
+ pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
+ } else {
+ // Quantized Pad requires that 'constant_values' is represented in the
+ // same quantized range as the input and output tensors.
+ TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
+ op_context.constant_values->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
+ op_context.constant_values->params.scale);
+ pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
}
- break;
- case kTfLiteUInt8:
- // Quantized Pad requires that 0 is represented in the quantized range.
- TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
- std::numeric_limits<uint8_t>::min());
- TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
- std::numeric_limits<uint8_t>::max());
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, uint8_t,
- op_context.output->params.zero_point);
+ TF_LITE_PAD(reference_ops, uint8_t, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, uint8_t,
- op_context.output->params.zero_point);
+ TF_LITE_PAD(optimized_ops, uint8_t, pad_value);
}
- break;
- case kTfLiteInt32:
+ } break;
+ case kTfLiteInt32: {
+ int32_t pad_value =
+ op_context.constant_values == nullptr
+ ? 0
+ : *GetTensorData<int32_t>(op_context.constant_values);
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, int32_t, 0);
+ TF_LITE_PAD(reference_ops, int32_t, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, int32_t, 0);
+ TF_LITE_PAD(optimized_ops, int32_t, pad_value);
}
- break;
- case kTfLiteInt64:
+ } break;
+ case kTfLiteInt64: {
+ int64_t pad_value =
+ op_context.constant_values == nullptr
+ ? 0L
+ : *GetTensorData<int64_t>(op_context.constant_values);
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, int64_t, 0);
+ TF_LITE_PAD(reference_ops, int64_t, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, int64_t, 0);
+ TF_LITE_PAD(optimized_ops, int64_t, pad_value);
}
- break;
+ } break;
default:
context->ReportError(context, "Type is currently not supported by Pad.");
return kTfLiteError;
@@ -185,6 +222,21 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() {
TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
+// Also register Pad as PadV2.
+TfLiteRegistration* Register_PADV2_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
+ pad::Eval<pad::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_PADV2_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
+ pad::Eval<pad::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); }
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index c06237e572..f8b9064fbb 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -24,21 +24,26 @@ namespace {
using ::testing::ElementsAreArray;
using ::testing::Matcher;
+template <typename T>
class PadOpModel : public SingleOpModel {
public:
- void SetInput(std::initializer_list<float> data) {
- PopulateTensor<float>(input_, data);
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
}
void SetQuantizedInput(std::initializer_list<float> data) {
QuantizeAndPopulate<uint8_t>(input_, data);
}
+ void SetQuantizedPadValue(float data) {
+ QuantizeAndPopulate<uint8_t>(constant_values_, {data});
+ }
+
void SetPaddings(std::initializer_list<int> paddings) {
PopulateTensor<int>(paddings_, paddings);
}
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
std::vector<float> GetDequantizedOutput() {
@@ -50,6 +55,59 @@ class PadOpModel : public SingleOpModel {
int input_;
int output_;
int paddings_;
+ int constant_values_;
+};
+
+namespace {
+
+// Returns the corresponding TensorType given the type T.
+template <typename T>
+TensorType GetTensorType() {
+ if (std::is_same<T, float>::value) return TensorType_FLOAT32;
+ if (std::is_same<T, int32_t>::value) return TensorType_INT32;
+ if (std::is_same<T, uint8_t>::value) return TensorType_UINT8;
+ return TensorType_MIN; // default value
+}
+
+} // namespace
+
+// Tests case where paddings is a const tensor. Type T is the dtype.
+template <typename T>
+class PadV2OpConstModel : public PadOpModel<T> {
+ public:
+ PadV2OpConstModel(const TensorData& input,
+ std::initializer_list<int> paddings_shape,
+ std::initializer_list<int> paddings, T constant_values,
+ const TensorData& output) {
+ this->input_ = this->AddInput(input);
+ this->paddings_ =
+ this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
+ this->constant_values_ =
+ this->AddConstInput(GetTensorType<T>(), {constant_values}, {1});
+
+ this->output_ = this->AddOutput(output);
+
+ this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
+ CreatePadV2Options(this->builder_).Union());
+ this->BuildInterpreter({input.shape});
+ }
+
+ PadV2OpConstModel(const TensorData& input,
+ std::initializer_list<int> paddings_shape,
+ std::initializer_list<int> paddings,
+ const TensorData& constant_values,
+ const TensorData& output) {
+ this->input_ = this->AddInput(input);
+ this->paddings_ =
+ this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
+ this->constant_values_ = this->AddInput(constant_values);
+
+ this->output_ = this->AddOutput(output);
+
+ this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
+ CreatePadV2Options(this->builder_).Union());
+ this->BuildInterpreter({input.shape});
+ }
};
// Tests case where paddings is a const tensor.
@@ -58,7 +116,7 @@ class PadOpModel : public SingleOpModel {
// PadOpDynamicModel m(input_shape, paddings_shape, paddings_data);
// m.SetInput(input_data);
// m.Invoke();
-class PadOpConstModel : public PadOpModel {
+class PadOpConstModel : public PadOpModel<float> {
public:
PadOpConstModel(const TensorData& input,
std::initializer_list<int> paddings_shape,
@@ -66,6 +124,7 @@ class PadOpConstModel : public PadOpModel {
const TensorData& output) {
input_ = AddInput(input);
paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
+ constant_values_ = AddNullInput();
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
@@ -75,19 +134,52 @@ class PadOpConstModel : public PadOpModel {
};
// Test case where paddings is a non-const tensor.
+template <typename T>
+class PadV2OpDynamicModel : public PadOpModel<T> {
+ public:
+ PadV2OpDynamicModel(const TensorData& input,
+ std::initializer_list<int> paddings_shape,
+ T constant_values, const TensorData& output) {
+ this->input_ = this->AddInput(input);
+ this->paddings_ = this->AddInput(TensorType_INT32);
+ this->constant_values_ =
+ this->AddConstInput(GetTensorType<T>(), {constant_values}, {1});
+ this->output_ = this->AddOutput(output);
+
+ this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
+ CreatePadV2Options(this->builder_).Union());
+ this->BuildInterpreter({input.shape, paddings_shape});
+ }
+ PadV2OpDynamicModel(const TensorData& input,
+ std::initializer_list<int> paddings_shape,
+ const TensorData& constant_values,
+ const TensorData& output) {
+ this->input_ = this->AddInput(input);
+ this->paddings_ = this->AddInput(TensorType_INT32);
+ this->constant_values_ = this->AddInput(constant_values);
+ this->output_ = this->AddOutput(output);
+
+ this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
+ CreatePadV2Options(this->builder_).Union());
+ this->BuildInterpreter({input.shape, paddings_shape});
+ }
+};
+
+// Test case where paddings is a non-const tensor.
//
// Example usage is as follows:
// PadOpDynamicModel m(input_shape, paddings_shape);
// m.SetInput(input_data);
// m.SetPaddings(paddings_data);
// m.Invoke();
-class PadOpDynamicModel : public PadOpModel {
+class PadOpDynamicModel : public PadOpModel<float> {
public:
PadOpDynamicModel(const TensorData& input,
std::initializer_list<int> paddings_shape,
const TensorData& output) {
input_ = AddInput(input);
paddings_ = AddInput(TensorType_INT32);
+ constant_values_ = AddNullInput();
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
@@ -237,6 +329,272 @@ TEST_F(QuantizedPadOpTest, AdvancedDynamicTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
}
+TEST(PadV2OpTest, TooManyDimensions) {
+ EXPECT_DEATH(PadV2OpConstModel<float>(
+ {TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
+ {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
+ {TensorType_FLOAT32}),
+ "dims != 4");
+}
+
+TEST(PadV2OpTest, UnequalDimensions) {
+ EXPECT_DEATH(
+ PadV2OpConstModel<float>({TensorType_FLOAT32, {1, 1, 2, 1}}, {3, 2},
+ {1, 1, 2, 2, 3, 3}, 0.0, {TensorType_FLOAT32}),
+ "3 != 4");
+}
+
+TEST(PadV2OpTest, InvalidPadValue) {
+ EXPECT_DEATH(PadV2OpConstModel<float>({TensorType_FLOAT32, {1, 1, 2, 1}},
+ {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}, 0.0,
+ {TensorType_FLOAT32}),
+ "Pad value has to be greater than equal to 0.");
+}
+
+TEST(PadV2OpTest, SimpleConstTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
+ {0, 0, 1, 1, 1, 1, 0, 0}, 0.0,
+ {TensorType_FLOAT32});
+ m.SetInput({1, 2, 3, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
+ 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST(PadV2OpTest, SimpleConstFloat32ValuedTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
+ {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_FLOAT32});
+ m.SetInput({1, 2, 3, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4,
+ 5, 5, 5, 5, 5}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST(PadV2OpTest, Simple4DConstFloat32ValuedTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2},
+ {0, 1, 0, 0, 0, 0, 0, 1}, 5, {TensorType_FLOAT32});
+ m.SetInput({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 5, 3, 5, 5, 5, 5, 5}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2}));
+}
+
+TEST(PadV2OpTest, SimpleConstInt32ValuedTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadV2OpConstModel<int32_t> m({TensorType_INT32, {1, 2, 2, 1}}, {4, 2},
+ {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_INT32});
+ m.SetInput({1, 2, 3, 4});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4,
+ 5, 5, 5, 5, 5}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST(PadV2OpTest, SimpleDynamicTest) {
+ PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 0.0,
+ {TensorType_FLOAT32});
+ m.SetInput({1, 2, 3, 4});
+ m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
+ 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST(PadV2OpTest, SimpleDynamicValuedTest) {
+ PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 5,
+ {TensorType_FLOAT32});
+ m.SetInput({1, 2, 3, 4});
+ m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4,
+ 5, 5, 5, 5, 5}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST(PadV2OpTest, AdvancedConstTest) {
+ PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
+ {0, 0, 0, 2, 1, 3, 0, 0}, 0, {TensorType_FLOAT32});
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
+TEST(PadV2OpTest, AdvancedDynamicTest) {
+ PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, 0,
+ {TensorType_FLOAT32});
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
+class QuantizedPadV2OpTest : public ::testing::Test {
+ protected:
+ std::vector<Matcher<float>> DequantizedArrayNear(
+ const std::vector<float>& values, const float min, const float max) {
+ const float quantization_tolerance = (max - min) / 255.0;
+ return ArrayFloatNear(values, quantization_tolerance);
+ }
+};
+
+TEST_F(QuantizedPadV2OpTest, ZeroNotInQuantizationRange) {
+ // The test_util and actual quantization code currently ensure that the range
+ // must include zero, but if that ever changes, this test will catch it.
+ EXPECT_DEATH(
+ PadV2OpConstModel<float> m({TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0},
+ {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, 0,
+ {TensorType_UINT8, {}, 1.0, 2.0}),
+ ".*Check failed: f_min <= 0.*");
+}
+
+TEST_F(QuantizedPadV2OpTest, SimpleConstTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
+ {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7});
+ m.SetQuantizedPadValue(0);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, SimpleDynamicTest) {
+ PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7});
+ m.SetQuantizedPadValue(0);
+ m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, AdvancedConstTest) {
+ PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0},
+ {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
+ {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
+ m.SetQuantizedPadValue(0);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, AdvancedDynamicTest) {
+ PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0},
+ {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
+ m.SetQuantizedPadValue(0);
+ m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, SimpleConstValuedTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
+ {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7});
+ m.SetQuantizedPadValue(-0.5);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
+ 0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, SimpleDynamicValuedTest) {
+ PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
+ {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7});
+ m.SetQuantizedPadValue(-0.5);
+ m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
+ 0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, AdvancedConstValuedTest) {
+ PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0},
+ {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
+ {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
+ m.SetQuantizedPadValue(-0.5);
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {-0.5, -0.8, 0.2, 0.9, -0.5, -0.5, -0.5, -0.5, 0.7, 0.1,
+ -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
+ -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
+TEST_F(QuantizedPadV2OpTest, AdvancedDynamicValuedTest) {
+ PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0},
+ {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0},
+ {TensorType_UINT8, {}, -1.0, 1.0});
+ m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
+ m.SetQuantizedPadValue(-0.5);
+ m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(DequantizedArrayNear(
+ {-0.5, -0.8, 0.2, 0.9, -0.5, -0.5, -0.5, -0.5, 0.7, 0.1,
+ -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
+ -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
+ -1.0, 1.0)));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 29ea718a96..5df35aac62 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -60,6 +60,7 @@ TfLiteRegistration* Register_LSTM();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration* Register_PAD();
+TfLiteRegistration* Register_PADV2();
TfLiteRegistration* Register_RESHAPE();
TfLiteRegistration* Register_RESIZE_BILINEAR();
TfLiteRegistration* Register_SKIP_GRAM();
@@ -79,9 +80,13 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_GREATER();
+TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
+TfLiteRegistration* Register_LESS_EQUAL();
TfLiteRegistration* Register_FLOOR();
TfLiteRegistration* Register_NEG();
+TfLiteRegistration* Register_SELECT();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -121,6 +126,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
Register_UNIDIRECTIONAL_SEQUENCE_LSTM());
AddBuiltin(BuiltinOperator_PAD, Register_PAD());
+ AddBuiltin(BuiltinOperator_PADV2, Register_PADV2());
AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR());
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
@@ -142,9 +148,13 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
+ AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
+ AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL());
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
+ AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
new file mode 100644
index 0000000000..029ad9a709
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -0,0 +1,125 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace select {
+
+constexpr int kInputTensorCondition = 0;
+constexpr int kInputTensorX = 1;
+constexpr int kInputTensorY = 2;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input_condition =
+ GetInput(context, node, kInputTensorCondition);
+ TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
+ TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Input must be bool.
+ TF_LITE_ENSURE(context, input_condition->type == kTfLiteBool);
+
+ // Input tensors must have the same type and size
+ TF_LITE_ENSURE_EQ(context, input_x->type, input_y->type);
+ TF_LITE_ENSURE(context, HaveSameShapes(input_x, input_y));
+ output->type = input_x->type;
+
+ // Either the same shape, or input_condition must be Rank 1 and match over the
+ // first dimension.
+ bool same_shape = HaveSameShapes(input_condition, input_x);
+ if (!same_shape && NumDimensions(input_condition) == 1) {
+ same_shape =
+ SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0);
+ }
+
+ TF_LITE_ENSURE(context, same_shape);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_x->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input_condition =
+ GetInput(context, node, kInputTensorCondition);
+ TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
+ TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ bool is_rank_one = !HaveSameShapes(input_condition, input_x);
+
+#define TF_LITE_SELECT(type, op) \
+ reference_ops::op(GetTensorData<bool>(input_condition), \
+ GetTensorDims(input_condition), \
+ GetTensorData<type>(input_x), GetTensorDims(input_x), \
+ GetTensorData<type>(input_y), GetTensorDims(input_y), \
+ GetTensorData<type>(output), GetTensorDims(output));
+
+#define TF_LITE_SWITCH(type, op) \
+ switch (type) { \
+ break; \
+ case kTfLiteBool: \
+ TF_LITE_SELECT(bool, op); \
+ break; \
+ case kTfLiteFloat32: \
+ TF_LITE_SELECT(float, op); \
+ break; \
+ case kTfLiteUInt8: \
+ TF_LITE_SELECT(uint8_t, op); \
+ break; \
+ case kTfLiteInt32: \
+ TF_LITE_SELECT(int32_t, op); \
+ break; \
+ case kTfLiteInt64: \
+ TF_LITE_SELECT(int64_t, op); \
+ break; \
+ default: \
+ context->ReportError(context, \
+ "Does not support type other than bool|float|int"); \
+ return kTfLiteError; \
+ }
+
+ if (is_rank_one) {
+ TF_LITE_SWITCH(input_x->type, RankOneSelect);
+ } else {
+ TF_LITE_SWITCH(input_x->type, Select);
+ }
+
+#undef TF_LITE_SELECT
+#undef TF_LITE_SWITCH
+ return kTfLiteOk;
+}
+
+} // namespace select
+
+TfLiteRegistration* Register_SELECT() {
+ static TfLiteRegistration r = {nullptr, nullptr, select::SelectPrepare,
+ select::SelectEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc
new file mode 100644
index 0000000000..cfe24a5fc9
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/select_test.cc
@@ -0,0 +1,143 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class SelectOpModel : public SingleOpModel {
+ public:
+ SelectOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ std::initializer_list<int> input3_shape,
+ TensorType input_type) {
+ input1_ = AddInput(TensorType_BOOL);
+ input2_ = AddInput(input_type);
+ input3_ = AddInput(input_type);
+ output_ = AddOutput(input_type);
+ SetBuiltinOp(BuiltinOperator_SELECT, BuiltinOptions_SelectOptions,
+ CreateSelectOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape, input3_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+ int input3() { return input3_; }
+
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int input3_;
+ int output_;
+};
+
+TEST(SelectOpTest, SelectBool) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_BOOL);
+
+ model.PopulateTensor<bool>(model.input1(), {true, false, true, false});
+ model.PopulateTensor<bool>(model.input2(), {false, false, false, false});
+ model.PopulateTensor<bool>(model.input3(), {true, true, true, true});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<bool>(),
+ ElementsAreArray({false, true, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(SelectOpTest, SelectFloat) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_FLOAT32);
+
+ model.PopulateTensor<bool>(model.input1(), {true, false, true, false});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.3, 0.4});
+ model.PopulateTensor<float>(model.input3(), {0.5, 0.6, 0.7, 0.8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<float>(), ElementsAreArray({0.1, 0.6, 0.3, 0.8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(SelectOpTest, SelectUInt8) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_UINT8);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<uint8>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<uint8>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<uint8>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(SelectOpTest, SelectInt32) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_INT32);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(SelectOpTest, RankOneSelectInt32) {
+ SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 3, 4}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
+}
+
+TEST(SelectOpTest, RankZeroSelectInt32) {
+ SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
+
+ model.PopulateTensor<bool>(model.input1(), {false});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 0bb28b50b2..5a6c85e97e 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -22,23 +22,6 @@ namespace tflite {
using ::testing::FloatNear;
using ::testing::Matcher;
-namespace {
-template <typename T>
-std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) {
- // These are required by many quantized operations.
- CHECK_LE(f_min, 0);
- CHECK_GE(f_max, 0);
- T q_min = std::numeric_limits<T>::min();
- T q_max = std::numeric_limits<T>::max();
- float range = q_max - q_min;
- float scale = (f_max - f_min) / range;
- int32_t zero_point = std::min(
- q_max,
- std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale))));
- return {scale, zero_point};
-}
-} // namespace
-
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
float max_abs_error) {
std::vector<Matcher<float>> matchers;
@@ -49,69 +32,8 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
return matchers;
}
-int SingleOpModel::AddTensor(TensorData t, std::initializer_list<int> data) {
- int id = tensors_.size();
-
- // This is slightly different depending on whether we are adding a
- // quantized or a regular tensor.
- bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0);
-
- flatbuffers::Offset<QuantizationParameters> q_params = 0;
-
- if (is_quantized) {
- if (t.min != 0 || t.max != 0) {
- if (t.type == TensorType_UINT8) {
- std::tie(t.scale, t.zero_point) =
- QuantizationParams<uint8_t>(t.min, t.max);
- } else if (t.type == TensorType_INT32) {
- std::tie(t.scale, t.zero_point) =
- QuantizationParams<int32_t>(t.min, t.max);
- } else {
- LOG(FATAL) << "No support for the requested quantized type";
- }
- t.min = 0;
- t.max = 0;
- }
-
- q_params = CreateQuantizationParameters(
- builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>({t.scale}),
- builder_.CreateVector<int64_t>({t.zero_point}));
- }
-
- int buffer_id = 0;
- if (data.size()) {
- // Initialize buffers list with empty buffer to allow for non-const tensors.
- if (buffers_.empty()) {
- buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
- }
-
- // Add data as a Buffer to buffers list.
- buffer_id = buffers_.size();
- auto data_buffer =
- builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()),
- sizeof(int) * data.size());
- buffers_.push_back(CreateBuffer(builder_, data_buffer));
- }
-
- tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>(t.shape),
- t.type, /*buffer=*/buffer_id,
- /*name=*/0, q_params));
-
- tensor_data_[id] = t;
-
- return id;
-}
-
int SingleOpModel::AddInput(const TensorData& t) {
- int id = AddTensor(t, {});
- inputs_.push_back(id);
- return id;
-}
-
-int SingleOpModel::AddConstInput(TensorType type,
- std::initializer_list<int> data,
- std::initializer_list<int> shape) {
- int id = AddTensor(TensorData{type, shape}, data);
+ int id = AddTensor<float>(t, {});
inputs_.push_back(id);
return id;
}
@@ -123,7 +45,7 @@ int SingleOpModel::AddNullInput() {
}
int SingleOpModel::AddOutput(const TensorData& t) {
- int id = AddTensor(t, {});
+ int id = AddTensor<float>(t, {});
outputs_.push_back(id);
return id;
}
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index 6fb6fe27eb..6a9fdf1112 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -116,9 +116,14 @@ class SingleOpModel {
int AddInput(TensorType type) { return AddInput(TensorData{type}); }
int AddInput(const TensorData& t);
- // Add a Tensor containing const data and return the tensor id.
- int AddConstInput(TensorType type, std::initializer_list<int> data,
- std::initializer_list<int> shape);
+ // Templated version of AddConstInput().
+ template <typename T>
+ int AddConstInput(TensorType type, std::initializer_list<T> data,
+ std::initializer_list<int> shape) {
+ int id = AddTensor(TensorData{type, shape}, data);
+ inputs_.push_back(id);
+ return id;
+ }
// Add a null input tensor (optional input) and return kOptionalTensor.
int AddNullInput();
@@ -224,7 +229,79 @@ class SingleOpModel {
std::unique_ptr<OpResolver> resolver_;
private:
- int AddTensor(TensorData t, std::initializer_list<int> data);
+ // TODO(gavinbelson): sync this method with
+ // //tensorflow/contrib/lite/kernels/internal/quantization_util.h?l=31
+ template <typename T>
+ std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) {
+ // These are required by many quantized operations.
+ CHECK_LE(f_min, 0);
+ CHECK_GE(f_max, 0);
+ T q_min = std::numeric_limits<T>::min();
+ T q_max = std::numeric_limits<T>::max();
+ float range = q_max - q_min;
+ float scale = (f_max - f_min) / range;
+ int32_t zero_point = std::min(
+ q_max,
+ std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale))));
+ return {scale, zero_point};
+ }
+
+ template <typename T>
+ int AddTensor(TensorData t, std::initializer_list<T> data) {
+ int id = tensors_.size();
+
+ // This is slightly different depending on whether we are adding a
+ // quantized or a regular tensor.
+ bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0);
+
+ flatbuffers::Offset<QuantizationParameters> q_params = 0;
+
+ if (is_quantized) {
+ if (t.min != 0 || t.max != 0) {
+ if (t.type == TensorType_UINT8) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<uint8_t>(t.min, t.max);
+ } else if (t.type == TensorType_INT32) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<int32_t>(t.min, t.max);
+ } else {
+ LOG(FATAL) << "No support for the requested quantized type";
+ }
+ t.min = 0;
+ t.max = 0;
+ }
+
+ q_params = CreateQuantizationParameters(
+ builder_, /*min=*/0, /*max=*/0,
+ builder_.CreateVector<float>({t.scale}),
+ builder_.CreateVector<int64_t>({t.zero_point}));
+ }
+
+ int buffer_id = 0;
+ if (data.size()) {
+ // Initialize buffers list with empty buffer to allow for non-const
+ // tensors.
+ if (buffers_.empty()) {
+ buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
+ }
+
+ // Add data as a Buffer to buffers list.
+ buffer_id = buffers_.size();
+ auto data_buffer =
+ builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()),
+ sizeof(T) * data.size());
+ buffers_.push_back(CreateBuffer(builder_, data_buffer));
+ }
+
+ tensors_.push_back(CreateTensor(builder_,
+ builder_.CreateVector<int>(t.shape), t.type,
+ /*buffer=*/buffer_id,
+ /*name=*/0, q_params));
+
+ tensor_data_[id] = t;
+
+ return id;
+ }
std::map<int, TensorData> tensor_data_;
std::vector<int32_t> inputs_;
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 590f042e21..e89036ce73 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -569,6 +569,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_PAD: {
break;
}
+ case BuiltinOperator_PADV2: {
+ break;
+ }
case BuiltinOperator_RESHAPE: {
auto* params = MallocPOD<TfLiteReshapeParams>();
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
@@ -669,7 +672,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_LESS: {
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_SELECT: {
break;
}
case BuiltinOperator_DELEGATE: {
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 6eac18c4f5..eb451397bd 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -61,6 +61,10 @@ NNAPIAllocation::~NNAPIAllocation() {
}
NNAPIDelegate::~NNAPIDelegate() {
+ if (nn_compiled_model_) {
+ ANeuralNetworksCompilation_free(nn_compiled_model_);
+ nn_compiled_model_ = nullptr;
+ }
if (nn_model_) {
ANeuralNetworksModel_free(nn_model_);
nn_model_ = nullptr;
@@ -347,6 +351,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_L2_NORMALIZATION:
case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
case tflite::BuiltinOperator_PAD:
+ case tflite::BuiltinOperator_PADV2:
case tflite::BuiltinOperator_RESIZE_BILINEAR:
case tflite::BuiltinOperator_CALL:
case tflite::BuiltinOperator_SKIP_GRAM:
@@ -371,8 +376,12 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_GREATER:
+ case tflite::BuiltinOperator_GREATER_EQUAL:
case tflite::BuiltinOperator_LESS:
+ case tflite::BuiltinOperator_LESS_EQUAL:
case tflite::BuiltinOperator_NEG:
+ case tflite::BuiltinOperator_SELECT:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 5d89f7be62..2f5c39e7d7 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -137,6 +137,11 @@ enum BuiltinOperator : byte {
MINIMUM = 57,
LESS = 58,
NEG = 59,
+ PADV2 = 60,
+ GREATER = 61,
+ GREATER_EQUAL = 62,
+ LESS_EQUAL = 63,
+ SELECT = 64,
}
// Options for the builtin operators.
@@ -183,6 +188,11 @@ union BuiltinOptions {
ArgMaxOptions,
LessOptions,
NegOptions,
+ PadV2Options,
+ GreaterOptions,
+ GreaterEqualOptions,
+ LessEqualOptions,
+ SelectOptions,
}
enum Padding : byte { SAME, VALID }
@@ -316,6 +326,9 @@ table CallOptions {
table PadOptions {
}
+table PadV2Options {
+}
+
table ReshapeOptions {
new_shape:[int];
}
@@ -405,12 +418,24 @@ table ArgMaxOptions {
output_type : TensorType;
}
+table GreaterOptions {
+}
+
+table GreaterEqualOptions {
+}
+
table LessOptions {
}
+table LessEqualOptions {
+}
+
table NegOptions {
}
+table SelectOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index c172f77aa9..a2f0c8cdd2 100755..100644
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -88,6 +88,9 @@ struct CallOptionsT;
struct PadOptions;
struct PadOptionsT;
+struct PadV2Options;
+struct PadV2OptionsT;
+
struct ReshapeOptions;
struct ReshapeOptionsT;
@@ -151,12 +154,24 @@ struct MaximumMinimumOptionsT;
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct GreaterOptions;
+struct GreaterOptionsT;
+
+struct GreaterEqualOptions;
+struct GreaterEqualOptionsT;
+
struct LessOptions;
struct LessOptionsT;
+struct LessEqualOptions;
+struct LessEqualOptionsT;
+
struct NegOptions;
struct NegOptionsT;
+struct SelectOptions;
+struct SelectOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -276,11 +291,16 @@ enum BuiltinOperator {
BuiltinOperator_MINIMUM = 57,
BuiltinOperator_LESS = 58,
BuiltinOperator_NEG = 59,
+ BuiltinOperator_PADV2 = 60,
+ BuiltinOperator_GREATER = 61,
+ BuiltinOperator_GREATER_EQUAL = 62,
+ BuiltinOperator_LESS_EQUAL = 63,
+ BuiltinOperator_SELECT = 64,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_NEG
+ BuiltinOperator_MAX = BuiltinOperator_SELECT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[59] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[64] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -340,7 +360,12 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[59] {
BuiltinOperator_ARG_MAX,
BuiltinOperator_MINIMUM,
BuiltinOperator_LESS,
- BuiltinOperator_NEG
+ BuiltinOperator_NEG,
+ BuiltinOperator_PADV2,
+ BuiltinOperator_GREATER,
+ BuiltinOperator_GREATER_EQUAL,
+ BuiltinOperator_LESS_EQUAL,
+ BuiltinOperator_SELECT
};
return values;
}
@@ -407,6 +432,11 @@ inline const char **EnumNamesBuiltinOperator() {
"MINIMUM",
"LESS",
"NEG",
+ "PADV2",
+ "GREATER",
+ "GREATER_EQUAL",
+ "LESS_EQUAL",
+ "SELECT",
nullptr
};
return names;
@@ -461,11 +491,16 @@ enum BuiltinOptions {
BuiltinOptions_ArgMaxOptions = 40,
BuiltinOptions_LessOptions = 41,
BuiltinOptions_NegOptions = 42,
+ BuiltinOptions_PadV2Options = 43,
+ BuiltinOptions_GreaterOptions = 44,
+ BuiltinOptions_GreaterEqualOptions = 45,
+ BuiltinOptions_LessEqualOptions = 46,
+ BuiltinOptions_SelectOptions = 47,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_NegOptions
+ BuiltinOptions_MAX = BuiltinOptions_SelectOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[43] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[48] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -509,7 +544,12 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[43] {
BuiltinOptions_MaximumMinimumOptions,
BuiltinOptions_ArgMaxOptions,
BuiltinOptions_LessOptions,
- BuiltinOptions_NegOptions
+ BuiltinOptions_NegOptions,
+ BuiltinOptions_PadV2Options,
+ BuiltinOptions_GreaterOptions,
+ BuiltinOptions_GreaterEqualOptions,
+ BuiltinOptions_LessEqualOptions,
+ BuiltinOptions_SelectOptions
};
return values;
}
@@ -559,6 +599,11 @@ inline const char **EnumNamesBuiltinOptions() {
"ArgMaxOptions",
"LessOptions",
"NegOptions",
+ "PadV2Options",
+ "GreaterOptions",
+ "GreaterEqualOptions",
+ "LessEqualOptions",
+ "SelectOptions",
nullptr
};
return names;
@@ -741,6 +786,26 @@ template<> struct BuiltinOptionsTraits<NegOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_NegOptions;
};
+template<> struct BuiltinOptionsTraits<PadV2Options> {
+ static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options;
+};
+
+template<> struct BuiltinOptionsTraits<GreaterOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions;
+};
+
+template<> struct BuiltinOptionsTraits<GreaterEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions;
+};
+
+template<> struct BuiltinOptionsTraits<LessEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions;
+};
+
+template<> struct BuiltinOptionsTraits<SelectOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1108,6 +1173,46 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_NegOptions ?
reinterpret_cast<const NegOptionsT *>(value) : nullptr;
}
+ PadV2OptionsT *AsPadV2Options() {
+ return type == BuiltinOptions_PadV2Options ?
+ reinterpret_cast<PadV2OptionsT *>(value) : nullptr;
+ }
+ const PadV2OptionsT *AsPadV2Options() const {
+ return type == BuiltinOptions_PadV2Options ?
+ reinterpret_cast<const PadV2OptionsT *>(value) : nullptr;
+ }
+ GreaterOptionsT *AsGreaterOptions() {
+ return type == BuiltinOptions_GreaterOptions ?
+ reinterpret_cast<GreaterOptionsT *>(value) : nullptr;
+ }
+ const GreaterOptionsT *AsGreaterOptions() const {
+ return type == BuiltinOptions_GreaterOptions ?
+ reinterpret_cast<const GreaterOptionsT *>(value) : nullptr;
+ }
+ GreaterEqualOptionsT *AsGreaterEqualOptions() {
+ return type == BuiltinOptions_GreaterEqualOptions ?
+ reinterpret_cast<GreaterEqualOptionsT *>(value) : nullptr;
+ }
+ const GreaterEqualOptionsT *AsGreaterEqualOptions() const {
+ return type == BuiltinOptions_GreaterEqualOptions ?
+ reinterpret_cast<const GreaterEqualOptionsT *>(value) : nullptr;
+ }
+ LessEqualOptionsT *AsLessEqualOptions() {
+ return type == BuiltinOptions_LessEqualOptions ?
+ reinterpret_cast<LessEqualOptionsT *>(value) : nullptr;
+ }
+ const LessEqualOptionsT *AsLessEqualOptions() const {
+ return type == BuiltinOptions_LessEqualOptions ?
+ reinterpret_cast<const LessEqualOptionsT *>(value) : nullptr;
+ }
+ SelectOptionsT *AsSelectOptions() {
+ return type == BuiltinOptions_SelectOptions ?
+ reinterpret_cast<SelectOptionsT *>(value) : nullptr;
+ }
+ const SelectOptionsT *AsSelectOptions() const {
+ return type == BuiltinOptions_SelectOptions ?
+ reinterpret_cast<const SelectOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -2873,6 +2978,46 @@ inline flatbuffers::Offset<PadOptions> CreatePadOptions(
flatbuffers::Offset<PadOptions> CreatePadOptions(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct PadV2OptionsT : public flatbuffers::NativeTable {
+ typedef PadV2Options TableType;
+ PadV2OptionsT() {
+ }
+};
+
+struct PadV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PadV2OptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PadV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PadV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PadV2Options> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PadV2OptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PadV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PadV2OptionsBuilder &operator=(const PadV2OptionsBuilder &);
+ flatbuffers::Offset<PadV2Options> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PadV2Options>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PadV2Options> CreatePadV2Options(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PadV2OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PadV2Options> CreatePadV2Options(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ReshapeOptionsT : public flatbuffers::NativeTable {
typedef ReshapeOptions TableType;
std::vector<int32_t> new_shape;
@@ -3995,6 +4140,86 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct GreaterOptionsT : public flatbuffers::NativeTable {
+ typedef GreaterOptions TableType;
+ GreaterOptionsT() {
+ }
+};
+
+struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GreaterOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ GreaterOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<GreaterOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GreaterOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit GreaterOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GreaterOptionsBuilder &operator=(const GreaterOptionsBuilder &);
+ flatbuffers::Offset<GreaterOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GreaterOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ GreaterOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct GreaterEqualOptionsT : public flatbuffers::NativeTable {
+ typedef GreaterEqualOptions TableType;
+ GreaterEqualOptionsT() {
+ }
+};
+
+struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GreaterEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ GreaterEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<GreaterEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GreaterEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit GreaterEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GreaterEqualOptionsBuilder &operator=(const GreaterEqualOptionsBuilder &);
+ flatbuffers::Offset<GreaterEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GreaterEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ GreaterEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct LessOptionsT : public flatbuffers::NativeTable {
typedef LessOptions TableType;
LessOptionsT() {
@@ -4035,6 +4260,46 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions(
flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct LessEqualOptionsT : public flatbuffers::NativeTable {
+ typedef LessEqualOptions TableType;
+ LessEqualOptionsT() {
+ }
+};
+
+struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LessEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ LessEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<LessEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LessEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LessEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LessEqualOptionsBuilder &operator=(const LessEqualOptionsBuilder &);
+ flatbuffers::Offset<LessEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LessEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LessEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct NegOptionsT : public flatbuffers::NativeTable {
typedef NegOptions TableType;
NegOptionsT() {
@@ -4075,6 +4340,46 @@ inline flatbuffers::Offset<NegOptions> CreateNegOptions(
flatbuffers::Offset<NegOptions> CreateNegOptions(flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SelectOptionsT : public flatbuffers::NativeTable {
+ typedef SelectOptions TableType;
+ SelectOptionsT() {
+ }
+};
+
+struct SelectOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SelectOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ SelectOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(SelectOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SelectOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SelectOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit SelectOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SelectOptionsBuilder &operator=(const SelectOptionsBuilder &);
+ flatbuffers::Offset<SelectOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SelectOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SelectOptions> CreateSelectOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ SelectOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SelectOptions> CreateSelectOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -4318,6 +4623,21 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const NegOptions *builtin_options_as_NegOptions() const {
return builtin_options_type() == BuiltinOptions_NegOptions ? static_cast<const NegOptions *>(builtin_options()) : nullptr;
}
+ const PadV2Options *builtin_options_as_PadV2Options() const {
+ return builtin_options_type() == BuiltinOptions_PadV2Options ? static_cast<const PadV2Options *>(builtin_options()) : nullptr;
+ }
+ const GreaterOptions *builtin_options_as_GreaterOptions() const {
+ return builtin_options_type() == BuiltinOptions_GreaterOptions ? static_cast<const GreaterOptions *>(builtin_options()) : nullptr;
+ }
+ const GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_GreaterEqualOptions ? static_cast<const GreaterEqualOptions *>(builtin_options()) : nullptr;
+ }
+ const LessEqualOptions *builtin_options_as_LessEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_LessEqualOptions ? static_cast<const LessEqualOptions *>(builtin_options()) : nullptr;
+ }
+ const SelectOptions *builtin_options_as_SelectOptions() const {
+ return builtin_options_type() == BuiltinOptions_SelectOptions ? static_cast<const SelectOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -4512,6 +4832,26 @@ template<> inline const NegOptions *Operator::builtin_options_as<NegOptions>() c
return builtin_options_as_NegOptions();
}
+template<> inline const PadV2Options *Operator::builtin_options_as<PadV2Options>() const {
+ return builtin_options_as_PadV2Options();
+}
+
+template<> inline const GreaterOptions *Operator::builtin_options_as<GreaterOptions>() const {
+ return builtin_options_as_GreaterOptions();
+}
+
+template<> inline const GreaterEqualOptions *Operator::builtin_options_as<GreaterEqualOptions>() const {
+ return builtin_options_as_GreaterEqualOptions();
+}
+
+template<> inline const LessEqualOptions *Operator::builtin_options_as<LessEqualOptions>() const {
+ return builtin_options_as_LessEqualOptions();
+}
+
+template<> inline const SelectOptions *Operator::builtin_options_as<SelectOptions>() const {
+ return builtin_options_as_SelectOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -5572,6 +5912,29 @@ inline flatbuffers::Offset<PadOptions> CreatePadOptions(flatbuffers::FlatBufferB
_fbb);
}
+inline PadV2OptionsT *PadV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new PadV2OptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void PadV2Options::UnPackTo(PadV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PadV2Options> PadV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePadV2Options(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PadV2Options> CreatePadV2Options(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PadV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreatePadV2Options(
+ _fbb);
+}
+
inline ReshapeOptionsT *ReshapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ReshapeOptionsT();
UnPackTo(_o, _resolver);
@@ -6115,6 +6478,52 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new GreaterOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void GreaterOptions::UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<GreaterOptions> GreaterOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateGreaterOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateGreaterOptions(
+ _fbb);
+}
+
+inline GreaterEqualOptionsT *GreaterEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new GreaterEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void GreaterEqualOptions::UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<GreaterEqualOptions> GreaterEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateGreaterEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateGreaterEqualOptions(
+ _fbb);
+}
+
inline LessOptionsT *LessOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new LessOptionsT();
UnPackTo(_o, _resolver);
@@ -6138,6 +6547,29 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBuffe
_fbb);
}
+inline LessEqualOptionsT *LessEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new LessEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void LessEqualOptions::UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<LessEqualOptions> LessEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateLessEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LessEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateLessEqualOptions(
+ _fbb);
+}
+
inline NegOptionsT *NegOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new NegOptionsT();
UnPackTo(_o, _resolver);
@@ -6161,6 +6593,29 @@ inline flatbuffers::Offset<NegOptions> CreateNegOptions(flatbuffers::FlatBufferB
_fbb);
}
+inline SelectOptionsT *SelectOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SelectOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SelectOptions::UnPackTo(SelectOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<SelectOptions> SelectOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSelectOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SelectOptions> CreateSelectOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SelectOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateSelectOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -6512,6 +6967,26 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const NegOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_PadV2Options: {
+ auto ptr = reinterpret_cast<const PadV2Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<const GreaterOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<const LessEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SelectOptions: {
+ auto ptr = reinterpret_cast<const SelectOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -6698,6 +7173,26 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const NegOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_PadV2Options: {
+ auto ptr = reinterpret_cast<const PadV2Options *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<const GreaterOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<const LessEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_SelectOptions: {
+ auto ptr = reinterpret_cast<const SelectOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -6872,6 +7367,26 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const NegOptionsT *>(value);
return CreateNegOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_PadV2Options: {
+ auto ptr = reinterpret_cast<const PadV2OptionsT *>(value);
+ return CreatePadV2Options(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<const GreaterOptionsT *>(value);
+ return CreateGreaterOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<const GreaterEqualOptionsT *>(value);
+ return CreateGreaterEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<const LessEqualOptionsT *>(value);
+ return CreateLessEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_SelectOptions: {
+ auto ptr = reinterpret_cast<const SelectOptionsT *>(value);
+ return CreateSelectOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -7046,6 +7561,26 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new NegOptionsT(*reinterpret_cast<NegOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_PadV2Options: {
+ value = new PadV2OptionsT(*reinterpret_cast<PadV2OptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_GreaterOptions: {
+ value = new GreaterOptionsT(*reinterpret_cast<GreaterOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ value = new GreaterEqualOptionsT(*reinterpret_cast<GreaterEqualOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_LessEqualOptions: {
+ value = new LessEqualOptionsT(*reinterpret_cast<LessEqualOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_SelectOptions: {
+ value = new SelectOptionsT(*reinterpret_cast<SelectOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -7263,6 +7798,31 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_PadV2Options: {
+ auto ptr = reinterpret_cast<PadV2OptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<GreaterOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<GreaterEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<LessEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_SelectOptions: {
+ auto ptr = reinterpret_cast<SelectOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 211de63d58..f89c0d28d3 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -33,9 +33,12 @@ gen_zipped_test_files(
"fused_batch_norm.zip",
"gather.zip",
"global_batch_norm.zip",
+ "greater.zip",
+ "greater_equal.zip",
"l2_pool.zip",
"l2norm.zip",
"less.zip",
+ "less_equal.zip",
"local_response_norm.zip",
"log_softmax.zip",
"max_pool.zip",
@@ -45,6 +48,7 @@ gen_zipped_test_files(
"mul.zip",
"neg.zip",
"pad.zip",
+ "padv2.zip",
"relu.zip",
"relu1.zip",
"relu6.zip",
@@ -60,6 +64,7 @@ gen_zipped_test_files(
"sub.zip",
"topk.zip",
"transpose.zip",
+ "where.zip",
],
)
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 926bb3f121..f7cc7da900 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1391,6 +1391,60 @@ def make_pad_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_padv2_tests(zip_path):
+ """Make a set of tests to do padv2."""
+
+ # TODO(nupurgarg): Add test for tf.uint8.
+ test_parameters = [
+ {
+ "dtype": [tf.int32, tf.int64, tf.float32],
+ "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
+ "paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], [[0, 1], [0, 0],
+ [0, 0], [2, 3]]],
+ "constant_paddings": [True, False],
+ "constant_values": [0, 2],
+ },
+ # Non-4D use case.
+ {
+ "dtype": [tf.int32, tf.int64, tf.float32],
+ "input_shape": [[1, 2], [0, 1, 2]],
+ "paddings": [[[0, 1], [2, 3]]],
+ "constant_paddings": [True, False],
+ "constant_values": [0, 2],
+ },
+ ]
+
+ def build_graph(parameters):
+ """Build a pad graph given `parameters`."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+
+ # Get paddings as either a placeholder or constants.
+ if parameters["constant_paddings"]:
+ paddings = parameters["paddings"]
+ input_tensors = [input_tensor]
+ else:
+ shape = [len(parameters["paddings"]), 2]
+ paddings = tf.placeholder(dtype=tf.int32, name="padding", shape=shape)
+ input_tensors = [input_tensor, paddings]
+
+ out = tf.pad(input_tensor, paddings=paddings,
+ constant_values=parameters["constant_values"])
+ return input_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ values = [
+ create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ ]
+ if not parameters["constant_paddings"]:
+ values.append(np.array(parameters["paddings"]))
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_reshape_tests(zip_path):
"""Make a set of tests to do reshape."""
@@ -2001,6 +2055,74 @@ def make_arg_max_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_greater_tests(zip_path):
+ """Make a set of tests to do greater."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the greater op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.greater(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_greater_equal_tests(zip_path):
+ """Make a set of tests to do greater_equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the greater_equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.greater_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_less_tests(zip_path):
"""Make a set of tests to do less."""
@@ -2035,6 +2157,40 @@ def make_less_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_less_equal_tests(zip_path):
+ """Make a set of tests to do less_equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the less_equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.less_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_floor_tests(zip_path):
"""Make a set of tests to do floor."""
@@ -2086,10 +2242,41 @@ def make_neg_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_where_tests(zip_path):
+ """Make a set of tests to do where."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32],
+ "input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),],
+ }]
+
+ def build_graph(parameters):
+ """Build the where op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_set"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input3",
+ shape=parameters["input_shape_set"][1])
+ less = tf.less(input_value1, input_value2)
+ out = tf.where(less, input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_set"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_set"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
# Toco binary path provided by the generate rule.
bin_path = None
-
def main(unused_args):
global bin_path
def mkdir_if_not_exist(x):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 0673a3bb46..49762bdfe7 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -54,9 +54,11 @@ std::map<string, string> kBrokenTests = {
{R"(^\/div.*int32)", "68808744"},
{R"(^\/sub.*int32)", "68808744"},
- // Pad only supports 4D tensors.
+ // Pad and PadV2 only supports 4D tensors.
{R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
"70527055"},
+ {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
+ "70527055"},
// L2Norm only supports tensors with 4D or fewer.
{R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
@@ -256,9 +258,12 @@ INSTANTIATE_TESTS(fully_connected)
INSTANTIATE_TESTS(fused_batch_norm)
INSTANTIATE_TESTS(gather)
INSTANTIATE_TESTS(global_batch_norm)
+INSTANTIATE_TESTS(greater)
+INSTANTIATE_TESTS(greater_equal)
INSTANTIATE_TESTS(l2_pool)
INSTANTIATE_TESTS(l2norm)
INSTANTIATE_TESTS(less)
+INSTANTIATE_TESTS(less_equal)
INSTANTIATE_TESTS(local_response_norm)
INSTANTIATE_TESTS(log_softmax)
INSTANTIATE_TESTS(max_pool)
@@ -268,6 +273,7 @@ INSTANTIATE_TESTS(minimum)
INSTANTIATE_TESTS(mul)
INSTANTIATE_TESTS(neg)
INSTANTIATE_TESTS(pad)
+INSTANTIATE_TESTS(padv2)
// INSTANTIATE_TESTS(prelu)
INSTANTIATE_TESTS(relu)
INSTANTIATE_TESTS(relu1)
@@ -283,6 +289,7 @@ INSTANTIATE_TESTS(squeeze)
INSTANTIATE_TESTS(strided_slice)
INSTANTIATE_TESTS(sub)
INSTANTIATE_TESTS(transpose)
+INSTANTIATE_TESTS(where)
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index ce0a74724a..01ce0d9db2 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -280,6 +280,7 @@ cc_library(
"graph_transformations/resolve_mean_attributes.cc",
"graph_transformations/resolve_multiply_by_zero.cc",
"graph_transformations/resolve_pad_attributes.cc",
+ "graph_transformations/resolve_padv2_attributes.cc",
"graph_transformations/resolve_reorder_axes.cc",
"graph_transformations/resolve_reshape_attributes.cc",
"graph_transformations/resolve_slice_attributes.cc",
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 99ccfaea64..f5157149af 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1492,6 +1492,37 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op,
shape->add_dim()->set_size(2);
}
+void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("PadV2");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
+ for (int i = 0; i < src_op.left_padding.size(); ++i) {
+ tensor->add_int_val(src_op.left_padding[i]);
+ tensor->add_int_val(src_op.right_padding[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.left_padding.size());
+ shape->add_dim()->set_size(2);
+}
+
void CreateSliceInput(const string& input_name, const std::vector<int>& values,
GraphDef* tensorflow_graph) {
auto* params_op = tensorflow_graph->add_node();
@@ -1643,6 +1674,19 @@ void ConvertTensorFlowMaximumOperator(const Model& model,
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Select");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ *sub_op->add_input() = src_op.inputs[2];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
GraphDef* tensorflow_graph) {
auto* topk_op = tensorflow_graph->add_node();
@@ -1671,6 +1715,19 @@ void ConvertRandomUniformOperator(const Model& model,
(*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
}
+void ConvertComparisonOperator(const Model& model, const Operator& src_op,
+ const char* op_name,
+ GraphDef* tensorflow_graph) {
+ auto* comparison_op = tensorflow_graph->add_node();
+ comparison_op->set_op(op_name);
+ comparison_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *comparison_op->add_input() = src_op.inputs[0];
+ *comparison_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*comparison_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1795,6 +1852,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kPad) {
ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPadV2) {
+ ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kStridedSlice) {
ConvertStridedSliceOperator(
model, static_cast<const StridedSliceOperator&>(src_op),
@@ -1859,6 +1919,17 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowGreater) {
+ ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
+ ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowLess) {
+ ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowLessEqual) {
+ ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSelect) {
+ ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 72ffd51db4..4e3ea72182 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -174,6 +174,7 @@ DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul)
DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index c1cf79f626..6342cf3e8a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -152,6 +152,17 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// Yield on ExpandDim until it is converted to Reshape
return false;
}
+ case OperatorType::kSelect: {
+ // Select produces outputs with the same type as their 2nd input
+ CHECK_EQ(op->inputs.size(), 3);
+ const ArrayDataType data_type_x =
+ model->GetArray(op->inputs[1]).data_type;
+ const ArrayDataType data_type_y =
+ model->GetArray(op->inputs[2]).data_type;
+ CHECK(data_type_x == data_type_y);
+ SetDataTypeForAllOutputs(model, op, data_type_x);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 4923f83d91..52b739c5e2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -499,8 +499,8 @@ void ProcessTensorFlowReshapeOperator(Model* model,
<< op->outputs[0] << "\". Are your input shapes correct?";
}
-void ProcessSimpleOperator(Model* model, Operator* op) {
- const auto& input_array = model->GetArray(op->inputs[0]);
+void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
+ const auto& input_array = model->GetArray(op->inputs[input_index]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -529,6 +529,21 @@ void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
&output_array);
}
+void ProcessSelectOperator(Model* model, SelectOperator* op) {
+ // Yield until all input dims have been resolved.
+ for (const auto& input : op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ if (!input_array.has_shape()) {
+ return;
+ }
+ }
+
+ // Select's output matches the second and third output.
+ const auto& input1_array = model->GetArray(op->inputs[1]);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.copy_shape(input1_array.shape());
+}
+
void ProcessAddNOperator(Model* model, Operator* op) {
// Yield until all input dims have been resolved.
//
@@ -670,8 +685,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
const auto& first_input_array = model->GetArray(op->inputs[0]);
output_array.copy_shape(first_input_array.shape());
// Negative axis means the count starts at the back of the dims().
- int axis = op->axis;
- if (axis < 0) axis += first_input_array.shape().dims().size();
+ if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
@@ -684,14 +698,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
CHECK_EQ(input_array.shape().dimensions_count(),
output_array.shape().dimensions_count());
const std::vector<int>& input_dims = input_array.shape().dims();
- CHECK_LT(axis, input_dims.size());
- concat_size += input_dims[axis];
+ CHECK_LT(op->axis, input_dims.size());
+ concat_size += input_dims[op->axis];
}
// Write out the concat_size on the output array shape.
auto& output_shape = *output_array.mutable_shape();
auto& output_dims = *output_shape.mutable_dims();
- CHECK_LT(axis, output_shape.dimensions_count());
- output_dims[axis] = concat_size;
+ CHECK_LT(op->axis, output_shape.dimensions_count());
+ output_dims[op->axis] = concat_size;
}
void ProcessRangeOperator(Model* model, RangeOperator* op) {
@@ -1147,6 +1161,32 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
output_array.copy_shape(output_shape);
}
+void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
+ CHECK_EQ(op->inputs.size(), 3);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->left_padding.empty()) return;
+ CHECK_EQ(op->left_padding.size(), op->right_padding.size());
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->left_padding.size(), dims.size());
+
+ for (int i = 0; i < op->left_padding.size(); ++i) {
+ dims[i] += op->left_padding[i] + op->right_padding[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
void ProcessRankOperator(Model* model, RankOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
@@ -1474,7 +1514,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kCast:
case OperatorType::kFloor:
case OperatorType::kExp:
- ProcessSimpleOperator(model, op);
+ ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
@@ -1545,7 +1585,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kMean:
ProcessTensorFlowReductionOperator(model, op);
break;
-
+ case OperatorType::kSelect:
+ ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
+ break;
case OperatorType::kSlice:
ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
break;
@@ -1629,6 +1671,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kPad:
ProcessPadOperator(model, static_cast<PadOperator*>(op));
break;
+ case OperatorType::kPadV2:
+ ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
+ break;
case OperatorType::kStridedSlice:
ProcessStridedSliceOperator(model,
static_cast<StridedSliceOperator*>(op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 347302c7a5..a1ca7371c8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -48,13 +48,18 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kLogSoftmax ||
type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub ||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
+ type == OperatorType::kPadV2 ||
type == OperatorType::kTensorFlowReshape ||
type == OperatorType::kTanh || type == OperatorType::kMul ||
type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice ||
type == OperatorType::kDepthToSpace ||
type == OperatorType::kLstmCell || type == OperatorType::kGather ||
- type == OperatorType::kTranspose || type == OperatorType::kMean;
+ type == OperatorType::kTranspose || type == OperatorType::kMean ||
+ type == OperatorType::kTensorFlowGreater ||
+ type == OperatorType::kTensorFlowGreaterEqual ||
+ type == OperatorType::kTensorFlowLess ||
+ type == OperatorType::kTensorFlowLessEqual;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
@@ -256,8 +261,7 @@ bool ChooseHardcodedQuantizationForOperatorOutput(
IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
return true;
}
- if ((op.type == OperatorType::kLogistic) ||
- (op.type == OperatorType::kSoftmax)) {
+ if (op.type == OperatorType::kLogistic || op.type == OperatorType::kSoftmax) {
// Logistic and Softmax have range: [0, 1].
//
// For Logistic, 0.5 should be exactly representable, as implementations
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
index 2b3ee36ad1..8f2c1f8162 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -134,9 +134,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
}
// Remove the old param arrays
- model->EraseArray(bn_op->inputs[1]);
- model->EraseArray(bn_op->inputs[2]);
- model->EraseArray(bn_op->inputs[3]);
+ DeleteArrayIfUsedOnce(bn_op->inputs[1], model);
+ DeleteArrayIfUsedOnce(bn_op->inputs[2], model);
+ DeleteArrayIfUsedOnce(bn_op->inputs[3], model);
// Remove the old operator
DCHECK_EQ(bn_it->get(), bn_op);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc
new file mode 100644
index 0000000000..ebb023e342
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
+ const auto pad_it = model->operators.begin() + op_index;
+ auto* pad_op = pad_it->get();
+ if (pad_op->type != OperatorType::kPadV2) return false;
+
+ auto* op = static_cast<PadV2Operator*>(pad_op);
+ if (!op->left_padding.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 3);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& array = model->GetArray(op->inputs[1]);
+ if (!array.has_shape()) return false;
+
+ const std::vector<int>& dims = array.shape().dims();
+ CHECK_EQ(dims.size(), 2);
+
+ std::vector<int> buffer = array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ for (int i = 0; i < dims[0]; ++i) {
+ op->left_padding.push_back(buffer[i * 2]);
+ op->right_padding.push_back(buffer[i * 2 + 1]);
+ }
+
+ // TODO(dkalenichenko): Delete the extra input?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 8efe6ab7b9..1eef173afe 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -925,6 +925,19 @@ void ConvertPadOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
+void ConvertPadV2Operator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "PadV2");
+ CheckInputsCount(node, tf_import_flags, 3);
+ auto* op = new PadV2Operator;
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->inputs.push_back(node.input(2));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
void ConvertShapeOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1331,6 +1344,19 @@ void ConvertUnsupportedOperator(const NodeDef& node,
}
}
+void ConvertSelectOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CheckInputsCount(node, tf_import_flags, 3);
+
+ auto* op = new SelectOperator;
+ for (const auto& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
void ConvertStridedSliceOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -2169,6 +2195,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertMergeOperator(node, tf_import_flags, model);
} else if (node.op() == "Pad") {
ConvertPadOperator(node, tf_import_flags, model);
+ } else if (node.op() == "PadV2") {
+ ConvertPadV2Operator(node, tf_import_flags, model);
} else if (node.op() == "StridedSlice") {
ConvertStridedSliceOperator(node, tf_import_flags, model);
} else if (node.op() == "Shape") {
@@ -2239,6 +2267,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertDynamicStitchOperator(node, tf_import_flags, model);
} else if (node.op() == "RandomUniform") {
ConvertRandomUniform(node, tf_import_flags, model);
+ } else if (node.op() == "Select") {
+ ConvertSelectOperator(node, tf_import_flags, model);
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 482cc71d8b..47f8db5978 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -82,6 +82,7 @@ enum class OperatorType {
kStack,
kBatchToSpaceND,
kPad,
+ kPadV2,
kStridedSlice,
kSlice,
kSqueeze,
@@ -132,6 +133,7 @@ enum class OperatorType {
// instead of being given as plain constant arrays. So we need to insert
// special nodes in the graph to shuffle axes.
kReorderAxes,
+ kSelect,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -825,6 +827,29 @@ struct PadOperator : Operator {
std::vector<int> right_padding;
};
+// PaddingV2 operator. Pads a tensor with the given constant value.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the padding array
+// inputs[2]: required: the scalar constant_values
+//
+// This operation pads input according to the paddings and constant_values you
+// specify. paddings is an integer tensor with shape [Dn, 2], where n is the
+// rank of input. For each dimension D of input, paddings[D, 0] indicates how
+// many padding values to add before the contents of input in that dimension,
+// and paddings[D, 1] indicates how many padding values to add after the
+// contents of input in that dimension. constant_values is a scalar tensor of
+// the same type as input that indicates the value to use for padding input.
+//
+// TensorFlow equivalent: PadV2
+struct PadV2Operator : Operator {
+ PadV2Operator() : Operator(OperatorType::kPadV2) {}
+
+ std::vector<int> left_padding;
+ std::vector<int> right_padding;
+};
+
// Strided slice operator.
//
// Inputs:
@@ -1063,6 +1088,18 @@ struct NegOperator : Operator {
NegOperator() : Operator(OperatorType::kNeg) {}
};
+// Element-wise select operator choosing elements from inputs[1] or input[2]
+//
+// Inputs:
+// inputs[0]: required: boolean mask per index
+// inputs[1]: required: tensor of values if true
+// inputs[2]: required: tensor of values if false
+//
+// TensorFlow equivalent: Select
+struct SelectOperator : Operator {
+ SelectOperator() : Operator(OperatorType::kSelect) {}
+};
+
// Element-wise reciprocal-square-root (x^-0.5) operator.
//
// Inputs:
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index e18ae805c0..90e24aa104 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -465,6 +465,21 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
TocoOperator* op) const override {}
};
+class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
+ ::tflite::BuiltinOptions_PadV2Options> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreatePadV2Options(*builder);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {}
+};
+
class Reshape
: public BuiltinOperator<TensorFlowReshapeOperator,
::tflite::ReshapeOptions,
@@ -832,6 +847,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kMaxPool));
ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
+ ops.emplace_back(
+ new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
OperatorType::kTensorFlowReshape));
ops.emplace_back(
@@ -898,9 +915,18 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
"MAXIMUM", OperatorType::kTensorFlowMaximum));
ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>(
"MINIMUM", OperatorType::kTensorFlowMinimum));
+ ops.emplace_back(new SimpleOperator<TensorFlowGreaterOperator>(
+ "GREATER", OperatorType::kTensorFlowGreater));
+ ops.emplace_back(new SimpleOperator<TensorFlowGreaterEqualOperator>(
+ "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual));
ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
"LESS", OperatorType::kTensorFlowLess));
+ ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>(
+ "LESS_EQUAL", OperatorType::kTensorFlowLessEqual));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
+ ops.emplace_back(
+ new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect));
+
return ops;
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 2b6c32b07c..a4fff9974a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -116,6 +116,7 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<TensorFlowLessOperator>("LESS",
OperatorType::kTensorFlowLess);
CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
+ CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 6973b22c5a..58c99051bd 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -106,6 +106,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveSpaceToBatchNDAttributes);
transformations->Add(new ResolveBatchToSpaceNDAttributes);
transformations->Add(new ResolvePadAttributes);
+ transformations->Add(new ResolvePadV2Attributes);
transformations->Add(new ResolveStridedSliceAttributes);
transformations->Add(new ResolveSliceAttributes);
transformations->Add(new ResolveMeanAttributes);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 86ee1f3761..1f56fe5c83 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -143,6 +143,10 @@ int CountOpsWithInput(const Model& model, const string& array_name) {
for (auto& input : op->inputs) {
if (input == array_name) {
count++;
+ // Breaking here is important: some graphs have ops that use the
+ // same array as more than one of their inputs, and in that case
+ // we want it counted only once.
+ break;
}
}
}
@@ -352,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
HANDLE_OPERATORTYPENAME_CASE(Neg)
HANDLE_OPERATORTYPENAME_CASE(Pad)
+ HANDLE_OPERATORTYPENAME_CASE(PadV2)
HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
HANDLE_OPERATORTYPENAME_CASE(Stack)
HANDLE_OPERATORTYPENAME_CASE(Range)
@@ -386,6 +391,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Exp)
HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
+ HANDLE_OPERATORTYPENAME_CASE(Select)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -2092,6 +2098,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
return ArrayDataType::kInt32;
case INT64:
return ArrayDataType::kInt64;
+ case BOOL:
+ return ArrayDataType::kBool;
default:
return ArrayDataType::kNone;
}
diff --git a/tensorflow/contrib/lite/toco/types.proto b/tensorflow/contrib/lite/toco/types.proto
index 03bd6150bc..421667a83c 100644
--- a/tensorflow/contrib/lite/toco/types.proto
+++ b/tensorflow/contrib/lite/toco/types.proto
@@ -37,4 +37,7 @@ enum IODataType {
// Int16, quantized
QUANTIZED_INT16 = 6;
+
+ // Boolean
+ BOOL = 7;
}
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index ea6032e588..4b7af18b33 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -396,14 +396,19 @@ class Pruning(object):
self._block_pooling_function)
with ops.name_scope(weights.op.name + '_pruning_ops'):
- abs_weights = math_ops.abs(
- array_ops.reshape(weights, [
- 1,
- squeezed_weights.get_shape()[0],
- squeezed_weights.get_shape()[1], 1
- ]))
+ abs_weights = math_ops.abs(squeezed_weights)
+
pool_window = [self._block_dim[0], self._block_dim[1]]
- pooled_weights = nn_ops.pool(
+ pool_fn = pruning_utils.factorized_pool
+
+ if not self._spec.use_tpu:
+ pool_fn = nn_ops.pool
+ abs_weights = array_ops.reshape(
+ abs_weights,
+ [1, abs_weights.get_shape()[0],
+ abs_weights.get_shape()[1], 1])
+
+ pooled_weights = pool_fn(
abs_weights,
window_shape=pool_window,
pooling_type=self._block_pooling_function,
@@ -411,19 +416,18 @@ class Pruning(object):
padding='SAME',
name=weights.op.name + '_pooled')
+ if pooled_weights.get_shape().ndims != 2:
+ pooled_weights = array_ops.squeeze(pooled_weights)
+
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
-
- reshaped_mask = array_ops.reshape(
- new_mask,
- [pooled_weights.get_shape()[1],
- pooled_weights.get_shape()[2]])
updated_mask = pruning_utils.kronecker_product(
- reshaped_mask, array_ops.ones(self._block_dim))
+ new_mask, array_ops.ones(self._block_dim))
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
squeezed_weights.get_shape()[1]])
+
return smoothed_threshold, array_ops.reshape(sliced_mask,
array_ops.shape(weights))
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index 56d3dcef20..ef6c6a3f5d 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs):
return math_ops.div(cdf, math_ops.reduce_max(cdf))
+def factorized_pool(input_tensor,
+ window_shape,
+ pooling_type,
+ strides,
+ padding,
+ name=None):
+ """Performs m x n pooling through a combination of 1xm and 1xn pooling.
+
+ Args:
+ input_tensor: Input tensor. Must be rank 2
+ window_shape: Pooling window shape
+ pooling_type: Either 'MAX' or 'AVG'
+ strides: The stride of the pooling window
+ padding: 'SAME' or 'VALID'.
+ name: Name of the op
+
+ Returns:
+ A rank 2 tensor containing the pooled output
+
+ Raises:
+ ValueError: if the input tensor is not rank 2
+ """
+ if input_tensor.get_shape().ndims != 2:
+ raise ValueError('factorized_pool() accepts tensors of rank 2 only')
+
+ [height, width] = input_tensor.get_shape()
+ with ops.name_scope(name, 'factorized_pool'):
+ input_tensor_aligned = array_ops.reshape(
+ input_tensor, [1, 1, height, width],
+ name=input_tensor.op.name + '_aligned')
+
+ height_pooling = nn_ops.pool(
+ input_tensor_aligned,
+ window_shape=[1, window_shape[0]],
+ pooling_type=pooling_type,
+ strides=[1, strides[0]],
+ padding=padding)
+ swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])
+
+ width_pooling = nn_ops.pool(
+ swap_height_width,
+ window_shape=[1, window_shape[1]],
+ pooling_type=pooling_type,
+ strides=[1, strides[1]],
+ padding=padding)
+
+ return array_ops.squeeze(
+ array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))
+
+
def determine_partitioned_axis(partitioned_variable):
partitioned_axis = 0
concatenated_variable_shape = partitioned_variable.get_shape()
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
index 10e1dd0a8e..ccde5b4e8a 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -22,8 +22,10 @@ import numpy as np
from tensorflow.contrib.model_pruning.python import pruning_utils
from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -31,6 +33,30 @@ from tensorflow.python.platform import test
class PruningUtilsTest(test.TestCase):
+ def _compare_cdf(self, values):
+ abs_values = math_ops.abs(values)
+ max_value = math_ops.reduce_max(abs_values)
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
+ abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
+ cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
+ self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
+
+ def _compare_pooling_methods(self, weights, pooling_kwargs):
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ pooled_weights_tf = array_ops.squeeze(
+ nn_ops.pool(
+ array_ops.reshape(
+ weights,
+ [1, weights.get_shape()[0],
+ weights.get_shape()[1], 1]), **pooling_kwargs))
+ pooled_weights_factorized_pool = pruning_utils.factorized_pool(
+ weights, **pooling_kwargs)
+ self.assertAllClose(pooled_weights_tf.eval(),
+ pooled_weights_factorized_pool.eval())
+
def testHistogram(self):
width = 10
height = 10
@@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase):
self.assertAllEqual(len(norm_cdf_val), nbins)
self.assertAllEqual(expected_cdf, norm_cdf_val)
- def _compare_cdf(self, values):
- abs_values = math_ops.abs(values)
- max_value = math_ops.reduce_max(abs_values)
- with self.test_session():
- variables.global_variables_initializer().run()
- cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
- abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
- cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
- return cdf.eval(), cdf_from_histogram.eval()
-
def testCDFEquivalence2D(self):
width = 100
height = 100
weights = variable_scope.get_variable("weights", shape=[width, height])
- cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
- self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+ self._compare_cdf(weights)
def testCDFEquivalence4D(self):
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
- cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
- self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+ self._compare_cdf(weights)
+
+ def testFactorizedAvgPool(self):
+ weights = variable_scope.get_variable("weights", shape=[1024, 2048])
+ pooling_kwargs = {
+ "window_shape": [2, 4],
+ "pooling_type": "AVG",
+ "strides": [2, 4],
+ "padding": "SAME"
+ }
+ self._compare_pooling_methods(weights, pooling_kwargs)
+
+ def testFactorizedMaxPool(self):
+ weights = variable_scope.get_variable("weights", shape=[1024, 2048])
+ pooling_kwargs = {
+ "window_shape": [2, 4],
+ "pooling_type": "MAX",
+ "strides": [2, 4],
+ "padding": "SAME"
+ }
+ self._compare_pooling_methods(weights, pooling_kwargs)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 0bdf6f64c9..f84ff1bfe9 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -181,6 +181,7 @@ py_library(
":datasets",
":profiler",
":tpu_py",
+ "//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/contrib/tpu/proto:topology_proto_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index 3bdf7c2f83..defed00537 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -64,6 +64,10 @@ REGISTER_OP("TPUReplicatedOutput")
"Operator that connects the output of an N-way replicated TPU "
"computation to N separate outputs.");
+REGISTER_OP("TPUCompilationResult")
+ .Output("output: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("TPUReplicate")
.Attr("computation: func")
.Attr("num_replicas: int >= 1")
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index fcfbbe1a21..7ecb36852c 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -21,3 +21,13 @@ tf_proto_library(
cc_api_version = 2,
visibility = ["//visibility:public"],
)
+
+tf_proto_library(
+ name = "compilation_result_proto",
+ srcs = [
+ "compilation_result.proto",
+ ],
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto
new file mode 100644
index 0000000000..cf52897de3
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/compilation_result.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow.tpu;
+
+import "tensorflow/core/lib/core/error_codes.proto";
+
+// Describes the result of a TPU compilation.
+message CompilationResultProto {
+ // The error message, if any, returned during compilation.
+ error.Code status_code = 1;
+ string status_error_message = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
index 3455e0b4a6..faf677a81d 100644
--- a/tensorflow/contrib/tpu/python/tpu/session_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/session_support.py
@@ -28,6 +28,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -78,6 +79,15 @@ class WorkerHeartbeatManager(object):
return WorkerHeartbeatManager(session, devices, heartbeat_ops,
request_placeholder)
+ def heartbeat_supported(self):
+ """Returns True if heartbeat operations are supported on all workers."""
+ try:
+ # Send ping to verify worker has heartbeat support.
+ self.ping()
+ return True
+ except errors.InvalidArgumentError as _:
+ return False
+
def configure(self, message):
"""Configure heartbeat manager for all devices.
@@ -106,7 +116,7 @@ class WorkerHeartbeatManager(object):
event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
for res_pb in results
]
- logging.info('Results: %s', parsed_results)
+ logging.debug('Ping results: %s', parsed_results)
return parsed_results
def lame_workers(self):
@@ -189,7 +199,9 @@ class WatchdogManager(threading.Thread):
self._running = False
self._graph = ops.Graph()
self._session = session_lib.Session(
- target=session.sess_str, graph=self._graph)
+ target=session.sess_str,
+ graph=self._graph,
+ )
with self._graph.as_default():
if devices is None:
@@ -249,6 +261,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
self._graph = ops.Graph()
self._workers = None
self._session = None
+ self._heartbeat_supported = False
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
# N.B. We have to pull the global step here to avoid it being unavailable
@@ -264,10 +277,16 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
target=training_session.sess_str, graph=self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
-
- self._workers.configure(
- event_pb2.WorkerHeartbeatRequest(
- shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
+ self._heartbeat_supported = self._workers.heartbeat_supported()
+ if self._heartbeat_supported:
+ self._workers.configure(
+ event_pb2.WorkerHeartbeatRequest(
+ shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
+ else:
+ logging.warn(
+ 'Worker heartbeats not supported by all workers. No failure '
+ 'handling will be enabled.'
+ )
def saver(self):
if self._saver:
@@ -286,6 +305,9 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
del run_values
+ if not self._heartbeat_supported:
+ return
+
lame_workers = self._workers.lame_workers()
if lame_workers:
logging.info('ShutdownHook: lame workers found: %s', lame_workers)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 7b8786304c..c8f24ed01d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -58,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
+_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
@@ -385,6 +386,45 @@ def replicate(computation,
ValueError: If the number of inputs per replica does not match
the number of formal parameters to `computation`.
"""
+ return split_compile_and_replicate(computation, inputs, infeed_queue,
+ device_assignment, name)[1]
+
+
+def split_compile_and_replicate(computation,
+ inputs=None,
+ infeed_queue=None,
+ device_assignment=None,
+ name=None):
+ """Builds graph operators that runs compilation and replicated computation.
+
+ This is a lower level interface than replicate that returns a separate compile
+ and execute output tensor. In the generated graph the compile op feeds into
+ the execute op and no additional compilation is incurred when running the
+ compile op before the execute op. The compile op returns additional
+ information about the compilation but does not return the compiled program.
+
+ Args:
+ computation: A Python function that builds the computation to replicate.
+ inputs: A list of lists of input tensors or `None` (equivalent to
+ `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
+ have the same number of inputs.
+ infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
+ of arguments as inputs to computation.
+ device_assignment: If not `None`, a `DeviceAssignment` describing the
+ mapping between logical cores in the computation with physical cores in
+ the TPU topology. Uses a default device assignment if `None`. The
+ `DeviceAssignment` may be omitted if each replica of the computation uses
+ only one core, and there is either only one replica, or the number of
+ replicas is equal to the number of cores in the TPU system.
+ name: (Deprecated) Does nothing.
+ Returns:
+ A list of lists with the first list corresponding to the compile op and the
+ second a list of output tensors, indexed by `[replica_num][output_num]`.
+ Raises:
+ ValueError: If all replicas do not have equal numbers of input tensors.
+ ValueError: If the number of inputs per replica does not match
+ the number of formal parameters to `computation`.
+ """
del name
inputs = [[]] if inputs is None else inputs
@@ -456,8 +496,8 @@ def replicate(computation,
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
- context = TPUReplicateContext(
- name=graph.unique_name("cluster"), num_replicas=num_replicas)
+ cluster_name = graph.unique_name("cluster")
+ context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas)
try:
context.Enter()
@@ -516,8 +556,7 @@ def replicate(computation,
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
- output_tensors = [o for o in outputs
- if not isinstance(o, ops.Operation)]
+ output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
@@ -550,22 +589,33 @@ def replicate(computation,
name="output{}".format(i))
for i in xrange(output_arity)]
+ with ops.control_dependencies([metadata]):
+ compile_status = tpu_ops.tpu_compilation_result()
+ op = compile_status.op
+ attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
+ op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
+
with ops.control_dependencies(output_operations):
if output_arity == 0:
# Returns a list of NoOps dependent on the replication Op, indexed by
# [replica_num].
return [
- control_flow_ops.no_op(name="shard_%d" % i)
- for i in range(num_replicas)
+ compile_status, [
+ control_flow_ops.no_op(name="shard_%d" % i)
+ for i in range(num_replicas)
+ ]
]
else:
# Wraps the outputs in identity operators so the names of any possible
# `fetch` nodes are preserved by the replication rewrite.
return [
- [array_ops.identity(outputs[out][replica],
- name="output_%d_shard_%d" % (out, replica))
- for out in xrange(output_arity)]
- for replica in xrange(num_replicas)
+ compile_status, [[
+ array_ops.identity(
+ outputs[out][replica],
+ name="output_%d_shard_%d" % (out, replica))
+ for out in xrange(output_arity)
+ ]
+ for replica in xrange(num_replicas)]
]
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index a69bfa9a20..a624eceed9 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -175,17 +175,7 @@ class _SIGNAL(object):
STOP = -2
-class TPUEstimatorSpec(
- collections.namedtuple('TPUEstimatorSpec', [
- 'mode',
- 'predictions',
- 'loss',
- 'train_op',
- 'eval_metrics',
- 'export_outputs',
- 'scaffold_fn',
- 'host_call'
- ])):
+class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
@@ -1156,7 +1146,7 @@ class _ModelFnWrapper(object):
self._call_model_fn(features, labels))
loss, train_op = estimator_spec.loss, estimator_spec.train_op
- if isinstance(estimator_spec, TPUEstimatorSpec):
+ if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
else:
captured_scaffold_fn.capture(None)
@@ -1165,8 +1155,8 @@ class _ModelFnWrapper(object):
# outfeed.
with ops.control_dependencies([train_op]):
host_call_outfeed_ops = []
- if (isinstance(estimator_spec, TPUEstimatorSpec) and
- estimator_spec.host_call is not None):
+ if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access
+ and estimator_spec.host_call is not None):
host_call.record({'host_call': estimator_spec.host_call})
host_call_outfeed_ops = host_call.create_enqueue_op()
with ops.control_dependencies(host_call_outfeed_ops):
@@ -1209,7 +1199,7 @@ class _ModelFnWrapper(object):
features, labels = inputs.features_and_labels()
tpu_estimator_spec = self._call_model_fn(features, labels)
- if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
+ if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
raise RuntimeError(
'estimator_spec used by TPU evaluation must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
@@ -1254,7 +1244,7 @@ class _ModelFnWrapper(object):
tpu_estimator_spec = self._call_model_fn(
features, labels, is_export_mode=False)
- if not isinstance(tpu_estimator_spec, TPUEstimatorSpec):
+ if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
raise RuntimeError(
'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
@@ -1316,7 +1306,7 @@ class _ModelFnWrapper(object):
estimator_spec = self._model_fn(features=features, **kwargs)
if (self._ctx.is_running_on_cpu(is_export_mode) and
- isinstance(estimator_spec, TPUEstimatorSpec)):
+ isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
return estimator_spec.as_estimator_spec()
@@ -1325,7 +1315,7 @@ class _ModelFnWrapper(object):
def _verify_estimator_spec(self, estimator_spec):
"""Validates the estimator_spec."""
- if isinstance(estimator_spec, TPUEstimatorSpec):
+ if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
return estimator_spec
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt
index bf544703de..e230c51edf 100644
--- a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt
@@ -1,5 +1,19 @@
op {
graph_op_name: "MapAndBatchDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when building a closure
+for `f`.
+END
+ }
in_arg {
name: "batch_size"
description: <<END
@@ -11,13 +25,26 @@ END
in_arg {
name: "num_parallel_batches"
description: <<END
-A scalar representing the number of batches to create in
-parallel. Processing multiple batches in parallel benefits workloads prone to
-stragglers.
+A scalar representing the number of batches to create in parallel. Processing
+multiple batches in parallel benefits workloads prone to stragglers.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function to apply to the outputs of `input_dataset`.
END
}
- summary: "Creates a dataset that applies `f` to the outputs of `input_dataset` and then"
+ summary: "Creates a dataset that fuses mapping with batching."
description: <<END
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
batches `batch_size` of them.
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt
new file mode 100644
index 0000000000..81ef92cae0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt
@@ -0,0 +1,54 @@
+op {
+ graph_op_name: "MapAndBatchDatasetV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when building a closure
+for `f`.
+END
+ }
+ in_arg {
+ name: "batch_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a
+batch. It determines the number of concurrent invocations of `f` that process
+elements from `input_dataset` in parallel.
+END
+ }
+ in_arg {
+ name: "num_parallel_calls"
+ description: <<END
+A scalar representing the maximum number of parallel invocations of the `map_fn`
+function. Applying the `map_fn` on consecutive input elements in parallel has
+the potential to improve input pipeline throughput.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function to apply to the outputs of `input_dataset`.
+END
+ }
+ summary: "Creates a dataset that fuses mapping with batching."
+ description: <<END
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+END
+}
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index e389eb9b2a..7d63626b95 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -272,9 +272,9 @@ struct NodeItem {
// (uint8 is enough for DataType).
// EdgeInfo out_edges[num_out_edges];
// AllocatorAttributes output_attr[num_outputs];
+ // int forward_from[num_outputs];
// uint8 input_type[num_inputs];
// uint8 output_type[num_outputs];
- // int forward_from[num_outputs];
// Return pointer to variable length section.
char* var() const {
@@ -289,22 +289,20 @@ struct NodeItem {
return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) *
num_output_edges);
}
+ int* forward_from_base() const {
+ return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
+ sizeof(AllocatorAttributes) * num_outputs);
+ }
uint8* input_type_base() const {
- return reinterpret_cast<uint8*>(var() +
- sizeof(EdgeInfo) * num_output_edges +
- sizeof(AllocatorAttributes) * num_outputs);
+ return reinterpret_cast<uint8*>(
+ var() + sizeof(EdgeInfo) * num_output_edges +
+ sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
}
uint8* output_type_base() const {
return reinterpret_cast<uint8*>(
var() + sizeof(EdgeInfo) * num_output_edges +
- sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs);
- }
-
- int* forward_from_base() const {
- return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
- sizeof(AllocatorAttributes) * num_outputs +
- sizeof(uint8) * num_inputs +
- sizeof(uint8) * num_outputs);
+ sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
+ sizeof(uint8) * num_inputs);
}
TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
@@ -481,9 +479,9 @@ size_t GraphView::NodeItemBytes(const Node* n) {
sizeof(NodeItem) // Fixed
+ num_output_edges * sizeof(EdgeInfo) // output_edges[...]
+ num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
+ + num_outputs * sizeof(int) // forward_from[num_outputs]
+ num_inputs * sizeof(uint8) // input_type[num_inputs]
- + num_outputs * sizeof(uint8) // output_type[num_outputs]
- + num_outputs * sizeof(int); // forward_from[num_outputs]
+ + num_outputs * sizeof(uint8); // output_type[num_outputs]
static constexpr size_t kItemAlignment = sizeof(NodeItem*);
static_assert(kItemAlignment % alignof(NodeItem) == 0,
"NodeItem must be aligned with kItemAlignment");
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index a6f637b488..bf05f6f1d9 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -795,16 +795,16 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
};
}
- if (run_opts.runner == nullptr) {
- run_opts.runner = &default_runner_;
- }
- DCHECK(run_opts.runner != nullptr);
-
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
parent_->Run(run_opts, handle, args, rets, done);
return;
}
+ if (run_opts.runner == nullptr) {
+ run_opts.runner = &default_runner_;
+ }
+ DCHECK(run_opts.runner != nullptr);
+
Executor::Args* exec_args = new Executor::Args;
// Inherit the step_id from the caller.
exec_args->step_id = run_opts.step_id;
diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h
index 9d31b1aecb..391dc8c198 100644
--- a/tensorflow/core/common_runtime/profile_handler.h
+++ b/tensorflow/core/common_runtime/profile_handler.h
@@ -29,22 +29,6 @@ class ProfileHandler {
ProfileHandler() {}
virtual ~ProfileHandler() {}
- // Records that a miscellaneous activity occurred in the current step.
- //
- // Implementations of this method must be thread-safe.
- //
- // Args:
- // - device: The device on which the activity occurred.
- // - start: The time at which the activity started.
- // - limit: The time at which the activity finished.
- // - label: A label for the op, which may be used in visualization.
- // - op_type: A type string for the op, which may be used in visualization.
- // - details: A details string, which may be used in visualization.
- // from time "start" to "limit" with "op_type" and "details".
- virtual void RecordActivity(const string& device, Microseconds start,
- Microseconds limit, StringPiece label,
- StringPiece op_type, StringPiece details) = 0;
-
// Records that a single Op was executed in the current step.
//
// Implementations of this method must be thread-safe.
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 06dbe04986..fa4d1eda62 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -232,13 +232,12 @@ Status ShapeRefiner::AddNode(const Node* node) {
input_nodes[e->dst_input()] = input;
input_shapes[e->dst_input()] = c->output(e->src_output());
- // Only propagate handle data of edges which are carrying resource handles.
- if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
- const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
- if (in_v != nullptr) {
- input_handle_shapes_and_types[e->dst_input()].reset(
- new std::vector<ShapeAndType>(*in_v));
- }
+ const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
+ if (in_v != nullptr) {
+ DataType input_type = e->src()->output_type(e->src_output());
+ DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
+ input_handle_shapes_and_types[e->dst_input()].reset(
+ new std::vector<ShapeAndType>(*in_v));
}
}
@@ -422,6 +421,28 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
kMaxTensorSize, disable_constant_propagation_);
}
+Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
+ int dst_idx, bool* evaluated,
+ int64* result) {
+ Tensor scalar;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
+ if (*evaluated) {
+ DCHECK_EQ(scalar.NumElements(), 1)
+ << "EvaluateConstantIntScalarEdge called on non-scalar edge: "
+ << scalar.NumElements();
+ if (scalar.dtype() == DT_INT32) {
+ *result = scalar.scalar<int32>()();
+ } else {
+ DCHECK_EQ(scalar.dtype(), DT_INT64)
+ << "EvaluateConstantIntScalarEdge called on non-integer edge: "
+ << scalar.dtype();
+ *result = scalar.scalar<int64>()();
+ }
+ }
+ return Status::OK();
+}
+
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
const Node* node, int dst_idx,
ShapeHandle* result) {
@@ -472,19 +493,11 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
std::vector<DimensionHandle> dims;
// Pack is concatenating its input scalars to form the shape tensor vector.
for (int i = 0; i < src_context->num_inputs(); ++i) {
- Tensor scalar;
- bool evaluated = false;
- TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
- &evaluated, &scalar));
+ int64 size;
+ bool evaluated;
+ TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
+ &evaluated, &size));
if (evaluated) {
- int64 size;
- if (scalar.dtype() == DT_INT32) {
- size = scalar.scalar<int32>()();
- } else if (scalar.dtype() == DT_INT64) {
- size = scalar.scalar<int64>()();
- } else {
- return errors::InvalidArgument("Pack input must be int32 or int64");
- }
dims.push_back(size < 0 ? target_context->UnknownDim()
: target_context->MakeDim(size));
} else {
@@ -514,6 +527,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
TF_RETURN_IF_ERROR(
target_context->Concatenate(*result, sub_result, result));
}
+ } else if (src_op == "StridedSlice") {
+ TF_RETURN_IF_ERROR(
+ PartialStridedSliceShape(input_edge->src(), src_context, result));
} else {
Tensor t;
bool evaluated = false;
@@ -525,6 +541,78 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
return Status::OK();
}
+Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
+ InferenceContext* ctx,
+ ShapeHandle* result) {
+ // Only attempt to evaluate if begin/end/strides all are scalars.
+ for (int i = 1; i <= 3; ++i) {
+ ShapeHandle input_shape = ctx->input(i);
+ if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+ }
+
+ int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
+ TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
+
+ // Only attempt to evaluate if there are no special masks set (note that we
+ // can handle begin/end_mask == 1).
+ if (!(begin_mask == 0 || begin_mask == 1) ||
+ !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
+ new_axis_mask != 0 || shrink_axis_mask != 0) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+
+ bool evaluated;
+ int64 begin;
+ if (begin_mask == 1) {
+ begin = 0;
+ } else {
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
+ if (!evaluated) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+ }
+
+ int64 end;
+ if (end_mask == 1) {
+ end = std::numeric_limits<int64>::max();
+ } else {
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
+ if (!evaluated) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+ }
+
+ int64 stride;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
+ if (!evaluated) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+
+ // Apply stride to input interpreted as a partial shape.
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
+ TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
+ return Status::OK();
+}
+
Status ShapeRefiner::RunShapeFn(const Node* node,
const OpRegistrationData* op_reg_data,
ExtendedInferenceContext* ec) {
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h
index d49c4373f0..9c96dcbc20 100644
--- a/tensorflow/core/common_runtime/shape_refiner.h
+++ b/tensorflow/core/common_runtime/shape_refiner.h
@@ -215,9 +215,18 @@ class ShapeRefiner {
bool keep_nested_shapes,
ExtendedInferenceContext* outer_context);
+ // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
+ // value can be evaluated, 'evaluated' is set to true and the value returned
+ // in 'result'. Otherwise 'evaluated' is set to false.
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
bool* evaluated, Tensor* result);
+ // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
+ // tensors. The caller is responsible for checking that the specified edge is
+ // scalar and int32 or int64.
+ Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
+ bool* evaluated, int64* result);
+
// This function tries to materialize as much information about the 'node''s
// dst_idx input as a statically computable shape, and the result may be
// partially known, depending on what is statically inferable.
@@ -243,6 +252,11 @@ class ShapeRefiner {
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
+ // Implementation of ConstantPartialShape for StridedSlice nodes.
+ Status PartialStridedSliceShape(Node* slice_node,
+ shape_inference::InferenceContext* ctx,
+ shape_inference::ShapeHandle* result);
+
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
ExtendedInferenceContext* ec);
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc
index f48638afc0..8b9657eec8 100644
--- a/tensorflow/core/common_runtime/shape_refiner_test.cc
+++ b/tensorflow/core/common_runtime/shape_refiner_test.cc
@@ -60,6 +60,39 @@ class ShapeRefinerTest : public ::testing::Test {
}
static constexpr int64 kMaxTensorSize = ShapeRefiner::kMaxTensorSize;
+
+ void TestStridedSlice(const PartialTensorShape& input_shape, int begin,
+ int end, int stride, const char* expected,
+ int begin_mask = 0, int end_mask = 0,
+ int ellipsis_mask = 0) {
+ Scope root = Scope::DisabledShapeInferenceScope();
+ auto placeholder =
+ ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
+ auto input = ops::Shape(root, placeholder);
+ auto begin_op = ops::Const(root, {begin});
+ auto end_op = ops::Const(root, {end});
+ auto stride_op = ops::Const(root, {stride});
+ auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op,
+ ops::StridedSlice::BeginMask(begin_mask)
+ .EndMask(end_mask)
+ .EllipsisMask(ellipsis_mask));
+ Node* result;
+ TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
+ .Input(slice.node())
+ .Finalize(root.graph(), &result));
+
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+ TF_ASSERT_OK(m.AddNode(placeholder.node()));
+ TF_ASSERT_OK(m.AddNode(input.node()));
+ TF_ASSERT_OK(m.AddNode(begin_op.node()));
+ TF_ASSERT_OK(m.AddNode(end_op.node()));
+ TF_ASSERT_OK(m.AddNode(stride_op.node()));
+ TF_ASSERT_OK(m.AddNode(slice.node()));
+ TF_ASSERT_OK(m.AddNode(result));
+
+ shape_inference::InferenceContext* ctx = m.GetContext(result);
+ EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected);
+ }
};
namespace {
@@ -1156,6 +1189,73 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
m.AddNode(result).error_message());
}
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) {
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3, -1, 5},
+ /*begin=*/2,
+ /*end=*/5,
+ /*stride=*/1,
+ /*expected=*/"[3,?,5]");
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) {
+ // clang-format off
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3, -1, 5},
+ /*begin=*/10,
+ /*end=*/0,
+ /*stride=*/-1,
+ /*expected=*/"[5,?,3,?]");
+ // clang-format on
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) {
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3, -1, 5},
+ /*begin=*/3,
+ /*end=*/4,
+ /*stride=*/1,
+ /*expected=*/"[1,?,3,?,5]",
+ /*begin_mask=*/1,
+ /*end_mask=*/1);
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) {
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3},
+ /*begin=*/2,
+ /*end=*/3,
+ /*stride=*/1,
+ /*expected=*/"[?,?,?]",
+ /*begin_mask=*/0,
+ /*end_mask=*/0,
+ /*ellipsis_mask=*/1);
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) {
+ Scope root = Scope::DisabledShapeInferenceScope();
+ auto input = ops::Placeholder(root, DT_INT32);
+ auto begin = ops::Const(root, {0, 0});
+ auto end = ops::Const(root, {2, 2});
+ auto stride = ops::Const(root, {1, 1});
+ auto slice = ops::StridedSlice(root, input, begin, end, stride);
+ Node* result;
+ TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
+ .Input(slice.node())
+ .Finalize(root.graph(), &result));
+
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+ TF_ASSERT_OK(m.AddNode(input.node()));
+ TF_ASSERT_OK(m.AddNode(begin.node()));
+ TF_ASSERT_OK(m.AddNode(end.node()));
+ TF_ASSERT_OK(m.AddNode(stride.node()));
+ TF_ASSERT_OK(m.AddNode(slice.node()));
+ TF_ASSERT_OK(m.AddNode(result));
+
+ shape_inference::InferenceContext* ctx = m.GetContext(result);
+ EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?");
+}
+
namespace {
// Dummy op to test ShapeRefiner util functions
diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto
index cce02d84b2..3f8dd272e7 100644
--- a/tensorflow/core/framework/api_def.proto
+++ b/tensorflow/core/framework/api_def.proto
@@ -56,8 +56,10 @@ message ApiDef {
// use a snake_case convention instead of CamelCase.
string name = 1;
- // First GraphDef version at which the op is disallowed.
- int32 deprecation_version = 2;
+ // If this endpoint is deprecated, set deprecation_message to a
+ // message that should be logged when the endpoint is used.
+ // The message should indicate alternative endpoint to use, if any.
+ string deprecation_message = 2;
}
repeated Endpoint endpoint = 3;
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 4145ef7bc9..62a9d5751d 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -269,4 +270,22 @@ const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE";
+namespace dataset {
+
+IteratorContext MakeIteratorContext(OpKernelContext* ctx) {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.lib = ctx->function_library();
+ // Note: must use reinterpret_cast because function.h forward-declares Device.
+ DeviceBase* device =
+ reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ return IteratorContext(params);
+}
+
+} // namespace dataset
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 775d9f6eb6..8624af9bf5 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -619,6 +619,12 @@ Status GetDatasetFromVariantTensor(const Tensor& tensor,
// The ownership of `dataset` is transferred to `tensor`.
Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
+namespace dataset {
+
+IteratorContext MakeIteratorContext(OpKernelContext* ctx);
+
+} // namespace dataset
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index bdc1af9fda..647c66099c 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -504,7 +504,7 @@ string Print(const NodeDef& n) {
std::vector<string> dep;
for (StringPiece s : n.input()) {
if (str_util::ConsumePrefix(&s, "^")) {
- dep.push_back(s.ToString());
+ dep.push_back(std::string(s));
} else {
dat.push_back(s);
}
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index f9cf6ce873..8e00bfe4f8 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -24,22 +24,23 @@ limitations under the License.
namespace tensorflow {
NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
- : node(n.ToString()), index(i), data_type(dt) {}
+ : node(std::string(n)), index(i), data_type(dt) {}
NodeDefBuilder::NodeOut::NodeOut() {
// uninitialized, call Reset() before use.
}
void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
- node = n.ToString();
+ node = std::string(n);
index = i;
data_type = dt;
}
NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
- node_def_.set_name(name.ToString());
- const Status status = op_registry->LookUpOpDef(op_name.ToString(), &op_def_);
+ node_def_.set_name(std::string(name));
+ const Status status =
+ op_registry->LookUpOpDef(std::string(op_name), &op_def_);
if (status.ok()) {
Initialize();
} else {
@@ -50,7 +51,7 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
: op_def_(op_def) {
- node_def_.set_name(name.ToString());
+ node_def_.set_name(std::string(name));
Initialize();
}
@@ -170,7 +171,7 @@ void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
} else if (src_index > 0) {
node_def_.add_input(strings::StrCat(src_node, ":", src_index));
} else {
- node_def_.add_input(src_node.ToString());
+ node_def_.add_input(std::string(src_node));
}
}
@@ -193,12 +194,12 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
}
NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
- control_inputs_.push_back(src_node.ToString());
+ control_inputs_.push_back(std::string(src_node));
return *this;
}
NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
- node_def_.set_device(device_spec.ToString());
+ node_def_.set_device(std::string(device_spec));
return *this;
}
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index bad92ca9b3..5798333dfe 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -245,7 +245,7 @@ DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
#undef DEFINE_GET_ATTR
bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
- return node_def.attr().find(attr_name.ToString()) != node_def.attr().end();
+ return node_def.attr().find(std::string(attr_name)) != node_def.attr().end();
}
static const string& kEmptyString = *new string();
@@ -639,7 +639,7 @@ Status AttachDef(const Status& status, const Node& node) {
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
node_def->mutable_attr()->insert(
- AttrValueMap::value_type(name.ToString(), value));
+ AttrValueMap::value_type(std::string(name), value));
}
#define ADD_NODE_ATTR(T) \
@@ -677,7 +677,7 @@ ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
#undef ADD_NODE_ATTR
void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
- map->insert(AttrValueMap::value_type(name.ToString(), value));
+ map->insert(AttrValueMap::value_type(std::string(name), value));
}
#define ADD_ATTR(T) \
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 403bd0b5e2..91eb6c0672 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -527,7 +527,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(op_name.ToString()); // NOLINT
+ op_def()->set_name(std::string(op_name)); // NOLINT
}
OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
@@ -584,7 +584,7 @@ OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(explanation.ToString());
+ deprecation->set_explanation(std::string(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index 7f23272871..3d7920a6e2 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -185,7 +185,7 @@ static bool FindMultiline(StringPiece line, size_t colon, string* end) {
while (str_util::ConsumePrefix(&line, " ")) {
}
if (str_util::ConsumePrefix(&line, "<<")) {
- *end = line.ToString();
+ *end = std::string(line);
return true;
}
return false;
@@ -306,9 +306,6 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
auto* endpoint = api_def->add_endpoint();
endpoint->set_name(op_def.name());
- if (op_def.has_deprecation()) {
- endpoint->set_deprecation_version(op_def.deprecation().version());
- }
for (const auto& op_in_arg : op_def.input_arg()) {
auto* api_in_arg = api_def->add_in_arg();
diff --git a/tensorflow/core/framework/op_gen_lib_test.cc b/tensorflow/core/framework/op_gen_lib_test.cc
index 857b1c8dbc..e0e77c7449 100644
--- a/tensorflow/core/framework/op_gen_lib_test.cc
+++ b/tensorflow/core/framework/op_gen_lib_test.cc
@@ -189,7 +189,6 @@ TEST(OpGenLibTest, ApiDefInitializedFromOpDef) {
visibility: VISIBLE
endpoint {
name: "testop"
- deprecation_version: 123
}
in_arg {
name: "arg_a"
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index ca91d68f79..c71bcb26ab 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -923,7 +923,7 @@ void OpKernelContext::clear_recorded_memory() {
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f)
- : def(d), kernel_class_name(c.ToString()), factory(f) {}
+ : def(d), kernel_class_name(std::string(c)), factory(f) {}
const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory;
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index c84ea3b034..3cc17e1ca6 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -338,6 +338,9 @@ class ResourceHandleOp : public OpKernel {
private:
string container_;
string name_;
+ mutex mutex_;
+ Tensor resource_ GUARDED_BY(mutex_);
+ std::atomic<bool> initialized_{false};
};
// Registers a kernel for an op which produces a handle to a resource of the
@@ -511,10 +514,17 @@ ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
template <typename T>
void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- output->scalar<ResourceHandle>()() =
- MakeResourceHandle<T>(ctx, container_, name_);
+ if (!initialized_.load()) {
+ mutex_lock ml(mutex_);
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
+ &resource_, attr));
+ resource_.scalar<ResourceHandle>()() =
+ MakeResourceHandle<T>(ctx, container_, name_);
+ initialized_.store(true);
+ }
+ ctx->set_output(0, resource_);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 2b995e8b5e..3185875e3b 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -605,10 +605,16 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start,
return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
}
-Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
+Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
ShapeHandle* out) {
- int64 start = start_in;
- int64 end = end_in;
+ return Subshape(s, start, end, 1 /* stride */, out);
+}
+
+Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
+ int64 stride, ShapeHandle* out) {
+ int64 start_in = start;
+ int64 end_in = end;
+
const int32 rank = Rank(s);
if (start == 0 && ((RankKnown(s) && end >= rank) ||
end == std::numeric_limits<int64>::max())) {
@@ -621,6 +627,9 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
if (start > rank) start = rank;
if (end > rank) end = rank;
+
+ if (stride < 0 && start == rank) --start;
+
if (start < 0) {
start = rank + start;
if (start < 0) {
@@ -638,16 +647,24 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
", for shape with rank ", rank);
}
}
- if (start > end) {
+ if (stride > 0 && start > end) {
*out = nullptr;
return errors::InvalidArgument(
"Subshape must have computed start <= end, but is ", start, " and ",
end, " (computed from start ", start_in, " and end ", end_in,
" over shape with rank ", rank, ")");
+ } else if (stride < 0 && start < end) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Subshape must have computed start >= end since stride is negative, "
+ "but is ",
+ start, " and ", end, " (computed from start ", start_in, " and end ",
+ end_in, " over shape with rank ", rank, " and stride", stride, ")");
}
+
std::vector<DimensionHandle> dims;
- dims.reserve(end - start);
- for (int i = start; i < end; ++i) {
+ dims.reserve((end - start) / stride);
+ for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
dims.push_back(Dim(s, i));
}
return ReturnCreatedShape(dims, out);
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 9431a62abe..3f3729dcf9 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -434,6 +434,13 @@ class InferenceContext {
Status Subshape(ShapeHandle s, int64 start, int64 end,
ShapeHandle* out) TF_MUST_USE_RESULT;
+ // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
+ // <start> and <end> can be negative, to index from the end of the shape.
+ // <start> and <end> are set to the rank of <s> if > rank of <s>.
+ // <stride> can be negative, to reverse the <s>.
+ Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
+
// Returns in <*out> the result of appending the dimensions of <s2> to those
// of <s1>.
Status Concatenate(ShapeHandle s1, ShapeHandle s2,
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index 2a99af7659..f6656b3b45 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -32,7 +32,7 @@ class Tensor;
struct ShapeInferenceTestOp {
typedef std::pair<string, DataType> ShapeAndType;
- explicit ShapeInferenceTestOp(StringPiece name) : name(name.ToString()) {}
+ explicit ShapeInferenceTestOp(StringPiece name) : name(std::string(name)) {}
string name;
NodeDef node_def;
std::vector<const Tensor*> input_tensors;
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index eeb6c60f71..71d0637dc2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -695,7 +695,7 @@ Status Graph::AddWhileContext(StringPiece frame_name,
std::vector<OutputTensor> body_outputs,
WhileContext** result) {
auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
- frame_name.ToString(),
+ std::string(frame_name),
WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
cond_output, std::move(body_inputs),
std::move(body_outputs))));
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index c678283fce..2fd32c0bd4 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -489,7 +489,7 @@ Status GraphConstructor::InitFromEdges() {
num_control_edges++;
} else {
TensorId id(ParseTensorName(input_name));
- if (next_iteration_nodes_.find(id.first.ToString()) !=
+ if (next_iteration_nodes_.find(std::string(id.first)) !=
next_iteration_nodes_.end()) {
has_loop_back_edge = true;
}
@@ -811,7 +811,7 @@ void GraphConstructor::UniquifyNames(
// We require that UniquifyNames() is called on all NodeDefs in topological
// order. This guarantees that node_def's inputs will already be uniquified
// if necessary.
- auto iter = uniquified_names_.find(id.first.ToString());
+ auto iter = uniquified_names_.find(std::string(id.first));
if (iter == uniquified_names_.end()) continue;
id.first = iter->second;
node_def->set_input(i, id.ToString());
@@ -830,7 +830,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
for (int i = 0; i < coloc_values.size(); ++i) {
StringPiece val(coloc_values[i]);
if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
- const auto& name_pair = uniquified_names_.find(val.ToString());
+ const auto& name_pair = uniquified_names_.find(std::string(val));
if (name_pair == uniquified_names_.end()) continue;
updated = true;
coloc_values[i] =
@@ -856,7 +856,7 @@ bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
}
string GraphConstructor::FindUniqueName(StringPiece original_name) {
- string name = original_name.ToString();
+ string name = std::string(original_name);
int count = 0;
// Check that any generated names don't collide with imported NodeDefs (as
// well as nodes in g_).
@@ -989,7 +989,7 @@ Status GraphConstructor::Convert() {
src_node->num_outputs(), " outputs");
}
- inputs.emplace_back(id.first.ToString(), src_node, src_index);
+ inputs.emplace_back(std::string(id.first), src_node, src_index);
}
if (has_data_back_edge && !IsMerge(*node_def)) {
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index b513778de9..c54b4fa269 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -157,7 +157,7 @@ class GraphConstructorTest : public ::testing::Test {
}
StringPiece loc(value[0]);
return str_util::ConsumePrefix(&loc, kColocationGroupPrefix)
- ? loc.ToString()
+ ? std::string(loc)
: "";
}
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
index 7a58347bd1..dd84c4f7c7 100644
--- a/tensorflow/core/graph/graph_def_builder.cc
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -44,12 +44,12 @@ GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
StringPiece name) {
- name_ = name.ToString();
+ name_ = std::string(name);
return *this;
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
StringPiece device) {
- device_ = device.ToString();
+ device_ = std::string(device);
return *this;
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
index 776a74c6d8..0d6aae4355 100644
--- a/tensorflow/core/graph/graph_def_builder.h
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -128,7 +128,7 @@ class GraphDefBuilder {
Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
template <class T>
Options WithAttrImpl(StringPiece name, T&& value) {
- attrs_.emplace_back(name.ToString(), AttrValue());
+ attrs_.emplace_back(std::string(name), AttrValue());
SetAttrValue(std::forward<T>(value), &attrs_.back().second);
return *this;
}
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 877e4f1b44..1b1941f9c1 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -785,7 +785,7 @@ Status TopologicalSortNodesWithTimePriority(
for (int n = 0; n < gdef->node_size(); ++n) {
const NodeDef* ndef = &gdef->node(n);
for (int i = 0; i < ndef->input_size(); ++i) {
- node_to_output_nodes[ParseTensorName(ndef->input(i)).first.ToString()]
+ node_to_output_nodes[std::string(ParseTensorName(ndef->input(i)).first)]
.push_back(ndef);
}
int64 start_time;
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index 114962c0e4..03f3bbd663 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -30,7 +30,7 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit)
dt(SafeGetOutput(node, i, &error)) {}
NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
- : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {}
+ : node(nullptr), error(false), name(std::string(n)), index(i), dt(t) {}
NodeBuilder::NodeOut::NodeOut()
: node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
diff --git a/tensorflow/core/graph/while_context.cc b/tensorflow/core/graph/while_context.cc
index 10a2b67f37..1b38aac35d 100644
--- a/tensorflow/core/graph/while_context.cc
+++ b/tensorflow/core/graph/while_context.cc
@@ -23,7 +23,7 @@ WhileContext::WhileContext(StringPiece frame_name,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs)
- : frame_name_(frame_name.ToString()),
+ : frame_name_(std::string(frame_name)),
enter_nodes_(std::move(enter_nodes)),
exit_nodes_(std::move(exit_nodes)),
cond_output_(cond_output),
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b35873ce38..2542fa2d67 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -27,10 +27,16 @@ namespace grappler {
constexpr int kOpsPerMac = 2;
constexpr char kConst[] = "Const";
+constexpr char kGuaranteeConst[] = "GuaranteeConst";
constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
+constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
+constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
+ "DepthwiseConv2dNativeBackpropFilter";
+constexpr char kDepthwiseConv2dNativeBackpropInput[] =
+ "DepthwiseConv2dNativeBackpropInput";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
@@ -200,11 +206,20 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
{kFusedConv2dBiasActivation,
wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
+ // reuse Conv2D for DepthwiseConv2dNative because the caculation is the
+ // same although the actual meaning of the parameters are different. See
+ // comments in PredictConv2D and related functions
+ {kDepthwiseConv2dNative, wrap(&OpLevelCostEstimator::PredictConv2D)},
+ {kDepthwiseConv2dNativeBackpropFilter,
+ wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
+ {kDepthwiseConv2dNativeBackpropInput,
+ wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kGuaranteeConst, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
@@ -537,18 +552,30 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
int64 OpLevelCostEstimator::CountConv2DOperations(
const OpInfo& op_features, ConvolutionDimensions* conv_info,
bool* found_unknown_shapes) const {
- if (op_features.op() != kConv2d) {
- LOG(ERROR) << "Invalid Operation";
- return 0;
- }
+ DCHECK(op_features.op() == kConv2d ||
+ op_features.op() == kDepthwiseConv2dNative)
+ << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
+
ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features,
found_unknown_shapes);
+ // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
+ // multiplier; The effective output channel depth oz_effective is
+ // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
+ // Compare to Conv2D where # ops = N x H x W x iz x oz x 2RS,
+ // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = iz.
int64 ops = conv_dims.batch;
ops *= conv_dims.ox * conv_dims.oy;
ops *= conv_dims.kx * conv_dims.ky;
- ops *= conv_dims.iz * conv_dims.oz;
+ if (op_features.op() == kConv2d) {
+ ops *= conv_dims.iz * conv_dims.oz;
+ } else {
+ // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
+ // although ops are the same as Conv2D.
+ conv_dims.oz *= conv_dims.iz;
+ ops *= conv_dims.oz;
+ }
ops *= kOpsPerMac;
if (conv_info != nullptr) {
@@ -795,7 +822,10 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
bool* found_unknown_shapes) const {
int64 ops = 0;
- DCHECK_EQ(kConv2dBackpropInput, op_features.op());
+ DCHECK(op_features.op() == kConv2dBackpropInput ||
+ op_features.op() == kDepthwiseConv2dNativeBackpropInput)
+ << "Invalid Operation: not kConv2dBackpropInput nor"
+ "kDepthwiseConv2dNativeBackpropInput";
if (op_features.inputs_size() < 2) {
*found_unknown_shapes = true;
@@ -828,10 +858,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
ops = conv_dims.batch;
ops *= conv_dims.ox * conv_dims.oy;
ops *= conv_dims.kx * conv_dims.ky;
- ops *= conv_dims.iz * conv_dims.oz;
- ops *= kOpsPerMac;
+ if (op_features.op() == kConv2dBackpropInput) {
+ ops *= conv_dims.iz * conv_dims.oz;
+ } else {
+ // conv_dims always use forward path definition regardless
+ conv_dims.oz *= conv_dims.iz;
+ ops *= conv_dims.oz;
+ }
- VLOG(1) << "Operations for Conv2DBackpropInput " << ops;
+ VLOG(1) << "Operations for" << op_features.op() << " " << ops;
if (returned_conv_dims != nullptr) {
*returned_conv_dims = conv_dims;
@@ -843,7 +878,11 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims,
bool* found_unknown_shapes) const {
int64 ops = 0;
- DCHECK_EQ(kConv2dBackpropFilter, op_features.op());
+
+ DCHECK(op_features.op() == kConv2dBackpropFilter ||
+ op_features.op() == kDepthwiseConv2dNativeBackpropFilter)
+ << "Invalid Operation: not kConv2dBackpropFilter nor"
+ "kDepthwiseConv2dNativeBackpropFilter";
TensorShapeProto filter_shape;
bool shape_found = false;
@@ -875,10 +914,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
ops = conv_dims.batch;
ops *= conv_dims.ox * conv_dims.oy;
ops *= conv_dims.kx * conv_dims.ky;
- ops *= conv_dims.iz * conv_dims.oz;
- ops *= kOpsPerMac;
+ if (op_features.op() == kConv2dBackpropFilter) {
+ ops *= conv_dims.iz * conv_dims.oz;
+ } else {
+ // conv_dims always use forward path definition regardless
+ conv_dims.oz *= conv_dims.iz;
+ ops *= conv_dims.oz;
+ }
- VLOG(1) << "Operations for Conv2DBackpropFilter" << ops;
+ VLOG(1) << "Operations for" << op_features.op() << " " << ops;
if (returned_conv_dims != nullptr) {
*returned_conv_dims = conv_dims;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 13ea43bed6..b2c021b73a 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -128,6 +128,23 @@ OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
return op_context;
}
+// Describe DepthwiseConvolution constructs an OpContext for a
+// DepthwiseConv2dNative applied to an input
+// tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
+// (kx, ky, iz2, cm). cm is channel multiplier
+
+OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
+ int iz2, int kx, int ky, int cm) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("DepthwiseConv2dNative");
+
+ DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+ DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
+
+ return op_context;
+}
+
// DescribeFusedConv2DBiasActivation constructs an OpContext for a
// FusedConv2DBiasActivation applied to a convolution input tensor with shape
// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
@@ -505,6 +522,15 @@ TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
EXPECT_FALSE(cost.inaccurate);
}
+TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
+ auto cost =
+ PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
+ EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 5b5e1e024e..900dfa95c5 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -604,6 +604,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 3f9feac55f..1f6f563687 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
bool remove_negation = true;
- bool hoist_cwise_unary_chains = true;
+ bool hoist_cwise_unary_chains = false;
bool convert_sqrt_div_to_rsqrt_mul = false;
bool remove_idempotent = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e109e66633..067adb359c 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -696,6 +696,9 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
ArithmeticOptimizer optimizer;
EnableOnlyHoistCommonFactor(&optimizer);
@@ -734,6 +737,13 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
EXPECT_EQ("id", id_node->name());
EXPECT_EQ(HoistDivName("add"), id_node->input(0));
}
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ if (use_ints) {
+ test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]);
+ } else {
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ }
}
}
}
@@ -1156,6 +1166,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveIdentityTranspose(&optimizer);
@@ -1168,6 +1183,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
EXPECT_EQ(node.input(2), "Split:2");
}
}
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
@@ -1184,6 +1203,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+ item.feed = {{"Placeholder", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveIdentityTranspose(&optimizer);
@@ -1194,6 +1218,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
EXPECT_EQ(2, outputs_node->input_size());
EXPECT_EQ(outputs_node->input(0), "outputs_const");
EXPECT_EQ(outputs_node->input(1), "^Placeholder");
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
@@ -1440,6 +1468,11 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
@@ -1451,6 +1484,10 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
EXPECT_EQ(3, output.node_size());
EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
@@ -1465,6 +1502,11 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
@@ -1476,6 +1518,10 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
@@ -1489,6 +1535,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantCast(&optimizer);
@@ -1500,6 +1551,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Cast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 47d8827686..e6a74dbdcd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -2370,115 +2370,124 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
}
}
- // Partial constant folding for Concat which is not commutative, so
- // we have to preserve order and can only push consecutive runs of constant
- // inputs into sub-nodes.
- if (IsConcat(*node) && num_non_control_inputs > 3 &&
- node->name().rfind("_partial_split_") == string::npos) {
- int axis_arg = -1;
- int begin = 0;
- int end = num_non_control_inputs;
- if (node->op() == "Concat") {
- begin = 1;
- axis_arg = 0;
- } else if (node->op() == "ConcatV2") {
- end = num_non_control_inputs - 1;
- axis_arg = num_non_control_inputs - 1;
- } else {
- continue;
- }
+ if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ continue;
+ }
+ }
- const NodeDef* axis_arg_node =
- node_map_->GetNode(NodeName(node->input(axis_arg)));
- if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
- // We cannot constant fold Concat unless we the axis argument is
- // constant. Skip node.
- continue;
- }
+ return Status::OK();
+}
- // We search for consecutive runs of constant inputs in the range
- // [begin:end[ and push then down into child nodes.
- std::vector<std::pair<int, int>> constant_input_runs;
- int first = begin;
- int last = begin;
- while (last < end) {
- while (first < end && !IsReallyConstant(*node_map_->GetNode(
- NodeName(node->input(first))))) {
- ++first;
- }
- // Invariant: node[first] is constant || first >= end.
- last = first + 1;
- while (last < end && IsReallyConstant(*node_map_->GetNode(
- NodeName(node->input(last))))) {
- ++last;
- }
- // Invariant: node[last] is not constant || last >= end
- // Discard intervals shorter than 2 elements.
- if (first < end && (last - first) > 1) {
- constant_input_runs.emplace_back(first, last);
- }
- first = last;
+bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
+ GraphProperties* properties,
+ NodeDef* node) {
+ // Partial constant folding for Concat which is not commutative, so
+ // we have to preserve order and can only push consecutive runs of constant
+ // inputs into sub-nodes.
+ const int num_non_control_inputs = NumNonControlInputs(*node);
+ if (IsConcat(*node) && num_non_control_inputs > 3 &&
+ node->name().rfind("_partial_split_") == string::npos) {
+ int axis_arg = -1;
+ int begin = 0;
+ int end = num_non_control_inputs;
+ if (node->op() == "Concat") {
+ begin = 1;
+ axis_arg = 0;
+ } else if (node->op() == "ConcatV2") {
+ end = num_non_control_inputs - 1;
+ axis_arg = num_non_control_inputs - 1;
+ } else {
+ return false;
+ }
+
+ const NodeDef* axis_arg_node =
+ node_map_->GetNode(NodeName(node->input(axis_arg)));
+ if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
+ // We cannot constant fold Concat unless we the axis argument is
+ // constant. Skip node.
+ return false;
+ }
+
+ // We search for consecutive runs of constant inputs in the range
+ // [begin:end[ and push then down into child nodes.
+ std::vector<std::pair<int, int>> constant_input_runs;
+ int first = begin;
+ int last = begin;
+ while (last < end) {
+ while (first < end && !IsReallyConstant(*node_map_->GetNode(
+ NodeName(node->input(first))))) {
+ ++first;
+ }
+ // Invariant: node[first] is constant || first >= end.
+ last = first + 1;
+ while (last < end && IsReallyConstant(*node_map_->GetNode(
+ NodeName(node->input(last))))) {
+ ++last;
}
+ // Invariant: node[last] is not constant || last >= end
+ // Discard intervals shorter than 2 elements.
+ if (first < end && (last - first) > 1) {
+ constant_input_runs.emplace_back(first, last);
+ }
+ first = last;
+ }
- // Skip if all inputs are constant, and let constant folding take over.
- if (constant_input_runs.size() == 1 &&
- constant_input_runs[0].first == begin &&
- constant_input_runs[0].second == end) {
- continue;
+ // Skip if all inputs are constant, and let constant folding take over.
+ if (constant_input_runs.size() == 1 &&
+ constant_input_runs[0].first == begin &&
+ constant_input_runs[0].second == end) {
+ return false;
+ }
+ std::set<int> inputs_to_delete;
+ for (auto interval : constant_input_runs) {
+ // Push the constant inputs in the interval to a child node than can be
+ // constant folded.
+ const string new_node_name = OptimizedNodeName(
+ *node, strings::StrCat("_partial_split_", interval.first));
+ if (node_map_->NodeExists(new_node_name)) {
+ break;
}
- std::set<int> inputs_to_delete;
- for (auto interval : constant_input_runs) {
- // Push the constant inputs in the interval to a child node than can be
- // constant folded.
- const string new_node_name = OptimizedNodeName(
- *node, strings::StrCat("_partial_split_", interval.first));
- if (node_map_->NodeExists(new_node_name)) {
- break;
- }
- NodeDef* added_node = optimized_graph->add_node();
- *added_node = *node;
- added_node->set_name(new_node_name);
- node_map_->AddNode(added_node->name(), added_node);
- added_node->clear_input();
- for (int i = interval.first; i < interval.second; ++i) {
- added_node->add_input(node->input(i));
- node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
- added_node->name());
- if (i != interval.first) {
- inputs_to_delete.insert(i);
- }
+ NodeDef* added_node = optimized_graph->add_node();
+ *added_node = *node;
+ added_node->set_name(new_node_name);
+ node_map_->AddNode(added_node->name(), added_node);
+ added_node->clear_input();
+ for (int i = interval.first; i < interval.second; ++i) {
+ added_node->add_input(node->input(i));
+ node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
+ added_node->name());
+ if (i != interval.first) {
+ inputs_to_delete.insert(i);
}
- added_node->add_input(node->input(axis_arg));
- (*added_node->mutable_attr())["N"].set_i(interval.second -
- interval.first);
- node_map_->AddOutput(NodeName(node->input(axis_arg)),
- added_node->name());
-
- // Overwrite the first constant input with the result of the added
- // child node.
- node->set_input(interval.first, added_node->name());
- node_map_->AddOutput(added_node->name(), node->name());
}
- if (!constant_input_runs.empty()) {
- graph_modified_ = true;
- if (!inputs_to_delete.empty()) {
- // Fix up the inputs to the original node.
- std::vector<string> tmp(node->input().begin(), node->input().end());
- node->clear_input();
- for (int i = 0; i < tmp.size(); ++i) {
- if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
- node->add_input(tmp[i]);
- }
+ added_node->add_input(node->input(axis_arg));
+ (*added_node->mutable_attr())["N"].set_i(interval.second -
+ interval.first);
+ node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
+
+ // Overwrite the first constant input with the result of the added
+ // child node.
+ node->set_input(interval.first, added_node->name());
+ node_map_->AddOutput(added_node->name(), node->name());
+ }
+ if (!constant_input_runs.empty()) {
+ if (!inputs_to_delete.empty()) {
+ // Fix up the inputs to the original node.
+ std::vector<string> tmp(node->input().begin(), node->input().end());
+ node->clear_input();
+ for (int i = 0; i < tmp.size(); ++i) {
+ if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
+ node->add_input(tmp[i]);
}
- (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
- properties->ClearInputProperties(node->name());
}
- continue;
+ (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
+ properties->ClearInputProperties(node->name());
}
+ return true;
}
}
-
- return Status::OK();
+ return false;
}
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index a694f1721a..2096576538 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -101,6 +101,11 @@ class ConstantFolding : public GraphOptimizer {
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
+ // Applies partial constant folding for Concat which is not commutative.
+ // Returns true if the transformation applied successfully.
+ bool PartialConcatConstFolding(GraphDef* optimized_graph,
+ GraphProperties* properties, NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 1bec9086f7..a44e1ee7f9 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -14,10 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
+
#include <unordered_map>
+
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
@@ -74,6 +77,73 @@ string UniqueSpecializedFunctionName(const FunctionDef& func,
return unique_name;
}
+// Specialized function instantiation type parameters, body parameters, and
+// const inputs.
+struct FunctionSpecializationSignature {
+ string func_name;
+ std::unordered_map<string, DataType> type_parameters;
+ std::unordered_map<string, AttrValue> body_parameters;
+ std::unordered_map<int, string> const_inputs;
+
+ bool operator==(const FunctionSpecializationSignature& other) const {
+ bool equals = func_name == other.func_name &&
+ type_parameters == other.type_parameters &&
+ const_inputs == other.const_inputs;
+
+ if (!equals) return false;
+
+ // Equality is not defined for AttrValue.
+ if (body_parameters.size() != other.body_parameters.size()) return false;
+
+ for (const auto& lhs : body_parameters) {
+ auto it = other.body_parameters.find(lhs.first);
+ if (it == other.body_parameters.end()) return false;
+ if (!AreAttrValuesEqual(lhs.second, (*it).second)) return false;
+ }
+
+ return true;
+ }
+
+ struct Hash {
+ uint64 operator()(FunctionSpecializationSignature const& s) const {
+ uint64 h = Hash64(s.func_name);
+
+ // Use std::map for deterministic iteration order.
+
+ std::map<string, DataType> types(s.type_parameters.begin(),
+ s.type_parameters.end());
+ for (const auto& pair : types) {
+ AttrValue attr_value;
+ attr_value.set_type(pair.second);
+ h = Hash64Combine(Hash64(pair.first), h);
+ h = Hash64Combine(AttrValueHash(attr_value), h);
+ }
+
+ std::map<string, AttrValue> body(s.body_parameters.begin(),
+ s.body_parameters.end());
+ for (const auto& pair : body) {
+ h = Hash64Combine(Hash64(pair.first), h);
+ h = Hash64Combine(AttrValueHash(pair.second), h);
+ }
+
+ std::map<int, string> inputs(s.const_inputs.begin(),
+ s.const_inputs.end());
+ for (const auto& pair : inputs) {
+ h = Hash64Combine(std::hash<int>()(pair.first), h);
+ h = Hash64Combine(Hash64(pair.second), h);
+ }
+
+ return h;
+ }
+ };
+};
+
+struct FunctionSpecialization {
+ string specialized_func_name;
+ std::unordered_set<string> const_inputs;
+ std::unordered_set<string> control_deps;
+};
+
class FunctionOptimizerContext {
public:
explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level,
@@ -108,6 +178,16 @@ class FunctionOptimizerContext {
return gtl::FindWithDefault(inlined_functions_, name, nullptr);
}
+ const FunctionSpecialization* FindFunctionSpecialization(
+ const FunctionSpecializationSignature& sig) const {
+ return gtl::FindOrNull(specialized_functions_, sig);
+ }
+
+ void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
+ const FunctionSpecialization& specialized_func) {
+ specialized_functions_.emplace(sig, specialized_func);
+ }
+
private:
void InitializeTrulyConstNodes(const GrapplerItem& item) {
std::unordered_set<string> feed_nodes;
@@ -148,6 +228,12 @@ class FunctionOptimizerContext {
// Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
+ // Specialized functions.
+ std::unordered_map<FunctionSpecializationSignature,
+ const FunctionSpecialization,
+ FunctionSpecializationSignature::Hash>
+ specialized_functions_;
+
TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
};
@@ -303,14 +389,34 @@ void RemovePushedDownConstInputs(const std::unordered_set<string>& const_inputs,
for (const string& ctrl : control_deps) {
if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
- VLOG(3) << "Forward control dependency to function caller node: input="
- << ctrl;
+ VLOG(3) << "Forward control dependency: input=" << ctrl;
specialized_func_node->add_input(ctrl);
}
}
}
}
+Status InitializeFunctionSpecializationSignature(
+ const NodeDef& func_node, const FunctionDef& func,
+ const AttrValueMap& func_attr, const FunctionOptimizerContext& ctx,
+ FunctionSpecializationSignature* sig) {
+ sig->func_name = func.signature().name();
+
+ TF_RETURN_IF_ERROR(
+ InstantiationTypeParameters(func, func_attr, &sig->type_parameters));
+ TF_RETURN_IF_ERROR(
+ InstantiationBodyParameters(func, func_attr, &sig->body_parameters));
+
+ for (int i = 0; i < func_node.input_size(); ++i) {
+ const string& input = func_node.input(i);
+ if (ctx.IsTrulyConst(input)) {
+ sig->const_inputs.emplace(i, input);
+ }
+ }
+
+ return Status::OK();
+}
+
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
@@ -320,6 +426,32 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
+ FunctionSpecializationSignature signature;
+ TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
+ func_node, func, func_attr, *ctx, &signature));
+
+ // Check if function was already specialized for identical context.
+ const FunctionSpecialization* already_specialized =
+ ctx->FindFunctionSpecialization(signature);
+
+ if (already_specialized) {
+ VLOG(2) << "Function was already specialized in identical context: "
+ "specialized_name="
+ << already_specialized->specialized_func_name;
+
+ // Add a function call node for the specialized function.
+ NodeDef* specialized_func_node = optimized_graph->add_node();
+ *specialized_func_node = func_node;
+ specialized_func_node->set_op(already_specialized->specialized_func_name);
+
+ RemovePushedDownConstInputs(already_specialized->const_inputs,
+ already_specialized->control_deps,
+ specialized_func_node);
+
+ return Status::OK();
+ }
+
+ // Add a new specialized function definition to the library.
const auto& flib = ctx->function_library();
// Make a GrapplerFunctionItem and convert it back to FunctionDef after
@@ -358,6 +490,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Update specialized node to remove inputs for pushed down consts.
RemovePushedDownConstInputs(const_inputs, control_deps,
specialized_func_node);
+
+ ctx->AddSpecializedFunction(
+ signature, {specialized_func_name, const_inputs, control_deps});
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 147a264421..0aaf57e947 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -718,5 +718,147 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_PushDownConstInput) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
+TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) {
+ using test::function::NDef;
+
+ FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
+
+ // Mark MyMul as noinline.
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+ (*mul_func.mutable_attr())["_noinline"].set_b(true);
+ std::vector<FunctionDef> function_library = {mul_func};
+
+ const Tensor kTwo = test::AsScalar<float>(2.0);
+ const Tensor kThree = test::AsScalar<float>(3.0);
+
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("init", "NoOp", {}, {}, kDevice),
+
+ // Float placeholders.
+ NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+
+ // Int32 placeholders.
+ NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+ NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
+
+ // Consts. Control inputs has to be attached to specialized func calls.
+ NDef("two", "Const", {"^init", "^xf"},
+ {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
+ NDef("three", "Const", {"^init", "^xf"},
+ {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice),
+
+ // Specialization #1: DT_FLOAT type parameter.
+ NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice),
+
+ // Specialization #2: DT_INT32 type parameter.
+ NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice),
+
+ // Specialization #3: DT_FLOAT type parameter + const input kTwo.
+ NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice),
+
+ // Specialization #4: DT_FLOAT type parameter + const input kThree.
+ NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)},
+ function_library);
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Make sure that MyMul was specialized once per unique context.
+ EXPECT_EQ(4, output.library().function_size());
+
+ // And graph nodes calling specialized functions.
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "mul_1" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_1", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("yf", node.input(1));
+
+ } else if (node.name() == "mul_2" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_1", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("yf", node.input(0));
+ EXPECT_EQ("xf", node.input(1));
+
+ } else if (node.name() == "mul_3" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_3", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xi", node.input(0));
+ EXPECT_EQ("yi", node.input(1));
+
+ } else if (node.name() == "mul_4" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_4", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+
+ } else if (node.name() == "mul_5" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_4", node.op());
+ ASSERT_EQ(3, node.input_size());
+ EXPECT_EQ("yf", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+ EXPECT_EQ("^xf", node.input(2));
+
+ } else if (node.name() == "mul_6" && count++) {
+ EXPECT_EQ("MyMul_specialized_for_mul_6", node.op());
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+ }
+ }
+ EXPECT_EQ(6, count);
+
+ // And that graph evaluation yields the same result.
+ Tensor pi = test::AsScalar<float>(3.14f);
+ Tensor four = test::AsScalar<int32>(4);
+ item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"};
+ item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}};
+
+ auto tensors_expected = EvaluateFetchNodes(item);
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+ test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
+ test::ExpectTensorEqual<int32>(tensors_expected[2], tensors[2]);
+ test::ExpectTensorEqual<float>(tensors_expected[3], tensors[3]);
+ test::ExpectTensorEqual<float>(tensors_expected[4], tensors[4]);
+ test::ExpectTensorEqual<float>(tensors_expected[5], tensors[5]);
+}
+
+TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) {
+ using test::function::NDef;
+ FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
+ DisableFunctionSpecialization(&optimizer);
+ auto func = test::function::XTimesTwo();
+ (*func.mutable_attr())["_noinline"].set_b(true);
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, "/device:CPU:0"),
+ NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, "/device:CPU:0"),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, "/device:CPU:0")},
+ // FunctionLib
+ {
+ func,
+ test::function::XTimesTwoInt32(),
+ test::function::XTimes16(),
+ });
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(output.library().function().size(), 1);
+ EXPECT_EQ(output.library().function(0).signature().name(), "XTimesTwo");
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 5adc5b9227..7d3520febc 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
@@ -504,6 +505,140 @@ Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve,
return Status::OK();
}
+Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve,
+ GraphDef* optimized_graph) {
+ std::unordered_set<const NodeDef*> dead_nodes;
+ std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
+ // TODO(bsteiner): also rewrite switches as identity. For now we just record
+ // them
+ std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
+ identity_switches;
+
+ GraphView view(optimized_graph);
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (!IsSwitch(node)) {
+ continue;
+ }
+ if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
+ continue;
+ }
+ GraphView::InputPort ctrl_port(&node, 1);
+ GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port);
+ if (!IsConstant(*ctrl_node.node)) {
+ continue;
+ }
+ Tensor selector;
+ CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor()));
+ const int dead_fanout = selector.scalar<bool>()() ? 0 : 1;
+ GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout);
+ identity_switches.insert(dead);
+
+ SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs;
+ for (const GraphView::InputPort& port : view.GetFanout(dead)) {
+ if (dead_nodes.find(port.node) == dead_nodes.end()) {
+ zombie_inputs.PushBack(port);
+ }
+ }
+ // If we encounter a single node that must be preserved in the fanout of the
+ // switch node we need to preserve the entire switch fanout: we therefore
+ // work on a local copy that only gets committed to the master copy once the
+ // whole fanout has been explored.
+ std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
+ std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
+ dead_merge_inputs;
+ bool found_node_to_preserve = false;
+ while (!found_node_to_preserve && !zombie_inputs.Empty()) {
+ GraphView::InputPort dead = zombie_inputs.PopBack();
+ if (nodes_to_preserve.find(dead.node->name()) !=
+ nodes_to_preserve.end()) {
+ found_node_to_preserve = true;
+ break;
+ }
+
+ if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
+ continue;
+ }
+
+ if (IsMerge(*dead.node)) {
+ const int fanout = dead.node->attr().at("N").i();
+ if (fanout > 2) {
+ // This never happens in practice, so we'll just skip these to
+ // simplify the code for now.
+ found_node_to_preserve = true;
+ break;
+ }
+ GraphView::OutputPort value_index(dead.node, 1);
+ const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
+ index_fanout = view.GetFanout(value_index);
+ if (!index_fanout.empty()) {
+ // The 2nd output (that indicates which input is propagated) is
+ // connected. This never happens in practice, so we'll just skip this
+ // case to simplify the code for now.
+ found_node_to_preserve = true;
+ break;
+ }
+
+ bool fully_dead = false;
+ if (dead.port_id < 0) {
+ // If the control dependency never gets triggered the merge will also
+ // never get triggered.
+ local_dead_nodes.insert(dead.node);
+ fully_dead = true;
+ } else {
+ local_dead_merge_inputs[dead.node].insert(dead.port_id);
+ if (local_dead_merge_inputs[dead.node].size() ==
+ dead.node->attr().at("N").i()) {
+ fully_dead = true;
+ }
+ if (fully_dead) {
+ local_dead_nodes.insert(dead.node);
+ for (const GraphView::InputPort& port :
+ view.GetFanouts(*dead.node, true)) {
+ zombie_inputs.PushBack(port);
+ }
+ }
+ }
+ } else {
+ if (local_dead_nodes.insert(dead.node).second) {
+ for (const GraphView::InputPort& dead_fanout :
+ view.GetFanouts(*dead.node, true)) {
+ zombie_inputs.PushBack(dead_fanout);
+ }
+ }
+ }
+ }
+ if (!found_node_to_preserve) {
+ std::swap(dead_nodes, local_dead_nodes);
+ std::swap(dead_merge_inputs, local_dead_merge_inputs);
+ }
+ }
+
+ int last = optimized_graph->node_size() - 1;
+ for (int i = optimized_graph->node_size() - 1; i >= 0; --i) {
+ NodeDef* node = optimized_graph->mutable_node(i);
+ if (dead_nodes.find(node) != dead_nodes.end()) {
+ optimized_graph->mutable_node()->SwapElements(i, last);
+ last--;
+ }
+ }
+ optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size());
+
+ for (const auto& itr : dead_merge_inputs) {
+ NodeDef* dead_node = itr.first;
+ if (dead_nodes.find(dead_node) != dead_nodes.end()) {
+ // The node has been pruned since all its inputs are dead.
+ continue;
+ }
+ const std::set<int>& dead_inputs = itr.second;
+ for (int index : dead_inputs) {
+ dead_node->mutable_input()->DeleteSubrange(index, 1);
+ }
+ dead_node->set_op("Identity");
+ dead_node->mutable_attr()->erase("N");
+ }
+ return Status::OK();
+}
+
} // namespace
Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -517,6 +652,11 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (options_.enable_stack_push_removal) {
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
}
+ if (opt_level_ == RewriterConfig::AGGRESSIVE &&
+ options_.enable_dead_branch_removal) {
+ TF_RETURN_IF_ERROR(
+ RemoveDeadBranches(item.NodesToPreserve(), optimized_graph));
+ }
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h
index 764506f7c1..85b8e65543 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h
@@ -54,6 +54,7 @@ class LoopOptimizer : public GraphOptimizer {
struct LoopOptimizerOptions {
bool enable_loop_invariant_node_motion = false;
bool enable_stack_push_removal = true;
+ bool enable_dead_branch_removal = true;
static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) {
LoopOptimizerOptions options;
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
index 10ec544424..6fd177b710 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
@@ -589,5 +589,112 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
}
}
+TEST_F(LoopOptimizerTest, RemoveDeadBranches) {
+ Scope scope = Scope::NewRootScope();
+ Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT);
+
+ Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({}));
+ ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
+ Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
+ Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true);
+
+ Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({}));
+ ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2);
+ Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false);
+ Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true);
+
+ Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({}));
+ ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3);
+ Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false);
+ Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true);
+
+ Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({}));
+ ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4);
+ Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false);
+ Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true);
+
+ ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1});
+ ops::Merge m2(scope.WithOpName("m2"), {v_in, square1});
+ ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
+ ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
+ ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
+ ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
+ {v_in, square1});
+ ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
+ {v_in, square1});
+
+ ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
+ Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
+ Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true);
+ ops::Merge m8(scope.WithOpName("m8"), {id1, id2});
+
+ ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1);
+ Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false);
+ Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true);
+ ops::Merge m9(scope.WithOpName("m9"), {id3, id4});
+
+ GrapplerItem item;
+ item.fetch.push_back("m8");
+ item.fetch.push_back("id4");
+
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_CHECK_OK(status);
+
+ for (const NodeDef& node : output.node()) {
+ // These nodes should have been pruned
+ EXPECT_NE("Square1", node.name());
+ EXPECT_NE("Sqrt2", node.name());
+ EXPECT_NE("m5", node.name());
+ EXPECT_NE("m7", node.name());
+
+ if (node.name() == "m1") {
+ // sqrt1 is dead
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square1", node.input(0));
+ } else if (node.name() == "m2") {
+ // both inputs are alive
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("v_in", node.input(0));
+ EXPECT_EQ("square1", node.input(1));
+ } else if (node.name() == "m3") {
+ // sqrt1 is dead
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("v_in", node.input(0));
+ } else if (node.name() == "m4") {
+ // both inputs are alive
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("square1", node.input(0));
+ EXPECT_EQ("sqrt2", node.input(1));
+ } else if (node.name() == "m6") {
+ // both inputs are alive and the control dependency can get triggered
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("v_in", node.input(0));
+ EXPECT_EQ("square1", node.input(1));
+ EXPECT_EQ("^sqrt2", node.input(2));
+ } else if (node.name() == "m8") {
+ // The node is to be preserved because of a fetch
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("id1", node.input(0));
+ EXPECT_EQ("id2", node.input(1));
+ } else if (node.name() == "m9") {
+ // The node is to be preserved because of a fetch
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("id3", node.input(0));
+ EXPECT_EQ("id4", node.input(1));
+ }
+ }
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 887a988af9..8247cce339 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -163,30 +163,28 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
output.library());
// Specialized and optimized functions should be added to the graph.
- EXPECT_EQ(6, optimized_flib.num_functions());
+ EXPECT_EQ(5, optimized_flib.num_functions());
// MyQuadratic should be specialized once:
// 0. 'quadratic' node in the main graph
const string optimized_0 = "MyQuadratic_specialized_for_quadratic";
// MySquare should be specialized and optimized for 3 instantiations:
- // 1. 'square' node in the main graph
- // 2. 'square' node in the MyQuadratic specialization
- // 3. 'quadratic' node in the MyQuadratic specialization
+ // 1. 'square' node in the main graph
+ // 2. 'square' node in the MyQuadratic specialization
+ // 3*. 'quadratic' node in the MyQuadratic specialization
+ // has identical instantiation context to #2
const string optimized_1 = "MySquare_specialized_for_square";
const string optimized_2 = "MySquare_specialized_for_square_1";
- const string optimized_3 = "MySquare_specialized_for_quadratic";
const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
- const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3);
ASSERT_NE(optimized_func_0, nullptr);
ASSERT_NE(optimized_func_1, nullptr);
ASSERT_NE(optimized_func_2, nullptr);
- ASSERT_NE(optimized_func_3, nullptr);
// Graph should call optimized function.
int count = 0;
@@ -205,13 +203,14 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
if (node.name() == "square" && count++) {
EXPECT_EQ(optimized_2, node.op());
} else if (node.name() == "quadratic" && count++) {
- EXPECT_EQ(optimized_3, node.op());
+ // Share specialized function with the 'square' node.
+ EXPECT_EQ(optimized_2, node.op());
}
}
EXPECT_EQ(2, count);
- const std::vector<const FunctionDef*> optimized_funcs = {
- optimized_func_1, optimized_func_1, optimized_func_3};
+ const std::vector<const FunctionDef*> optimized_funcs = {optimized_func_1,
+ optimized_func_2};
// MyMul should be inlined into all optimized versions of MySquare.
for (const FunctionDef* optimized_func : optimized_funcs) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index b87ae05546..1c6fef59ea 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -65,7 +65,7 @@ class NodeMap {
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
-template <class T>
+template <class T, class Hash = std::hash<T>>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
@@ -91,7 +91,7 @@ class SetVector {
void Reserve(int64 size) { vector_.reserve(size); }
private:
- std::unordered_set<T> set_;
+ std::unordered_set<T, Hash> set_;
std::vector<T> vector_;
};
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 79b823fa2d..34603f9869 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -417,6 +417,63 @@ bool IsParametrized(const FunctionDef& func) {
return HasParametrizedType(func) || HasParametrizedBody(func);
}
+Status InstantiationTypeParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, DataType>* type_parameters) {
+ if (!type_parameters->empty()) {
+ return errors::InvalidArgument("Type parameters output map must be empty");
+ }
+
+ GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr);
+
+ const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
+ // Check if it's unknown and unresolved type.
+ if (arg.type() == DT_INVALID &&
+ type_parameters->find(arg.type_attr()) == type_parameters->end()) {
+ DataType data_type;
+ TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
+ type_parameters->insert({arg.type_attr(), data_type});
+ }
+ return Status::OK();
+ };
+
+ for (const auto& input : func.signature().input_arg())
+ TF_RETURN_IF_ERROR(resolve_type_attr(input));
+ for (const auto& output : func.signature().output_arg())
+ TF_RETURN_IF_ERROR(resolve_type_attr(output));
+
+ return Status::OK();
+}
+
+Status InstantiationBodyParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, AttrValue>* body_parameters) {
+ if (!body_parameters->empty()) {
+ return errors::InvalidArgument("Body parameters output map must be empty");
+ }
+
+ for (const NodeDef& func_body_node : func.node_def()) {
+ for (auto& attr : func_body_node.attr()) {
+ const string& placeholder = attr.second.placeholder();
+
+ if (placeholder.empty() ||
+ body_parameters->find(placeholder) != body_parameters->end()) {
+ continue;
+ }
+
+ auto it = func_instantiation_attr.find(placeholder);
+ if (it != func_instantiation_attr.end()) {
+ body_parameters->emplace(placeholder, it->second);
+ } else {
+ return errors::InvalidArgument("Can't resolve placeholder: ",
+ placeholder);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index d9d71b80eb..4641bf5252 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -191,6 +191,19 @@ bool HasParametrizedBody(const FunctionDef& func);
// Check if function has parametrized type or body.
bool IsParametrized(const FunctionDef& func);
+// Resolve function instantiation type parameters from the attributes of the
+// caller node. Return error if type can't be resolved.
+Status InstantiationTypeParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, DataType>* type_parameters);
+
+// Resolve function instantiation body parameters (values for the function body
+// attr placeholders) from the attributes of the caller node. Return error if
+// type can't be resolved.
+Status InstantiationBodyParameters(
+ const FunctionDef& func, const AttrValueMap& func_instantiation_attr,
+ std::unordered_map<string, AttrValue>* body_parameters);
+
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity. Use function library definition to
// lookup function body nodes output names and ranges.
@@ -205,10 +218,10 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
// Make a GrapplerFunctionItem from the function definition and function
// instantiation attributes (caller node attributes). Returns error if the given
// function def cannot be converted (e.g. not all attributes are defined).
-Status MakeGrapplerFunctionItem(
- const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_instantiation_attr,
- const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+ const AttrValueMap& func_instantiation_attr,
+ const FunctionLibraryDefinition& flib,
+ GrapplerFunctionItem* item);
// Make a GrapplerFunction item from the function definition. Function must be
// fully defined (no type or body parametrization).
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index fa6fec70ff..15d8437438 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -54,6 +54,44 @@ TEST_F(FunctionsTest, IsParametrized) {
EXPECT_FALSE(IsParametrized(non_parametrized_func));
}
+TEST_F(FunctionsTest, InstantiationParameters) {
+ // Function definition is invalid, only type/body parameters are important.
+ FunctionDef func = FunctionDefHelper::Create(
+ "ParametrizedFunc",
+ /* inputs */
+ {"input1:A", "input2:B", "input3:float"},
+ /* outputs */
+ {"output1: A", "output2:C"},
+ /* type parameters */
+ {"A: {float, double}", "B: {float, int32}", "C: {float, double}"},
+ /* function body*/
+ {{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"x", "cx:output:0"}, {"y", "cy:output:0"}});
+
+ std::unordered_map<string, AttrValue> func_instantiation_attr;
+ func_instantiation_attr["key"].set_s("key-value");
+ func_instantiation_attr["A"].set_type(DT_FLOAT);
+ func_instantiation_attr["B"].set_type(DT_INT32);
+ func_instantiation_attr["C"].set_type(DT_DOUBLE);
+
+ std::unordered_map<string, DataType> type_parameters;
+ TF_EXPECT_OK(InstantiationTypeParameters(func, func_instantiation_attr,
+ &type_parameters));
+
+ ASSERT_EQ(3, type_parameters.size());
+ EXPECT_EQ(DT_FLOAT, type_parameters["A"]);
+ EXPECT_EQ(DT_INT32, type_parameters["B"]);
+ EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
+
+ std::unordered_map<string, AttrValue> body_parameters;
+ TF_EXPECT_OK(InstantiationBodyParameters(func, func_instantiation_attr,
+ &body_parameters));
+
+ ASSERT_EQ(1, body_parameters.size());
+ EXPECT_EQ("key-value", body_parameters["key"].s());
+}
+
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
GrapplerFunctionConnectivity connectivity;
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index 67ddb52d57..c608f9e1c6 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -46,18 +46,6 @@ Status MakeIteratorFromInputElement(
return Status::OK();
}
-IteratorContext MakeIteratorContext(OpKernelContext* ctx) {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.lib = ctx->function_library();
- DeviceBase* device = ctx->function_library()->device();
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- return IteratorContext(params);
-}
-
} // namespace dataset
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index e5ca71dd99..6c4191c2be 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -28,8 +28,6 @@ Status MakeIteratorFromInputElement(
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
-IteratorContext MakeIteratorContext(OpKernelContext* ctx);
-
} // namespace dataset
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index a2f6c5fe2c..b6bf0ecd09 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -1051,7 +1051,7 @@ class DeserializeIteratorOp : public OpKernel {
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
-
+ core::ScopedUnref unref_iterator(iterator_resource);
Variant variant = ctx->input(1).scalar<Variant>()();
auto* wrapper = variant.get<IteratorStateVariant>();
OP_REQUIRES(ctx, wrapper != nullptr,
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index c9551fbf16..729b615e56 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <utility>
+
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@@ -21,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
@@ -36,7 +39,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ graph_def_version_(ctx->graph_def_version()),
+ op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -59,12 +63,29 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx, batch_size > 0,
errors::InvalidArgument("batch_size must be greater than zero."));
- int64 num_parallel_batches;
- OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_batches",
- &num_parallel_batches));
- OP_REQUIRES(ctx, num_parallel_batches > 0,
- errors::InvalidArgument(
- "num_parallel_batches must be greater than zero."));
+ int64 num_parallel_calls;
+ switch (op_version_) {
+ case 1:
+ int64 num_parallel_batches;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_batches",
+ &num_parallel_batches));
+ num_parallel_calls = num_parallel_batches * batch_size;
+ OP_REQUIRES(ctx, num_parallel_batches > 0,
+ errors::InvalidArgument(
+ "num_parallel_batches must be greater than zero."));
+ break;
+ case 2:
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+ break;
+ default:
+ OP_REQUIRES(ctx, false,
+ errors::Unimplemented("Unsupported operation version %d.",
+ op_version_));
+ }
bool drop_remainder;
OP_REQUIRES_OK(ctx,
@@ -74,7 +95,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
func_, std::move(other_arguments), &captured_func));
- *output = new Dataset(ctx, input, batch_size, num_parallel_batches,
+ *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
std::move(captured_func), &ctx->eigen_cpu_device());
}
@@ -83,7 +104,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Dataset : public GraphDatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
- int64 num_parallel_batches, bool drop_remainder,
+ int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
const NameAttrList& func,
@@ -92,7 +113,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
: GraphDatasetBase(ctx),
input_(input),
batch_size_(batch_size),
- num_parallel_batches_(num_parallel_batches),
+ num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
@@ -128,9 +149,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
- Node* num_parallel_batches_node;
+ Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
- b->AddScalar(num_parallel_batches_, &num_parallel_batches_node));
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
Node* drop_remainder_node;
TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
@@ -153,7 +174,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
this,
{std::make_pair(0, input_graph_node),
std::make_pair(2, batch_size_node),
- std::make_pair(3, num_parallel_batches_node),
+ std::make_pair(3, num_parallel_calls_node),
std::make_pair(4, drop_remainder_node)}, // Single tensor inputs.
{std::make_pair(1, other_arguments)}, // Tensor list inputs.
{std::make_pair("f", f),
@@ -168,129 +189,54 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
- invocation_results_(params.dataset->batch_size_ *
- params.dataset->num_parallel_batches_),
- batch_results_(params.dataset->num_parallel_batches_) {}
+ batch_results_((params.dataset->num_parallel_calls_ +
+ params.dataset->batch_size_ - 1) /
+ params.dataset->batch_size_) {
+ for (int i = 0; i < batch_results_.size(); ++i) {
+ batch_results_[i].Initialize(params.dataset->batch_size_);
+ }
+ }
~Iterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
mutex_lock l(mu_);
- if (current_batch_index_ != -1) {
- for (size_t batch_index = 0;
- batch_index < dataset()->num_parallel_batches_; ++batch_index) {
- int64 num_elements;
- WaitForBatch(batch_index, &num_elements).IgnoreError();
- // Deallocate tensors allocated for the output.
- batch_results_[batch_index].output.clear();
- }
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
}
}
- // TODO(jsimsa): Implement and profile the following alternative design:
- //
- // 0. Set the number of in-flight batches and invocations independently
- // (though obviously the max number of in-flight invocations must be <
- // batch_size * num_parallel_batches). Maintain a current producing batch
- // index and offset.
- // 1. Issue invocations in order of batch and offset, as you do currently.
- // 2. When an invocation finishes, increment the current producing batch
- // and offset. If that invocation would start a new batch and give more
- // than num_parallel_batches in-flight, block; else start the new
- // invocation into that location.
- // 3. When a GetNext() call arrives, block until there's a full batch.
- // Before returning the batch, if the number of pending invocations is
- // less than the max, issue that number of invocations.
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
-
- // One-time initialization.
- if (current_batch_index_ == -1) {
- current_batch_index_ = 0;
- for (size_t i = 0; i < dataset()->num_parallel_batches_; ++i) {
- StartInvocationBatch(ctx, i);
- }
- }
-
- int64 num_elements = 0;
- Status status = WaitForBatch(current_batch_index_, &num_elements);
- if (num_elements == 0) {
- *end_of_sequence = true;
- return Status::OK();
- }
- if (!status.ok()) {
- // Deallocate tensors allocated for the output.
- batch_results_[current_batch_index_].output.clear();
- } else {
- if (num_elements < dataset()->batch_size_) {
- if (dataset()->drop_remainder_) {
- // Deallocate tensors allocated for the output.
- batch_results_[current_batch_index_].output.clear();
- *end_of_sequence = true;
- return Status::OK();
- }
- const std::vector<Tensor>& output =
- batch_results_[current_batch_index_].output;
- for (size_t i = 0; i < output.size(); ++i) {
- TensorShape component_shape(
- batch_results_[current_batch_index_].output[i].shape());
- component_shape.set_dim(0, num_elements);
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- Tensor component(ctx->allocator(attr), output[i].dtype(),
- component_shape);
- TF_RETURN_IF_ERROR(
- CopyPartialBatch(&component, output[i], num_elements));
- out_tensors->emplace_back(std::move(component));
- }
- // Deallocate tensors allocated for the output.
- batch_results_[current_batch_index_].output.clear();
- } else {
- *out_tensors =
- std::move(batch_results_[current_batch_index_].output);
- }
- *end_of_sequence = false;
- }
- StartInvocationBatch(ctx, current_batch_index_);
- current_batch_index_ =
- (current_batch_index_ + 1) % dataset()->num_parallel_batches_;
- return status;
+ EnsureRunnerThreadStarted(ctx);
+ BatchResult* result = &batch_results_[ComputeIndex(input_batch_)];
+ WaitForBatch(result, &l);
+ return ProcessBatch(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- if (current_batch_index_ == -1) {
- // Iterator has not been used. Nothing to save.
- return Status::OK();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
}
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_batch_index"),
- current_batch_index_));
+ CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- // Wait for the map_fn dispatches made in `InvokeFunctionLocked` to
- // finish. This may delay saving a checkpoint by a bit but keeps the
- // code clean and also saves us from checkpointing the state of the
- // `BlockingCounter`.
- std::vector<int64> num_elements(batch_results_.size());
- for (size_t i = 0; i < batch_results_.size(); i++) {
- WaitForBatch(i, &num_elements[i]).IgnoreError();
- }
-
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name("invocation_results_size"), invocation_results_.size()));
- for (size_t i = 0; i < invocation_results_.size(); ++i) {
- TF_RETURN_IF_ERROR(WriteInvocationResultLocked(writer, i));
- }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("call_counter"), call_counter_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_batch"), input_batch_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("output_batch"), output_batch_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"),
batch_results_.size()));
for (size_t i = 0; i < batch_results_.size(); ++i) {
- TF_RETURN_IF_ERROR(
- WriteBatchResultLocked(writer, i, num_elements[i]));
+ TF_RETURN_IF_ERROR(WriteBatchResult(writer, i));
}
return Status::OK();
}
@@ -298,70 +244,136 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- if (!reader->Contains(full_name("current_batch_index"))) {
- // Iterator was never used so nothing to restore.
- return Status::OK();
- }
- {
- int64 temp;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("current_batch_index"), &temp));
- current_batch_index_ = static_cast<int32>(temp);
- if (current_batch_index_ != temp) {
- return errors::Internal("Invalid value for current_batch_index ",
- temp);
- }
- }
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- size_t invocation_results_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("invocation_results_size"), &temp));
- invocation_results_size = static_cast<size_t>(temp);
- if (invocation_results_size != temp) {
- return errors::Internal(
- "Invalid value for invocation_results_size ", temp);
- }
- }
- CHECK_EQ(invocation_results_.size(), invocation_results_size);
- for (size_t i = 0; i < invocation_results_size; ++i) {
- TF_RETURN_IF_ERROR(ReadInvocationResultLocked(reader, i));
- }
- size_t batch_results_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("batch_results_size"), &temp));
- batch_results_size = static_cast<size_t>(temp);
- if (batch_results_size != temp) {
- return errors::Internal("Invalid value for batch_results_size ",
- temp);
- }
- }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("call_counter"), &call_counter_));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("input_batch"), &input_batch_));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("output_batch"), &output_batch_));
+ int64 batch_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_results_size"),
+ &batch_results_size));
CHECK_EQ(batch_results_.size(), batch_results_size);
- for (size_t i = 0; i < batch_results_size; ++i) {
- TF_RETURN_IF_ERROR(ReadBatchResultLocked(ctx, reader, i));
+ for (int i = 0; i < batch_results_size; ++i) {
+ TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
}
return Status::OK();
}
private:
struct BatchResult {
- mutex mu ACQUIRED_AFTER(mu_);
- bool output_allocated GUARDED_BY(mu);
+ mutex mu;
+ bool end_of_input GUARDED_BY(mu);
+ int64 num_elements GUARDED_BY(mu);
std::vector<Tensor> output;
- std::unique_ptr<BlockingCounter> counter;
+ bool output_allocated GUARDED_BY(mu);
+ Status status GUARDED_BY(mu);
+ // Used for coordination between the main thread and the callback
+ // threads. In particular, the main thread will wait for the value
+ // of `num_calls` to reach zero before processing the batch result.
+ condition_variable cond_var; // access guarded by owner's mutex
+ // Counts the number of outstanding calls for this batch.
+ int64 num_calls; // access guarded by owner's mutex
+
+ void Initialize(int64 batch_size) {
+ mutex_lock l(mu);
+ end_of_input = false;
+ num_calls = batch_size;
+ num_elements = 0;
+ output_allocated = false;
+ status = Status::OK();
+ }
+
+ void UpdateStatus(const Status& s) {
+ mutex_lock l(mu);
+ status.Update(s);
+ }
};
- struct InvocationResult {
- Status status;
+ void Callback(const std::shared_ptr<IteratorContext>& ctx,
+ BatchResult* result, std::vector<Tensor>* return_values,
+ int64 offset, const Status& status) {
+ std::unique_ptr<std::vector<Tensor>> cleanup_retvals(return_values);
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does not "
+ "match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ {
+ mutex_lock l(mu_);
+ CallCompleted(result);
+ }
+ }
+
+ void CallCompleted(BatchResult* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ num_calls_--;
+ cond_var_.notify_all();
+ result->num_calls--;
+ result->cond_var.notify_all();
+ }
+
+ void CallFunction(std::shared_ptr<IteratorContext> ctx,
+ BatchResult* result, int64 offset) {
+ // Get the next input element.
+ std::vector<Tensor> input_element;
bool end_of_input;
- std::vector<Tensor> return_values;
- };
+ Status status =
+ input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
+ {
+ mutex_lock l(mu_);
+ mutex_lock l2(result->mu);
+ result->end_of_input = result->end_of_input || end_of_input;
+ result->status.Update(status);
+ if (result->end_of_input || !result->status.ok()) {
+ CallCompleted(result);
+ return;
+ }
+ }
- int64 ComputeInvocationIndex(int64 batch_index, int64 offset) {
- return batch_index * dataset()->batch_size_ + offset;
+ // Call `captured_func_(input_element)`, using `Callback` to store the
+ // result in `result`.
+ (*ctx->runner())(std::bind(
+ [this, result, offset](std::shared_ptr<IteratorContext> ctx,
+ std::vector<Tensor> input_element) {
+ std::vector<Tensor>* return_values = new std::vector<Tensor>();
+ dataset()->captured_func_->RunAsync(
+ ctx.get(), std::move(input_element), return_values,
+ [this, ctx, result, return_values, offset](Status status) {
+ Callback(ctx, result, return_values, offset, status);
+ });
+ },
+ ctx, std::move(input_element)));
+ }
+
+ int64 ComputeIndex(int64 n) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ return n % batch_results_.size();
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -387,253 +399,140 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- void EnsureOutputAllocated(IteratorContext* ctx,
- BatchResult* batch_result,
- const std::vector<Tensor>& return_values) {
- mutex_lock l(batch_result->mu);
- if (batch_result->output_allocated) {
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ std::bind(&Iterator::RunnerThread, this, ctx_copy)));
+ }
+ }
+
+ void EnsureOutputAllocated(const std::shared_ptr<IteratorContext>& ctx,
+ BatchResult* result,
+ const std::vector<Tensor>* return_values) {
+ mutex_lock l(result->mu);
+ if (result->output_allocated) {
return;
}
- const size_t num_components = return_values.size();
+ const size_t num_components = return_values->size();
for (size_t i = 0; i < num_components; ++i) {
TensorShape component_shape({dataset()->batch_size_});
- component_shape.AppendShape(return_values[i].shape());
+ component_shape.AppendShape(return_values->at(i).shape());
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
- Tensor component(ctx->allocator(attr), return_values[i].dtype(),
+ Tensor component(ctx->allocator(attr), return_values->at(i).dtype(),
component_shape);
- batch_result->output.emplace_back(std::move(component));
+ result->output.emplace_back(std::move(component));
}
- batch_result->output_allocated = true;
+ result->output_allocated = true;
}
- void InvokeFunctionLocked(IteratorContext* ctx, int64 batch_index,
- int64 offset) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- size_t index = ComputeInvocationIndex(batch_index, offset);
- InvocationResult* result = &invocation_results_[index];
- BatchResult* batch_result = &batch_results_[batch_index];
-
- // Get the next input element.
- std::vector<Tensor> input_element;
- result->status =
- input_impl_->GetNext(ctx, &input_element, &result->end_of_input);
- if (result->end_of_input || !result->status.ok()) {
- batch_result->counter->DecrementCount();
- return;
+ Status ProcessBatch(IteratorContext* ctx, BatchResult* result,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ auto cleanup =
+ gtl::MakeCleanup([this, result]() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ result->Initialize(dataset()->batch_size_);
+ input_batch_++;
+ });
+ mutex_lock l(result->mu);
+ if (result->num_elements == 0) {
+ *end_of_sequence = true;
+ return Status::OK();
}
- // Call `captured_func_(input_element)`, store the result in
- // `result->return_values`, and notify `batch_result->counter`
- // to unblock a consumer.
- (*ctx->runner())(std::bind(
- [this, result, batch_result, offset](
- IteratorContext* ctx, std::vector<Tensor> input_element) {
- dataset()->captured_func_->RunAsync(
- ctx, std::move(input_element), &result->return_values,
- [this, ctx, result, batch_result, offset](Status ret_status) {
- result->status.Update(ret_status);
- if (ret_status.ok()) {
- EnsureOutputAllocated(ctx, batch_result,
- result->return_values);
- const size_t num_components =
- result->return_values.size();
- for (size_t i = 0; i < num_components; ++i) {
- const Tensor& tensor = result->return_values[i];
- Tensor* batch = &(batch_result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->status.Update(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of "
- "elements does not match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that
- // allows us to move `tensor` where possible, to speed
- // up string tensor batching.
- Status copy_status =
- ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->status.Update(copy_status);
- break;
- }
- }
- }
- delete ctx;
- // NOTE(mrry): We clear the return values here to release
- // any memory associated with them and to paralellize the
- // destruction of the tensors (which can be surprisingly
- // expensive for map functions with large numbers of return
- // values).
- result->return_values.clear();
- batch_result->counter->DecrementCount();
- });
- },
- new IteratorContext(*ctx), std::move(input_element)));
- }
-
- void StartInvocationBatch(IteratorContext* ctx, int64 batch_index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- tracing::ScopedActivity activity(strings::StrCat(prefix(), "::Start"));
- // Initialize batch result.
- {
- mutex_lock l(batch_results_[batch_index].mu);
- batch_results_[batch_index].output_allocated = false;
- batch_results_[batch_index].counter.reset(
- new BlockingCounter(dataset()->batch_size_));
- }
- // Initialize invocation results.
- for (size_t i = 0; i < dataset()->batch_size_; ++i) {
- size_t index = ComputeInvocationIndex(batch_index, i);
- InvocationResult* result = &invocation_results_[index];
- // Reset the state of `result`; `result->return_values` was cleared
- // when the previous invocation completed.
- result->end_of_input = false;
- result->status = Status::OK();
- }
- // Start individual invocations.
- for (size_t i = 0; i < dataset()->batch_size_; ++i) {
- InvokeFunctionLocked(ctx, batch_index, i);
+ if (!result->status.ok()) {
+ // Deallocate tensors allocated for the output.
+ result->output.clear();
+ } else {
+ if (result->num_elements < dataset()->batch_size_) {
+ if (dataset()->drop_remainder_) {
+ // Deallocate tensors allocated for the output.
+ result->output.clear();
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ const std::vector<Tensor>& output = result->output;
+ for (size_t i = 0; i < output.size(); ++i) {
+ TensorShape component_shape(result->output[i].shape());
+ component_shape.set_dim(0, result->num_elements);
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ Tensor component(ctx->allocator(attr), output[i].dtype(),
+ component_shape);
+ TF_RETURN_IF_ERROR(CopyPartialBatch(&component, output[i],
+ result->num_elements));
+ out_tensors->emplace_back(std::move(component));
+ }
+ // Deallocate tensors allocated for the output.
+ result->output.clear();
+ } else {
+ *out_tensors = std::move(result->output);
+ }
+ *end_of_sequence = false;
}
+ cond_var_.notify_all();
+ return result->status;
}
- Status WaitForBatch(int64 batch_index, int64* num_elements)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- tracing::ScopedActivity activity(strings::StrCat(prefix(), "::Wait"));
- batch_results_[batch_index].counter->Wait();
- Status status = Status::OK();
- for (size_t i = 0; i < dataset()->batch_size_; ++i, ++*num_elements) {
- size_t index = ComputeInvocationIndex(batch_index, i);
- InvocationResult* result = &invocation_results_[index];
- if (result->end_of_input) {
- VLOG(3) << "end of input encountered at element[" << i << "]: ";
- return Status::OK();
- }
- if (!result->status.ok()) {
- VLOG(3) << "failed to process element[" << i
- << "]: " << result->status;
- status.Update(result->status);
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ mutex_lock l(mu_);
+ while (true) {
+ while (!cancelled_ &&
+ (num_calls_ == dataset()->num_parallel_calls_ ||
+ (output_batch_ - input_batch_ == batch_results_.size()))) {
+ cond_var_.wait(l);
}
- }
- return status;
- }
- Status WriteInvocationResultLocked(IteratorStateWriter* writer,
- size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const InvocationResult& result = invocation_results_[index];
- string prefix = strings::StrCat("invocation_results_", index);
- TF_RETURN_IF_ERROR(WriteStatusLocked(
- writer, full_name(strings::StrCat(prefix, "_status")),
- result.status));
- if (result.end_of_input) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat(prefix, "_end_of_input")), ""));
- }
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat(prefix, "_return_values_size")),
- result.return_values.size()));
- for (size_t i = 0; i < result.return_values.size(); i++) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat(prefix, "_return_values_", i)),
- result.return_values[i]));
- }
- return Status::OK();
- }
+ if (cancelled_) {
+ return;
+ }
- Status ReadInvocationResultLocked(IteratorStateReader* reader,
- size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- InvocationResult* result = &invocation_results_[index];
- string prefix = strings::StrCat("invocation_results_", index);
- TF_RETURN_IF_ERROR(ReadStatusLocked(
- reader, full_name(strings::StrCat(prefix, "_status")),
- &result->status));
- result->end_of_input = reader->Contains(
- full_name(strings::StrCat(prefix, "_end_of_input")));
- size_t return_values_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat(prefix, "_return_values_size")),
- &temp));
- return_values_size = static_cast<size_t>(temp);
- if (temp != return_values_size) {
- return errors::Internal("Invalid value for return_values_size ",
- return_values_size);
+ while (num_calls_ < dataset()->num_parallel_calls_ &&
+ (output_batch_ - input_batch_ < batch_results_.size())) {
+ BatchResult* result = &batch_results_[ComputeIndex(output_batch_)];
+ int64 offset = call_counter_++ % dataset()->batch_size_;
+ num_calls_++;
+ mu_.unlock();
+ CallFunction(ctx, result, offset);
+ mu_.lock();
+ if (offset + 1 == dataset()->batch_size_) {
+ // Done scheduling calls for the current batch.
+ output_batch_++;
+ }
}
}
- result->return_values.reserve(return_values_size);
- for (size_t i = 0; i < return_values_size; i++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(reader->ReadTensor(
- full_name(strings::StrCat(prefix, "_return_values_", i)),
- &result->return_values.back()));
- }
- return Status::OK();
}
- Status WriteBatchResultLocked(IteratorStateWriter* writer, size_t index,
- int64 num_elements)
+ void WaitForBatch(BatchResult* result, mutex_lock* l)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const BatchResult& result = batch_results_[index];
- string prefix = strings::StrCat("batch_results_", index);
- {
- mutex_lock l(batch_results_[index].mu);
- if (result.output_allocated) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat(prefix, "_output_allocated")), ""));
- }
+ while (result->num_calls > 0) {
+ result->cond_var.wait(*l);
}
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat(prefix, "_output_size")),
- result.output.size()));
- for (size_t i = 0; i < result.output.size(); i++) {
- // If the batch is not full, we only store the first
- // `num_elements` values. The rest of the batch tensor is
- // *uninitialized* and accessing that will raise msan errors.
- if (num_elements < dataset()->batch_size_) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat(prefix, "_output_", i)),
- result.output[i].Slice(0, num_elements)));
- } else {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat(prefix, "_output_", i)),
- result.output[i]));
- }
- }
- return Status::OK();
}
- Status ReadBatchResultLocked(IteratorContext* ctx,
- IteratorStateReader* reader, size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
+ size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
BatchResult* result = &batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
- {
- mutex_lock l(batch_results_[index].mu);
- result->output_allocated = reader->Contains(
- full_name(strings::StrCat(prefix, "_output_allocated")));
- // Simulate that the batch was fully generated.
- batch_results_[index].counter.reset(new BlockingCounter(0));
- }
- size_t output_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat(prefix, "_output_size")), &temp));
- output_size = static_cast<size_t>(temp);
- if (temp != output_size) {
- return errors::Internal("Invalid value for output_size ",
- output_size);
- }
- }
+ mutex_lock l(result->mu);
+ result->end_of_input = reader->Contains(
+ full_name(strings::StrCat(prefix, "_end_of_input")));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name(strings::StrCat(prefix, "_num_calls")),
+ &result->num_calls));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(prefix, "_num_elements")),
+ &result->num_elements));
+ result->output_allocated = reader->Contains(
+ full_name(strings::StrCat(prefix, "_output_allocated")));
+ int64 output_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat(prefix, "_output_size")), &output_size));
result->output.reserve(output_size);
- for (size_t i = 0; i < output_size; i++) {
+ for (int i = 0; i < output_size; i++) {
Tensor t;
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(prefix, "_output_", i)), &t));
@@ -653,25 +552,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output.emplace_back(std::move(t));
}
}
+ TF_RETURN_IF_ERROR(ReadStatus(
+ reader, strings::StrCat(prefix, "_status"), &result->status));
return Status::OK();
}
- Status WriteStatusLocked(IteratorStateWriter* writer,
- const string& prefix, const Status& status)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
- static_cast<int64>(status.code())));
- if (!status.ok()) {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
- status.error_message()));
- }
- return Status::OK();
- }
-
- Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status ReadStatus(IteratorStateReader* reader, const string& prefix,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
@@ -687,17 +574,89 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
return Status::OK();
}
+
+ Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ BatchResult* result = &batch_results_[index];
+ string prefix = strings::StrCat("batch_results_", index);
+ mutex_lock l(result->mu);
+ if (result->end_of_input) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_end_of_input")), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_num_calls")),
+ result->num_calls));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_num_elements")),
+ result->num_elements));
+ if (result->output_allocated) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_output_allocated")), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat(prefix, "_output_size")),
+ result->output.size()));
+ for (int i = 0; i < result->output.size(); i++) {
+ // If the batch is not full, we only store the first `num_elements`
+ // values. The rest of the batch tensor is *uninitialized* and
+ // accessing that will raise msan errors.
+ if (result->num_elements < dataset()->batch_size_) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat(prefix, "_output_", i)),
+ result->output[i].Slice(0, result->num_elements)));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat(prefix, "_output_", i)),
+ result->output[i]));
+ }
+ }
+ TF_RETURN_IF_ERROR(WriteStatus(
+ writer, strings::StrCat(prefix, "_status"), result->status));
+ return Status::OK();
+ }
+
+ Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
+ static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
mutex mu_;
- int32 current_batch_index_ GUARDED_BY(mu_) = -1;
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- std::vector<InvocationResult> invocation_results_ GUARDED_BY(mu_);
+ // Used for coordination between the main thread, the runner thread, and
+ // the callback threads. In particular, the runner thread should only
+ // schedule new calls when the number of in-flight calls is less than the
+ // user specified level of parallelism and there are slots available in
+ // the `batch_results_` buffer.
+ condition_variable cond_var_;
+ // Counts the number of outstanding calls for this batch.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+ // Counts the total number of calls.
+ int64 call_counter_ GUARDED_BY(mu_) = 0;
+ const std::unique_ptr<IteratorBase> input_impl_;
+ // Identifies the next batch to be read by the caller.
+ int64 input_batch_ GUARDED_BY(mu_) = 0;
+ // Identifies the next batch to create.
+ int64 output_batch_ GUARDED_BY(mu_) = 0;
+ // Circular buffer for storing the (intermediate) batch results. When
+ // using `input_batch_` and `output_batch_` to index into the buffer,
+ // their value should be interpreted modulo the size of the buffer.
std::vector<BatchResult> batch_results_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
};
const DatasetBase* const input_;
const NameAttrList func_;
const int64 batch_size_;
- const int64 num_parallel_batches_;
+ const int64 num_parallel_calls_;
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
@@ -707,6 +666,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
};
const int graph_def_version_;
+ const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList func_;
@@ -715,6 +675,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU),
+ MapAndBatchDatasetOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index a8bcc7f7dc..03cc414905 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -703,6 +703,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
+ scatter_op::UpdateOp::ASSIGN);
REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
@@ -728,6 +730,13 @@ REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_GPU)
.HostMemory("resource")
+ .TypeConstraint<bool>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, bool, int32,
+ scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int64>("Tindices"),
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
index 59911bf0d2..bdc878594a 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
@@ -42,6 +42,8 @@ typedef Eigen::GpuDevice GPUDevice;
DEFINE_GPU_SPECS(float);
DEFINE_GPU_SPECS(double);
+DEFINE_GPU_SPECS_OP(bool, int32, scatter_op::UpdateOp::ASSIGN);
+DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN);
// TODO(b/27222123): The following fails to compile due to lack of support for
// fp16.
// TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 3db00d8180..6880ceb505 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -28130,6 +28130,54 @@ op {
}
}
op {
+ name: "MapAndBatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "MapClear"
attr {
name: "capacity"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 73174c184c..576946eddd 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -208,6 +208,19 @@ REGISTER_OP("MapAndBatchDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("MapAndBatchDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_calls: int64")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("PrefetchDataset")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7156440b46..d741598b19 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -13940,6 +13940,54 @@ op {
}
}
op {
+ name: "MapAndBatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "MapClear"
attr {
name: "capacity"
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 107c38114b..f6e09ef094 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -335,6 +335,8 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
name = cc_name,
deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
if_static([name + "_cc_impl"]),
+ testonly = testonly,
+ visibility = visibility,
)
native.cc_library(
name = cc_name + "_impl",
@@ -378,8 +380,10 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
)
native.py_library(
name = py_name,
- deps = py_deps + ["@protobuf_archive//:protobuf_python"])
-
+ deps = py_deps + ["@protobuf_archive//:protobuf_python"],
+ testonly = testonly,
+ visibility = visibility,
+ )
return
py_proto_library(
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
index a12d92795e..89e57d58a0 100644
--- a/tensorflow/core/platform/default/mutex.h
+++ b/tensorflow/core/platform/default/mutex.h
@@ -77,9 +77,7 @@ class SCOPED_LOCKABLE mutex_lock {
// Manually nulls out the source to prevent double-free.
// (std::move does not null the source pointer by default.)
- explicit mutex_lock(mutex_lock&& ml) noexcept : mu_(ml.mu_) {
- ml.mu_ = nullptr;
- }
+ mutex_lock(mutex_lock&& ml) noexcept : mu_(ml.mu_) { ml.mu_ = nullptr; }
~mutex_lock() UNLOCK_FUNCTION() {
if (mu_ != nullptr) {
mu_->unlock();
diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md
index 61edba04b4..3322004189 100644
--- a/tensorflow/docs_src/deploy/index.md
+++ b/tensorflow/docs_src/deploy/index.md
@@ -15,3 +15,7 @@ the following documents:
out-of-the-box integration with TensorFlow models.
[Source code for TensorFlow Serving](https://github.com/tensorflow/serving)
is available on GitHub.
+
+[TensorFlow Extended (TFX)](/tfx) is an end-to-end machine learning platform for
+TensorFlow. Implemented at Google, we've open sourced some TFX libraries with the
+rest of the system to come.
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index f530fe1206..21e4c71a60 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1049,8 +1049,8 @@ For a more intuitive description, see the "Informal Description" section below.
: : : from. :
|`gather_indices` | `ComputationDataHandle` | Tensor containing the starting |
: : : indices of the slices we're :
-: : : we're stitching together into :
-: : : the output tensor. :
+: : : stitching together into the :
+: : : output tensor. :
|`index_vector_dim` | `int64` | The dimension in |
: : : `gather_indices` that contains :
: : : the starting indices. :
diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md
index d5703e0737..8a98367dfb 100644
--- a/tensorflow/docs_src/programmers_guide/embedding.md
+++ b/tensorflow/docs_src/programmers_guide/embedding.md
@@ -238,7 +238,7 @@ row doesn't have to be filled, as shown below.
</tr>
</table>
-Follow [this link]("https://www.tensorflow.org/images/embedding-mnist.mp4" )
+Follow [this link](https://www.tensorflow.org/images/embedding-mnist.mp4)
to see a fun example of thumbnail images in the Embedding Projector.
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 2f1be51ada..70a271bd2e 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2544,6 +2544,72 @@ func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values
return op.Output(0)
}
+// Reverses specific dimensions of a tensor.
+//
+// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions
+// of `tensor`, this operation reverses each dimension i of `tensor` where
+// `dims[i]` is `True`.
+//
+// `tensor` can have up to 8 dimensions. The number of dimensions
+// of `tensor` must equal the number of elements in `dims`. In other words:
+//
+// `rank(tensor) = size(dims)`
+//
+// For example:
+//
+// ```
+// # tensor 't' is [[[[ 0, 1, 2, 3],
+// # [ 4, 5, 6, 7],
+// # [ 8, 9, 10, 11]],
+// # [[12, 13, 14, 15],
+// # [16, 17, 18, 19],
+// # [20, 21, 22, 23]]]]
+// # tensor 't' shape is [1, 2, 3, 4]
+//
+// # 'dims' is [False, False, False, True]
+// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
+// [ 7, 6, 5, 4],
+// [ 11, 10, 9, 8]],
+// [[15, 14, 13, 12],
+// [19, 18, 17, 16],
+// [23, 22, 21, 20]]]]
+//
+// # 'dims' is [False, True, False, False]
+// reverse(t, dims) ==> [[[[12, 13, 14, 15],
+// [16, 17, 18, 19],
+// [20, 21, 22, 23]
+// [[ 0, 1, 2, 3],
+// [ 4, 5, 6, 7],
+// [ 8, 9, 10, 11]]]]
+//
+// # 'dims' is [False, False, True, False]
+// reverse(t, dims) ==> [[[[8, 9, 10, 11],
+// [4, 5, 6, 7],
+// [0, 1, 2, 3]]
+// [[20, 21, 22, 23],
+// [16, 17, 18, 19],
+// [12, 13, 14, 15]]]]
+// ```
+//
+// Arguments:
+// tensor: Up to 8-D.
+// dims: 1-D. The dimensions to reverse.
+//
+// Returns The same shape as `tensor`.
+func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Reverse",
+ Input: []tf.Input{
+ tensor, dims,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Clips tensor values to a specified min and max.
//
// Given a tensor `t`, this operation returns a tensor of the same type and
@@ -2796,71 +2862,6 @@ func Asin(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// SparseToDenseAttr is an optional argument to SparseToDense.
-type SparseToDenseAttr func(optionalAttr)
-
-// SparseToDenseValidateIndices sets the optional validate_indices attribute to value.
-//
-// value: If true, indices are checked to make sure they are sorted in
-// lexicographic order and that there are no repeats.
-// If not specified, defaults to true
-func SparseToDenseValidateIndices(value bool) SparseToDenseAttr {
- return func(m optionalAttr) {
- m["validate_indices"] = value
- }
-}
-
-// Converts a sparse representation into a dense tensor.
-//
-// Builds an array `dense` with shape `output_shape` such that
-//
-// ```
-// # If sparse_indices is scalar
-// dense[i] = (i == sparse_indices ? sparse_values : default_value)
-//
-// # If sparse_indices is a vector, then for each i
-// dense[sparse_indices[i]] = sparse_values[i]
-//
-// # If sparse_indices is an n by d matrix, then for each i in [0, n)
-// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
-// ```
-//
-// All other values in `dense` are set to `default_value`. If `sparse_values` is a
-// scalar, all sparse indices are set to this single value.
-//
-// Indices should be sorted in lexicographic order, and indices must not
-// contain any repeats. If `validate_indices` is true, these properties
-// are checked during execution.
-//
-// Arguments:
-// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete
-// index where `sparse_values[i]` will be placed.
-// output_shape: 1-D. Shape of the dense output tensor.
-// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`,
-// or a scalar value to be used for all sparse indices.
-// default_value: Scalar value to set for indices not specified in
-// `sparse_indices`.
-//
-// Returns Dense output tensor of shape `output_shape`.
-func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SparseToDense",
- Input: []tf.Input{
- sparse_indices, output_shape, sparse_values, default_value,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the sum along sparse segments of a tensor.
//
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
@@ -6469,72 +6470,6 @@ func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dens
return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
}
-// Reverses specific dimensions of a tensor.
-//
-// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions
-// of `tensor`, this operation reverses each dimension i of `tensor` where
-// `dims[i]` is `True`.
-//
-// `tensor` can have up to 8 dimensions. The number of dimensions
-// of `tensor` must equal the number of elements in `dims`. In other words:
-//
-// `rank(tensor) = size(dims)`
-//
-// For example:
-//
-// ```
-// # tensor 't' is [[[[ 0, 1, 2, 3],
-// # [ 4, 5, 6, 7],
-// # [ 8, 9, 10, 11]],
-// # [[12, 13, 14, 15],
-// # [16, 17, 18, 19],
-// # [20, 21, 22, 23]]]]
-// # tensor 't' shape is [1, 2, 3, 4]
-//
-// # 'dims' is [False, False, False, True]
-// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
-// [ 7, 6, 5, 4],
-// [ 11, 10, 9, 8]],
-// [[15, 14, 13, 12],
-// [19, 18, 17, 16],
-// [23, 22, 21, 20]]]]
-//
-// # 'dims' is [False, True, False, False]
-// reverse(t, dims) ==> [[[[12, 13, 14, 15],
-// [16, 17, 18, 19],
-// [20, 21, 22, 23]
-// [[ 0, 1, 2, 3],
-// [ 4, 5, 6, 7],
-// [ 8, 9, 10, 11]]]]
-//
-// # 'dims' is [False, False, True, False]
-// reverse(t, dims) ==> [[[[8, 9, 10, 11],
-// [4, 5, 6, 7],
-// [0, 1, 2, 3]]
-// [[20, 21, 22, 23],
-// [16, 17, 18, 19],
-// [12, 13, 14, 15]]]]
-// ```
-//
-// Arguments:
-// tensor: Up to 8-D.
-// dims: 1-D. The dimensions to reverse.
-//
-// Returns The same shape as `tensor`.
-func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Reverse",
- Input: []tf.Input{
- tensor, dims,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// BiasAddGradAttr is an optional argument to BiasAddGrad.
type BiasAddGradAttr func(optionalAttr)
@@ -24884,6 +24819,71 @@ func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples t
return op.Output(0)
}
+// SparseToDenseAttr is an optional argument to SparseToDense.
+type SparseToDenseAttr func(optionalAttr)
+
+// SparseToDenseValidateIndices sets the optional validate_indices attribute to value.
+//
+// value: If true, indices are checked to make sure they are sorted in
+// lexicographic order and that there are no repeats.
+// If not specified, defaults to true
+func SparseToDenseValidateIndices(value bool) SparseToDenseAttr {
+ return func(m optionalAttr) {
+ m["validate_indices"] = value
+ }
+}
+
+// Converts a sparse representation into a dense tensor.
+//
+// Builds an array `dense` with shape `output_shape` such that
+//
+// ```
+// # If sparse_indices is scalar
+// dense[i] = (i == sparse_indices ? sparse_values : default_value)
+//
+// # If sparse_indices is a vector, then for each i
+// dense[sparse_indices[i]] = sparse_values[i]
+//
+// # If sparse_indices is an n by d matrix, then for each i in [0, n)
+// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
+// ```
+//
+// All other values in `dense` are set to `default_value`. If `sparse_values` is a
+// scalar, all sparse indices are set to this single value.
+//
+// Indices should be sorted in lexicographic order, and indices must not
+// contain any repeats. If `validate_indices` is true, these properties
+// are checked during execution.
+//
+// Arguments:
+// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete
+// index where `sparse_values[i]` will be placed.
+// output_shape: 1-D. Shape of the dense output tensor.
+// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`,
+// or a scalar value to be used for all sparse indices.
+// default_value: Scalar value to set for indices not specified in
+// `sparse_indices`.
+//
+// Returns Dense output tensor of shape `output_shape`.
+func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseToDense",
+ Input: []tf.Input{
+ sparse_indices, output_shape, sparse_values, default_value,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors.
//
// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 087b89b125..a865e8ca75 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1762,6 +1762,7 @@ py_library(
":logging_ops_gen",
":math_ops",
":platform",
+ ":resource_variable_ops_gen",
":sparse_tensor",
":tensor_array_ops",
":tf_should_use",
@@ -4134,7 +4135,7 @@ cuda_py_test(
py_test(
name = "saver_large_variable_test",
- size = "small",
+ size = "medium",
srcs = ["training/saver_large_variable_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 00090b21fe..7cbaae46b4 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -140,7 +140,7 @@ def main(_):
# Make predictions, using tfdbg hook.
predict_results = classifier.predict(test_input_fn, hooks=hooks)
- print("A prediction result: %s" % predict_results.next())
+ print("A prediction result: %s" % next(predict_results))
if __name__ == "__main__":
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index d04b004451..967c128280 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -358,6 +358,8 @@ def gradients_function(f, params=None):
assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
```
+ Note that only tensors with real or complex dtypes are differentiable.
+
Args:
f: function to be differentiated. If `f` returns a scalar, this scalar will
be differentiated. If `f` returns a tensor or list of tensors, by default
@@ -700,6 +702,9 @@ class GradientTape(object):
dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x) # 6.0
del g # Drop the reference to the tape
+ ```
+
+ Note that only tensors with real or complex dtypes are differentiable.
"""
def __init__(self, persistent=False):
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 8d9959fe20..be674487f1 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -124,6 +124,14 @@ class BackpropTest(test.TestCase):
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
+ def testGradientInteger(self):
+
+ def f(x):
+ return x + x
+
+ int_tensor = constant_op.constant(1)
+ self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None)
+
def testErrors(self):
@custom_gradient.custom_gradient
@@ -753,7 +761,7 @@ class BackpropTest(test.TestCase):
return result, grad
x = resource_variable_ops.ResourceVariable(
- initial_value=3, name='X.' + self.id())
+ initial_value=3., name='X.' + self.id())
def f():
return my_square(x)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 741bd2ac9c..b478b6b0db 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,6 +23,7 @@ import collections
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -102,13 +103,15 @@ class CapturingGraph(ops.Graph):
def clear_resource_control_flow_state(self):
self._last_op_using_resource_tensor = {}
- def maybe_capture_tensor(self, tensor):
+ def capture(self, tensor, name=None):
if isinstance(tensor, ops.EagerTensor):
- return capture_value(
- self.captures, tensor, tensor.dtype, str(ops.uid()))
+ if name is None:
+ name = str(ops.uid())
+ return capture_value(self.captures, tensor, tensor.dtype, name)
if tensor.graph is not self:
- return capture_value(
- self.captures, tensor, tensor.dtype, tensor.op.name)
+ if name is None:
+ name = tensor.op.name
+ return capture_value(self.captures, tensor, tensor.dtype, name)
return tensor
def create_op(
@@ -126,7 +129,7 @@ class CapturingGraph(ops.Graph):
# forward the resources such as Identity and Switch can cause serialization
# to fail.
for i, inp in enumerate(inputs):
- inputs[i] = self.maybe_capture_tensor(inp)
+ inputs[i] = self.capture(inp)
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_shapes, compute_device)
@@ -225,7 +228,7 @@ def _inference_name(n):
class _EagerDefinedFunction(object):
"""Function object with the interface of tf _DefinedFunction."""
- def __init__(self, name, graph, operations, inputs, outputs):
+ def __init__(self, name, graph, operations, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
@@ -235,6 +238,7 @@ class _EagerDefinedFunction(object):
which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
+ attrs: dict mapping names of attributes to their AttrValue values
"""
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
@@ -246,6 +250,14 @@ class _EagerDefinedFunction(object):
[],
None,
compat.as_str(""))
+
+ for name, attr_value in attrs.items():
+ serialized = attr_value.SerializeToString()
+ # TODO(iga): this creates and deletes a new TF_Status for every attr.
+ # It might be worth creating a convenient way to re-use status.
+ pywrap_tensorflow.TF_FunctionSetAttrValueProto(
+ fn, compat.as_str(name), serialized)
+
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
@@ -287,25 +299,6 @@ def _flatten(sequence):
class GraphModeFunction(object):
"""Callable object representing a graph-mode function.
-
- Args:
- name: str the name of the created function
- input_placeholders: list of placeholder values (tensors) to feed when
- calling the wrapped function.
- extra_inputs: Tensor inputs this function definition closed over which
- are passed as arguments. Need to track so gradients are supported
- correctly.
- graph: the Graph from which the operations will be pulled. Used as
- a context when computing gradients.
- operations: the subset of Operations in the graph used in the function
- definition.
- outputs: a flat list of the Tensors in the graph used as outputs to the
- function
- func_outputs: a possibly nested python object which will be returned by
- this function. The Tensors in this structure will be replaced by their
- corresponding values in outputs.
- output_shapes: List of shapes of all tensors in outputs
- variables: (optional) List of variables to watch during function execution.
"""
def __init__(self,
@@ -317,9 +310,36 @@ class GraphModeFunction(object):
outputs,
func_outputs,
output_shapes,
- variables=None):
+ variables=None,
+ attrs=None):
+ """Initialize a GraphModeFunction.
+
+ Args:
+ name: str the name of the created function
+ input_placeholders: list of placeholder values (tensors) to feed when
+ calling the wrapped function.
+ extra_inputs: Tensor inputs this function definition closed over which
+ are passed as arguments. Need to track so gradients are supported
+ correctly.
+ graph: the Graph from which the operations will be pulled. Used as
+ a context when computing gradients.
+ operations: the subset of Operations in the graph used in the function
+ definition.
+ outputs: a flat list of the Tensors in the graph used as outputs to the
+ function
+ func_outputs: a possibly nested python object which will be returned by
+ this function. The Tensors in this structure will be replaced by their
+ corresponding values in outputs.
+ output_shapes: List of shapes of all tensors in outputs
+ variables: (optional) List of variables to watch during function
+ execution.
+ attrs: (optional) dict mapping names of attributes to their AttrValue
+ values. Attributes in `attrs` will be included in this function's
+ definition.
+ """
+ self._attrs = attrs or {}
defined_function = _EagerDefinedFunction(
- name, graph, operations, input_placeholders, outputs)
+ name, graph, operations, input_placeholders, outputs, self._attrs)
if len(input_placeholders) != len(defined_function.signature.input_arg):
raise ValueError("Internal error: invalid lengths. %s %s" % (
len(input_placeholders), len(defined_function.signature.input_arg)))
@@ -372,7 +392,7 @@ class GraphModeFunction(object):
forward_name = _forward_name(self._func_name)
self._forward_fdef = _EagerDefinedFunction(
forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + captures)
+ filtered_outputs + captures, self._attrs)
all_inputs = self._out_grad_placeholders + captures
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
@@ -386,7 +406,7 @@ class GraphModeFunction(object):
bname = _backward_name(self._func_name)
self._backward_function = GraphModeFunction(
bname, all_inputs, [], self._graph, function_def_ops,
- backward_outputs, in_gradients, output_shapes)
+ backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
@@ -560,7 +580,7 @@ def _get_defun_inputs(args):
return nest.pack_sequence_as(args, ret)
-def _defun_internal(name, func, args, kwds):
+def _defun_internal(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
@@ -598,7 +618,7 @@ def _defun_internal(name, func, args, kwds):
# call to convert_to_tensor, so we manually capture all such tensors.
outputs_list = _flatten(func_outputs)
func_def_outputs = [
- tmp_graph.maybe_capture_tensor(x) for x in outputs_list
+ tmp_graph.capture(x) for x in outputs_list
if x is not None
]
@@ -625,9 +645,14 @@ def _defun_internal(name, func, args, kwds):
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
+
+ attrs = {}
+ if compiled:
+ attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
+
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
- func_outputs, output_shapes, variables)
+ func_outputs, output_shapes, variables, attrs)
# Defun uses this instead of Tensor as a cache key. Using dtype because
@@ -669,7 +694,7 @@ def _register(fn):
# TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name):
+def named_defun(func, name, compiled=False):
"""Defines a function with a given name.
See the documentation for `defun` for more information on the semantics of the
@@ -678,6 +703,7 @@ def named_defun(func, name):
Args:
func: the function to be wrapped.
name: the name given to it.
+ compiled: if true, the framework will attempt to compile func with XLA.
Returns:
the wrapped function.
@@ -694,13 +720,13 @@ def named_defun(func, name):
if cache_key not in arguments_to_functions:
arguments_to_functions[cache_key] = _defun_internal(
- name, func, args, kwds)
+ name, func, compiled, args, kwds)
return arguments_to_functions[cache_key](*args)
return decorated
-def defun(func):
+def defun(func=None, compiled=False):
"""Decorator to compile func into graph_mode.
`defun` converts a function that constructs a TensorFlow graph into a function
@@ -743,18 +769,45 @@ def defun(func):
```
Args:
- func: function to be compiled.
+ func: function to be compiled. If `func` is None, returns a
+ decorator that can be invoked with a single argument - `func`. The
+ end result is equivalent to providing all the arguments up front.
+ In other words, defun(compiled=True)(func) is equivalent to
+ defun(func, compiled=True). The former allows the following use case:
+ @tfe.defun(compiled=True)
+ def foo(...):
+ ...
+ compiled: If True, an attempt to compile `func` with XLA will be made.
+ If it fails, function will be run normally. Experimental.
+ Currently, supported only for execution on TPUs.
Returns:
- A callable that will execute the compiled function (and return zero
- or more `tf.Tensor` objects).
+ If `func` is not None, returns callable that will execute the compiled
+ function (and return zero or more `tf.Tensor` objects).
+ If `func` is None, returns a decorator that, when invoked with a single
+ `func` argument, returns a callable equivalent to the case above.
"""
# TODO(apassos): deal with captured global state. Deal with control flow.
- try:
- name = func.__name__
- except AttributeError:
- name = "function"
- return tf_decorator.make_decorator(func, named_defun(func, name))
+ def decorated(function):
+ try:
+ name = function.__name__
+ except AttributeError:
+ name = "function"
+ return tf_decorator.make_decorator(
+ function, named_defun(function, name, compiled=compiled))
+
+ # This code path is for the `foo = tfe.defun(foo, ...)` use case
+ if func is not None:
+ return decorated(func)
+
+ # This code path is for the
+ #
+ # @tfe.defun(...)
+ # def foo(...):
+ # ...
+ #
+ # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
+ return decorated
def make_defun_op(func, *args, **kwds):
@@ -806,7 +859,7 @@ def make_defun_op(func, *args, **kwds):
name = func.__name__
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
- return _defun_internal(name, func, args, kwds)
+ return _defun_internal(name, func, False, args, kwds)
class AutomaticControlDependencies(object):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 185f6d981c..f53d6c2608 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -771,6 +771,21 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0)
self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0)
+ def testDefunWhileLoopWithCapturedLoopVars(self):
+ n = 3
+ x = constant_op.constant(list(range(n)))
+
+ @function.defun
+ def loop():
+ c = lambda i, x: i < n
+ b = lambda i, x: (i + 1, x + 1)
+ i, out = control_flow_ops.while_loop(c, b, (0, x))
+ return i, out
+
+ i, out = loop()
+ self.assertEqual(int(i), 3)
+ self.assertAllEqual(out, [3, 4, 5])
+
def testDecorator(self):
with context.graph_mode(), self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index b5b4e394e3..b3aadd55ce 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -650,6 +650,12 @@ tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
+tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
+ CHECK(EagerTensor_CheckExact(tensor));
+ return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
+ reinterpret_cast<const EagerTensor*>(tensor)->handle));
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index 63ab1ed84d..88982b0c85 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,6 +21,7 @@ limitations under the License.
bool EagerTensor_CheckExact(const PyObject* o);
tensorflow::int64 EagerTensor_id(const PyObject* tensor);
+tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 4ecba1a46b..48a5b21dc7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -843,6 +843,24 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
return id;
}
+static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
+ if (EagerTensor_CheckExact(tensor)) {
+ return EagerTensor_dtype(tensor);
+ }
+ PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
+ if (dtype_field == nullptr) {
+ return tensorflow::DT_INVALID;
+ }
+ PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
+ Py_DECREF(dtype_field);
+ if (dtype_field == nullptr) {
+ return tensorflow::DT_INVALID;
+ }
+ tensorflow::int64 id = MakeInt(enum_field);
+ Py_DECREF(enum_field);
+ return static_cast<tensorflow::DataType>(id);
+}
+
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
public:
@@ -1053,15 +1071,18 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
// TODO(apassos) consider not building a list and changing the API to check
// each tensor individually.
std::vector<tensorflow::int64> tensor_ids;
+ std::vector<tensorflow::DataType> dtypes;
tensor_ids.reserve(len);
+ dtypes.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
tensor_ids.push_back(FastTensorId(item));
+ dtypes.push_back(FastTensorDtype(item));
}
Py_DECREF(seq);
auto tape_set = *tape_set_ptr;
for (TFE_Py_Tape* tape : tape_set) {
- if (tape->tape->ShouldRecord(tensor_ids)) {
+ if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
Py_RETURN_TRUE;
}
}
@@ -1169,9 +1190,27 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
}
namespace {
-void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
- const std::vector<tensorflow::int64>& input_ids,
- PyObject* backward_function) {
+std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return {};
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ std::vector<tensorflow::DataType> list;
+ list.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
+ list.push_back(FastTensorDtype(tensor));
+ }
+ Py_DECREF(seq);
+ return list;
+}
+
+void TapeSetRecordOperation(
+ PyObject* op_type, PyObject* output_tensors,
+ const std::vector<tensorflow::int64>& input_ids,
+ const std::vector<tensorflow::DataType>& input_dtypes,
+ PyObject* backward_function) {
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
@@ -1206,7 +1245,7 @@ void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
for (TFE_Py_Tape* tape : SafeTapeSet()) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
- op_type_str, output_info, input_ids, backward_function,
+ op_type_str, output_info, input_ids, input_dtypes, backward_function,
[backward_function]() { Py_DECREF(backward_function); });
}
}
@@ -1221,7 +1260,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
if (PyErr_Occurred()) return;
- TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
+ std::vector<tensorflow::DataType> input_dtypes =
+ MakeTensorDtypeList(input_tensors);
+ if (PyErr_Occurred()) return;
+ TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
+ backward_function);
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
@@ -1710,10 +1753,12 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* results, PyObject* name) {
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
if (PyErr_Occurred()) return nullptr;
+ std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
+ if (PyErr_Occurred()) return nullptr;
bool should_record = false;
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- if (tape->tape->ShouldRecord(input_ids)) {
+ if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
should_record = true;
break;
}
@@ -1744,7 +1789,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
Py_DECREF(callback_args);
if (backward_function == nullptr) return nullptr;
- TapeSetRecordOperation(op_name, results, input_ids, backward_function);
+ TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
+ backward_function);
Py_DECREF(backward_function);
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 56dec1eaa1..2d9a084bc6 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -91,6 +91,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:tag_constants",
"@six_archive//:six",
],
)
@@ -488,6 +489,7 @@ py_library(
py_test(
name = "estimator_test",
srcs = ["estimator_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["notsan"], # b/67510291
deps = [
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 973a6ec747..e7fbf8eb72 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -154,6 +154,59 @@ def _dnn_model_fn(features,
Raises:
ValueError: If features has the wrong type.
"""
+ tpu_estimator_spec = _tpu_dnn_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config)
+ return tpu_estimator_spec.as_estimator_spec()
+
+
+def _tpu_dnn_model_fn(features,
+ labels,
+ mode,
+ head,
+ hidden_units,
+ feature_columns,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None):
+ """Deep Neural Net model_fn for TPUEstimator.
+
+ Args:
+ features: dict of `Tensor`.
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+ dtype `int32` or `int64` in the range `[0, n_classes)`.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ head: A `head_lib._Head` instance.
+ hidden_units: Iterable of integer number of hidden units per layer.
+ feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use the Adagrad
+ optimizer with a default learning rate of 0.05.
+ activation_fn: Activation function applied to each layer.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Partitioner for input layer. Defaults
+ to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Returns:
+ A `model_fn.TPUEstimatorSpec` instance.
+
+ Raises:
+ ValueError: If features has the wrong type.
+ """
if not isinstance(features, dict):
raise ValueError('features should be a dictionary of `Tensor`s. '
'Given type: {}'.format(type(features)))
@@ -182,7 +235,7 @@ def _dnn_model_fn(features,
input_layer_partitioner=input_layer_partitioner)
logits = logit_fn(features=features, mode=mode)
- return head.create_estimator_spec(
+ return head._create_tpu_estimator_spec( # pylint: disable=protected-access
features=features,
mode=mode,
labels=labels,
@@ -320,17 +373,8 @@ class DNNClassifier(estimator.Estimator):
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
"""
- if n_classes == 2:
- head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
- weight_column=weight_column,
- label_vocabulary=label_vocabulary,
- loss_reduction=loss_reduction)
- else:
- head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary,
- loss_reduction=loss_reduction)
-
+ head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
+ n_classes, weight_column, label_vocabulary, loss_reduction)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 62b13c3200..06a648777f 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -134,7 +134,7 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits):
hidden_weights_names + hidden_biases_names +
[LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0'])
- def _create_estimator_spec(
+ def _create_tpu_estimator_spec(
features, mode, logits, labels, train_op_fn=None, optimizer=None):
del features, labels # Not used.
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
@@ -149,19 +149,29 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits):
train_op = train_op_fn(loss)
elif optimizer is not None:
train_op = optimizer.minimize(loss, global_step=None)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op)
elif mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss))
+ return model_fn._TPUEstimatorSpec(
+ mode=mode, loss=array_ops.identity(loss))
elif mode == model_fn.ModeKeys.PREDICT:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec(
mode=mode, predictions={'logits': array_ops.identity(logits)})
else:
testcase.fail('Invalid mode: {}'.format(mode))
+ def _create_estimator_spec(
+ features, mode, logits, labels, train_op_fn=None, optimizer=None):
+ tpu_spec = _create_tpu_estimator_spec(
+ features, mode, logits, labels, train_op_fn, optimizer)
+ return tpu_spec.as_estimator_spec()
+
head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
head.logits_dimension = logits_dimension
- head.create_estimator_spec = test.mock.MagicMock(wraps=_create_estimator_spec)
+ head._create_tpu_estimator_spec = test.mock.MagicMock(
+ wraps=_create_tpu_estimator_spec)
+ head.create_estimator_spec = test.mock.MagicMock(
+ wraps=_create_estimator_spec)
return head
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 48f448d7f5..232637314d 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -32,6 +32,7 @@ from tensorflow.python.feature_column import feature_column as feature_column_li
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -69,6 +70,35 @@ def _summary_key(head_name, val):
return '%s/%s' % (val, head_name) if head_name else val
+def _create_eval_metrics_tuple(fn, kwargs):
+ """Creates TPU eval metrics tuple.
+
+ Helper function to make eval_metric tuple (eval_metric_fn, fn_kwargs) used
+ by `TPUEstimator`. TPUEstimator requires that `eval_metric_fn` take
+ exclusively Tensor arguments. This helper can help create such a function from
+ a more generic function that can take both Tensor and non-Tensor arguments.
+
+ Args:
+ fn: A eval_metric_fn that takes both Tensor and non-Tensor arguments.
+ This function must return a dict of form
+ {'metric name': (metric_tensor, eval_op)}
+ kwargs: Dict of arguments for `fn`.
+
+ Returns:
+ `eval_metric` tuple that can be passed to a `model_fn._TPUEstimatorSpec`.
+ """
+ tensor_kwargs = {}
+ nontensor_kwargs = {}
+ for k, v in six.iteritems(kwargs):
+ if tensor_util.is_tensor(v):
+ tensor_kwargs[k] = v
+ else:
+ nontensor_kwargs[k] = v
+ def _fn(**tensors):
+ return fn(**dict(nontensor_kwargs, **tensors))
+ return (_fn, tensor_kwargs)
+
+
class _Head(object):
"""Interface for the head/top of a model.
@@ -174,7 +204,6 @@ class _Head(object):
# TODO(b/65403806): By default, collect regularization_losses from
# GraphKeys.REGULARIZATION_LOSSES collection.
- @abc.abstractmethod
def create_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
@@ -203,7 +232,47 @@ class _Head(object):
Returns:
`EstimatorSpec`.
"""
- raise NotImplementedError('Calling an abstract method.')
+ try:
+ tpu_estimator_spec = (
+ self._create_tpu_estimator_spec(
+ features, mode, logits, labels, optimizer, train_op_fn,
+ regularization_losses))
+ return tpu_estimator_spec.as_estimator_spec()
+ except NotImplementedError:
+ # Not all subclasses of _Head will have implemented
+ # _create_tpu_estimator_spec. If it is implemented, we can use it to
+ # create our `EstimatorSpec` here.
+ raise NotImplementedError(
+ 'Subclasses of _Head must implement `create_estimator_spec()` or '
+ '_create_tpu_estimator_spec().')
+
+ def _create_tpu_estimator_spec(
+ self, features, mode, logits, labels=None, optimizer=None,
+ train_op_fn=None, regularization_losses=None):
+ """Returns `model_fn._TPUEstimatorSpec` that a model_fn can return.
+
+ Args:
+ features: Input `dict` of `Tensor` or `SparseTensor` objects.
+ mode: Estimator's `ModeKeys`.
+ logits: logits `Tensor` to be used by the head.
+ labels: Labels `Tensor`, or `dict` of same.
+ optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+ Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+ updates variables and increments `global_step`.
+ train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
+ to optimize the model with the loss in TRAIN mode. Used if `optimizer`
+ is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in
+ TRAIN mode. None is allowed in other modes. If you want to optimize loss
+ yourself you can pass `lambda _: tf.no_op()` and then use
+ EstimatorSpec.loss to compute and apply gradients.
+ regularization_losses: A list of additional scalar losses to be added to
+ the training loss, such as regularization losses.
+
+ Returns:
+ A `model_fn._TPUEstimatorSpec' instance.
+ """
+ raise NotImplementedError(
+ 'TPUEstimatorSpec not available for this model head.')
def _check_dense_labels_match_logits_and_reshape(
@@ -702,10 +771,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
weights=weights,
processed_labels=label_ids)
- def create_estimator_spec(
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
- """Returns an `EstimatorSpec`.
+ """Returns a `model_fn._TPUEstimatorSpec`.
Args:
features: Input `dict` of `Tensor` or `SparseTensor` objects.
@@ -727,7 +796,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
`loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
- `EstimatorSpec`.
+ A `model_fn._TPUEstimatorSpec` instance.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode, or if both are set.
@@ -761,7 +830,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
classifier_output = _classification_output(
scores=probabilities, n_classes=self._n_classes,
label_vocabulary=self._label_vocabulary)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
@@ -781,16 +850,17 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
regularized_training_loss = training_loss
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=self._eval_metric_ops(
- labels=label_ids,
- class_ids=class_ids,
- weights=weights,
- unreduced_loss=unreduced_loss,
- regularization_loss=regularization_loss))
+ eval_metrics=_create_eval_metrics_tuple(self._eval_metric_ops, {
+ 'labels': label_ids,
+ 'class_ids': class_ids,
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss
+ }))
# Train.
if optimizer is not None:
@@ -824,7 +894,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
summary.scalar(
_summary_key(self._name, keys.LOSS_REGULARIZATION),
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
@@ -1060,7 +1130,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
weights=weights,
processed_labels=labels)
- def create_estimator_spec(
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
"""Returns an `EstimatorSpec`.
@@ -1122,7 +1192,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
classifier_output = _classification_output(
scores=probabilities, n_classes=2,
label_vocabulary=self._label_vocabulary)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
@@ -1146,18 +1216,22 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=self._eval_metric_ops(
- labels=processed_labels,
- logits=logits,
- logistic=logistic,
- class_ids=class_ids,
- weights=weights,
- unreduced_loss=unreduced_loss,
- regularization_loss=regularization_loss))
+ eval_metrics=_create_eval_metrics_tuple(
+ self._eval_metric_ops,
+ {
+ 'labels': processed_labels,
+ 'logits': logits,
+ 'logistic': logistic,
+ 'class_ids': class_ids,
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss
+ }
+ ))
# Train.
if optimizer is not None:
@@ -1190,7 +1264,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
summary.scalar(
_summary_key(self._name, keys.LOSS_REGULARIZATION),
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
@@ -1322,7 +1396,25 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
weights=weights,
processed_labels=labels)
- def create_estimator_spec(
+ def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss):
+ """Returns the Eval metric ops."""
+ keys = metric_keys.MetricKeys
+ # Estimator already adds a metric for loss.
+ eval_metric_ops = {
+ _summary_key(self._name, keys.LOSS_MEAN):
+ metrics_lib.mean(
+ values=unreduced_loss,
+ weights=weights)
+ }
+ if regularization_loss is not None:
+ regularization_loss_key = _summary_key(
+ self._name, keys.LOSS_REGULARIZATION)
+ eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
+ values=regularization_loss,
+ name=keys.LOSS_REGULARIZATION)
+ return eval_metric_ops
+
+ def _create_tpu_estimator_spec(
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None, regularization_losses=None):
"""Returns an `EstimatorSpec`.
@@ -1348,7 +1440,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
`loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
- `EstimatorSpec`.
+ A `model_fn._TPUEstimatorSpec` instance.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode, or if both are set.
@@ -1369,7 +1461,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
if mode == model_fn.ModeKeys.PREDICT:
regression_output = export_output.RegressionOutput(
value=predicted_value)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
@@ -1390,25 +1482,18 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- keys = metric_keys.MetricKeys
- # Estimator already adds a metric for loss.
- eval_metric_ops = {
- _summary_key(self._name, keys.LOSS_MEAN):
- metrics_lib.mean(
- values=unreduced_loss,
- weights=weights)
- }
- if regularization_loss is not None:
- regularization_loss_key = _summary_key(
- self._name, keys.LOSS_REGULARIZATION)
- eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
- values=regularization_loss,
- name=keys.LOSS_REGULARIZATION)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=regularized_training_loss,
- eval_metric_ops=eval_metric_ops)
+ eval_metrics=_create_eval_metrics_tuple(
+ self._eval_metric_ops,
+ {
+ 'weights': weights,
+ 'unreduced_loss': unreduced_loss,
+ 'regularization_loss': regularization_loss,
+ }
+ ))
# Train.
if optimizer is not None:
@@ -1441,7 +1526,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
summary.scalar(
_summary_key(self._name, keys.LOSS_REGULARIZATION),
regularization_loss)
- return model_fn.EstimatorSpec(
+ return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=regularized_training_loss,
@@ -1478,3 +1563,42 @@ def _weights(features, weight_column):
raise ValueError('Weight column should be castable to float. '
'Given dtype: {}'.format(weights.dtype))
return math_ops.to_float(weights, name='weights')
+
+
+def _binary_logistic_or_multi_class_head(
+ n_classes, weight_column, label_vocabulary, loss_reduction):
+ """Creates either binary or multi-class head.
+
+ Args:
+ n_classes: Number of label classes.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are
+ already encoded as integer or float within [0, 1] for `n_classes=2` and
+ encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+ Also there will be errors if vocabulary is not provided and labels are
+ string.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ `head._Head` instance.
+ """
+ if n_classes == 2:
+ head = _binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+ else:
+ head = _multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes, weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+ return head
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 32a6339936..ecca3e8b0d 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -86,6 +86,98 @@ def _sigmoid(logits):
return 1 / (1 + np.exp(-logits))
+class CreateEstimatorSpecTest(test.TestCase):
+
+ class _HeadWithTPUSupport(head_lib._Head):
+ """Head that overrides _create_tpu_estimator_spec."""
+
+ def name(self):
+ return 'HeadWithTPUSupport'
+
+ def logits_dimension(self):
+ return None
+
+ def create_loss(self, features, mode, logits, labels):
+ return None
+
+ def _create_tpu_estimator_spec(self, features, mode, logits, labels=None,
+ optimizer=None, train_op_fn=None,
+ regularization_losses=None):
+ return model_fn._TPUEstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ loss=constant_op.constant(0.0, dtype=dtypes.float32))
+
+ class _HeadWithOutTPUSupport(head_lib._Head):
+ """Head that overrides create_estimator_spec."""
+
+ def name(self):
+ return 'HeadWithOutTPUSupport'
+
+ def logits_dimension(self):
+ return None
+
+ def create_loss(self, features, mode, logits, labels):
+ return None
+
+ def create_estimator_spec(self, features, mode, logits, labels=None,
+ optimizer=None, train_op_fn=None,
+ regularization_losses=None):
+ return model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ loss=constant_op.constant(0.0, dtype=dtypes.float32))
+
+ class _InvalidHead(head_lib._Head):
+ """Head that overrides neither estimator_spec functions."""
+
+ def name(self):
+ return 'InvalidHead'
+
+ def logits_dimension(self):
+ return None
+
+ def create_loss(self, features, mode, logits, labels):
+ return None
+
+ def test_head_override_tpu_estimator_spec(self):
+ """Test for `_Head` that overrides _create_tpu_estimator_spec."""
+ head = self._HeadWithTPUSupport()
+
+ tpu_spec = head._create_tpu_estimator_spec(
+ features=None, mode=None, logits=None)
+ self.assertTrue(isinstance(tpu_spec, model_fn._TPUEstimatorSpec))
+ est_spec = head.create_estimator_spec(
+ features=None, mode=None, logits=None)
+ self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))
+
+ def test_head_override_estimator_spec(self):
+ """Test for `_Head` that overrides create_estimator_spec."""
+ head = self._HeadWithOutTPUSupport()
+
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'TPUEstimatorSpec not available for this model head.'):
+ _ = head._create_tpu_estimator_spec(
+ features=None, mode=None, logits=None)
+ est_spec = head.create_estimator_spec(
+ features=None, mode=None, logits=None)
+ self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))
+
+ def test_invalid_head_class(self):
+ head = self._InvalidHead()
+
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'TPUEstimatorSpec not available for this model head.'):
+ _ = head._create_tpu_estimator_spec(
+ features=None, mode=None, logits=None)
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ r'Subclasses of _Head must implement `create_estimator_spec\(\)` or '
+ r'_create_tpu_estimator_spec\(\).'):
+ _ = head.create_estimator_spec(
+ features=None, mode=None, logits=None)
+
+
class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def setUp(self):
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index cc8023a5e7..64457eb1ff 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -37,9 +37,8 @@ from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util
-from tensorflow.python.estimator.export.export import build_all_signature_defs
-from tensorflow.python.estimator.export.export import get_temp_export_dir
-from tensorflow.python.estimator.export.export import get_timestamped_export_dir
+from tensorflow.python.estimator.export import export as export_helpers
+from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
@@ -51,7 +50,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
-from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import device_setter
@@ -609,73 +607,283 @@ class Estimator(object):
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
+ return self._export_saved_model_for_mode(
+ export_dir_base,
+ serving_input_receiver_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs,
+ mode=model_fn_lib.ModeKeys.PREDICT)
+
+ def _export_all_saved_models(
+ self, export_dir_base, input_receiver_fn_map,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
+ """Exports requested train/eval/predict graphs as separate SavedModels.
+
+ This is a wrapper around export_saved_model_for_mode that accepts
+ multiple modes simultaneously and creates directories for each under
+ export_dir_base. See `Estimator.export_saved_model_for_mode` for
+ further details as to how the export works for each mode.
+
+ See tf.contrib.estimator.export_all_saved_models for the currently
+ exposed version of this function.
+
+ Args:
+ export_dir_base: A string containing a directory in which to create
+ timestamped subdirectories containing exported SavedModels.
+ input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
+ mappings, where the input_receiver_fn is a function that takes no
+ argument and returns the appropriate subclass of `InputReceiver`.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel, or `None` if no extra assets are needed.
+ as_text: whether to write the SavedModel proto in text format.
+ checkpoint_path: The checkpoint path to export. If `None` (the default),
+ the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+
+ Returns:
+ A dict of tf.estimator.ModeKeys value to string path for each exported
+ directory.
+
+ Raises:
+ ValueError: if any input_receiver_fn is None, no export_outputs
+ are provided, or no checkpoint can be found.
+ """
+ # pylint: enable=line-too-long
+ # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode.
+ exported = {}
+ for mode, input_receiver_fn in input_receiver_fn_map.items():
+ export_mode_dir = os.path.join(
+ compat.as_bytes(export_dir_base),
+ compat.as_bytes(mode))
+ gfile.MakeDirs(export_mode_dir)
+
+ exported_path = self._export_saved_model_for_mode(
+ export_mode_dir,
+ input_receiver_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs,
+ mode=mode)
+
+ exported[mode] = exported_path
+
+ return exported
+
+ def _export_saved_model_for_mode(
+ self, export_dir_base, input_receiver_fn,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ strip_default_attrs=False,
+ mode=model_fn_lib.ModeKeys.PREDICT):
+ # pylint: disable=line-too-long
+ """Exports a single train/eval/predict graph as a SavedModel.
+
+ For a detailed guide, see
+ @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+
+ See tf.contrib.estimator.export_saved_model_for_mode for the currently
+ exposed version of this function.
+
+ This method takes an input_receiver_fn and mode. For the mode passed in,
+ this method builds a new graph by calling the input_receiver_fn to obtain
+ feature and label `Tensor`s. Next, this method calls the `Estimator`'s
+ model_fn in the passed mode to generate the model graph based on
+ those features and labels, and restores the given checkpoint
+ (or, lacking that, the most recent checkpoint) into the graph.
+ Finally, it creates a timestamped export directory below the
+ export_dir_base, and writes a `SavedModel` into it containing
+ the `MetaGraphDef` for the given mode and its associated signatures.
+
+ For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
+ for each element of the export_outputs dict returned from the model_fn,
+ named using the same keys. One of these keys is always
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ signature will be served when a serving request does not specify one.
+ For each signature, the outputs are provided by the corresponding
+ `ExportOutput`s, and the inputs are always the input receivers provided by
+ the serving_input_receiver_fn.
+
+ For training and evaluation, the train_op is stored in an extra collection,
+ and loss, metrics, and predictions are included in a SignatureDef for the
+ mode in question.
+
+ Extra assets may be written into the SavedModel via the assets_extra
+ argument. This should be a dict, where each key gives a destination path
+ (including the filename) relative to the assets.extra directory. The
+ corresponding value gives the full path of the source file to be copied.
+ For example, the simple case of copying a single file without renaming it
+ is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+
+ Args:
+ export_dir_base: A string containing a directory in which to create
+ timestamped subdirectories containing exported SavedModels.
+ input_receiver_fn: a function that takes no argument and
+ returns the appropriate subclass of `InputReceiver`.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel, or `None` if no extra assets are needed.
+ as_text: whether to write the SavedModel proto in text format.
+ checkpoint_path: The checkpoint path to export. If `None` (the default),
+ the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ mode: tf.estimator.ModeKeys value indicating with mode will be exported.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if input_receiver_fn is None, no export_outputs
+ are provided, or no checkpoint can be found.
+ """
+ # pylint: enable=line-too-long
with context.graph_mode():
- if serving_input_receiver_fn is None:
- raise ValueError('serving_input_receiver_fn must be defined.')
+ if not input_receiver_fn:
+ raise ValueError('An input_receiver_fn must be defined.')
- with ops.Graph().as_default() as g:
- self._create_and_assert_global_step(g)
- random_seed.set_random_seed(self._config.tf_random_seed)
- serving_input_receiver = serving_input_receiver_fn()
+ if not checkpoint_path:
+ # Locate the latest checkpoint
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise ValueError("Couldn't find trained model at %s." % self._model_dir)
- # Call the model_fn and collect the export_outputs.
- estimator_spec = self._call_model_fn(
- features=serving_input_receiver.features,
- labels=None,
- mode=model_fn_lib.ModeKeys.PREDICT,
- config=self.config)
-
- # Build the SignatureDefs from receivers and all outputs
- signature_def_map = build_all_signature_defs(
- serving_input_receiver.receiver_tensors,
- estimator_spec.export_outputs,
- serving_input_receiver.receiver_tensors_alternatives)
-
- if not checkpoint_path:
- # Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
- if not checkpoint_path:
- raise ValueError(
- "Couldn't find trained model at %s." % self._model_dir)
-
- export_dir = get_timestamped_export_dir(export_dir_base)
- temp_export_dir = get_temp_export_dir(export_dir)
-
- # TODO(soergel): Consider whether MonitoredSession makes sense here
- with tf_session.Session(config=self._session_config) as session:
-
- saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
- sharded=True)
- saver_for_restore.restore(session, checkpoint_path)
-
- local_init_op = (
- estimator_spec.scaffold.local_init_op or
- monitored_session.Scaffold.default_local_init_op())
-
- # Perform the export
- builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
- builder.add_meta_graph_and_variables(
- session, [tag_constants.SERVING],
- signature_def_map=signature_def_map,
- assets_collection=ops.get_collection(
- ops.GraphKeys.ASSET_FILEPATHS),
- legacy_init_op=local_init_op,
- strip_default_attrs=strip_default_attrs)
- builder.save(as_text)
-
- # Add the extra assets
- if assets_extra:
- assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
- compat.as_bytes('assets.extra'))
- for dest_relative, source in assets_extra.items():
- dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
- compat.as_bytes(dest_relative))
- dest_path = os.path.dirname(dest_absolute)
- gfile.MakeDirs(dest_path)
- gfile.Copy(source, dest_absolute)
-
- gfile.Rename(temp_export_dir, export_dir)
- return export_dir
+ export_dir = export_helpers.get_timestamped_export_dir(export_dir_base)
+ temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
+
+ builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
+
+ self._add_meta_graph_and_variables_for_mode(
+ builder, input_receiver_fn, checkpoint_path,
+ strip_default_attrs, mode)
+
+ builder.save(as_text)
+
+ # Add the extra assets
+ if assets_extra:
+ assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
+ compat.as_bytes('assets.extra'))
+ for dest_relative, source in assets_extra.items():
+ dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
+ compat.as_bytes(dest_relative))
+ dest_path = os.path.dirname(dest_absolute)
+ gfile.MakeDirs(dest_path)
+ gfile.Copy(source, dest_absolute)
+
+ gfile.Rename(temp_export_dir, export_dir)
+ return export_dir
+
+ def _add_meta_graph_and_variables_for_mode(
+ self, builder, input_receiver_fn, checkpoint_path, strip_default_attrs,
+ mode=model_fn_lib.ModeKeys.PREDICT):
+ # pylint: disable=line-too-long
+ """Loads variables and adds them along with a MetaGraphDef for saving.
+
+ Args:
+ builder: instance of SavedModelBuilder that will be used for saving.
+ input_receiver_fn: a function that takes no argument and
+ returns the appropriate subclass of `InputReceiver`.
+ checkpoint_path: The checkpoint path to export. If `None` (the default),
+ the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ mode: tf.estimator.ModeKeys value indicating which mode will be exported.
+ """
+ # pylint: enable=line-too-long
+ with ops.Graph().as_default() as g:
+ self._create_and_assert_global_step(g)
+ random_seed.set_random_seed(self._config.tf_random_seed)
+
+ input_receiver = input_receiver_fn()
+
+ # Call the model_fn and collect the export_outputs.
+ estimator_spec = self._call_model_fn(
+ features=input_receiver.features,
+ labels=getattr(input_receiver, 'labels', None),
+ mode=mode,
+ config=self.config)
+
+ export_outputs = self._get_export_outputs_for_spec(estimator_spec)
+
+ # Build the SignatureDefs from receivers and all outputs
+ signature_def_map = export_helpers.build_all_signature_defs(
+ input_receiver.receiver_tensors,
+ export_outputs,
+ getattr(input_receiver, 'receiver_tensors_alternatives', None),
+ serving_only=(mode == model_fn_lib.ModeKeys.PREDICT))
+
+ with tf_session.Session(config=self._session_config) as session:
+
+ export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
+
+ local_init_op = (
+ estimator_spec.scaffold.local_init_op or
+ monitored_session.Scaffold.default_local_init_op())
+
+ saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
+ sharded=True)
+ saver_for_restore.restore(session, checkpoint_path)
+
+ # We add the train op explicitly for now, so that we don't have to
+ # change the Builder public interface. Note that this is a no-op
+ # for prediction, where train_op is None.
+ builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access
+
+ builder.add_meta_graph_and_variables(
+ session,
+ tags=export_tags,
+ signature_def_map=signature_def_map,
+ assets_collection=ops.get_collection(
+ ops.GraphKeys.ASSET_FILEPATHS),
+ strip_default_attrs=strip_default_attrs,
+ legacy_init_op=local_init_op)
+
+ def _get_export_outputs_for_spec(self, estimator_spec):
+ """Given an EstimatorSpec, determine what our export outputs should be.
+
+ EstimatorSpecs contain export_outputs that are used for serving, but for
+ training and eval graphs, we must wrap the tensors of interest in
+ appropriate ExportOutput objects.
+
+ Args:
+ estimator_spec: EstimatorSpec object that will be exported.
+
+ Returns:
+ a dict mapping export_output_name to ExportOutput object.
+
+ Raises:
+ ValueError: if an appropriate ExportOutput cannot be found for the
+ passed EstimatorSpec.mode
+ """
+ mode = estimator_spec.mode
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ outputs = estimator_spec.export_outputs
+ else:
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ output_class = export_output.TrainOutput
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ output_class = export_output.EvalOutput
+ else:
+ raise ValueError(
+ 'Export output type not found for mode: {}'.format(mode))
+
+ export_out = output_class(
+ loss=estimator_spec.loss,
+ predictions=estimator_spec.predictions,
+ metrics=estimator_spec.eval_metric_ops)
+ outputs = {mode: export_out}
+
+ return outputs
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
@@ -1544,3 +1752,5 @@ def _get_default_warm_start_settings(warm_start_from):
else:
raise ValueError('warm_start_from must be a string or a WarmStartSettings, '
'instead got {}'.format(type(warm_start_from)))
+
+
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 76b45b7f57..02088e5134 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -1865,6 +1865,41 @@ def _model_fn_for_export_tests(features, labels, mode):
'test': export_output.ClassificationOutput(scores, classes)})
+def _x_y_input_fn():
+ return ({'x': constant_op.constant([[1], [1]]),
+ 'y': constant_op.constant([[2], [2]])},
+ constant_op.constant([[1], [1]]))
+
+
+def _model_fn_with_x_y(features, labels, mode):
+ _ = labels
+ variables.Variable(1., name='weight')
+ scores = constant_op.constant([3.])
+ classes = constant_op.constant(['wumpus'])
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ variables.Variable(36., name='name_collision')
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ export_outputs={
+ 'test': export_output.ClassificationOutput(scores, classes)})
+ else:
+ prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else ''
+
+ multiplied = math_ops.multiply(
+ features['x'], features['y'], name='{}multiplied'.format(prefix))
+ metrics = {'mean': metrics_lib.mean(features['x'] - features['y'],
+ name='{}mean'.format(prefix))}
+ variables.Variable(1., name='later_var')
+ variables.Variable(3., name='name_collision')
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=multiplied,
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ eval_metric_ops=metrics)
+
+
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
_, _ = features, labels
table = saver_test_utils.CheckpointedOp(name='v2')
@@ -1881,21 +1916,41 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode):
'test': export_output.PredictOutput({'prediction': prediction})})
+def _get_serving_input_receiver_fn():
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ return export.build_parsing_serving_input_receiver_fn(feature_spec)
+
+
+def _get_supervised_input_receiver_fn():
+ feature_spec = {
+ 'x': array_ops.placeholder(
+ dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
+ 'y': array_ops.placeholder(
+ dtype=dtypes.int64, shape=(2, 1), name='feature_y')
+ }
+ label_spec = array_ops.placeholder(
+ dtype=dtypes.float32, shape=[1], name='truth')
+
+ return export.build_raw_supervised_input_receiver_fn(feature_spec, label_spec)
+
+
_VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
_EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
class EstimatorExportTest(test.TestCase):
- def test_export_savedmodel_proto_roundtrip(self):
- tmpdir = tempfile.mkdtemp()
- est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
- est.train(input_fn=dummy_input_fn, steps=1)
+ def test_export_savedmodel_proto_roundtrip_raw_receiver(self):
feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
+ est.train(input_fn=dummy_input_fn, steps=1)
+
# Perform the export.
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
@@ -1904,6 +1959,266 @@ class EstimatorExportTest(test.TestCase):
# Check that all the files are in the right places.
self.assertTrue(gfile.Exists(export_dir_base))
+ self._validate_exported_files(export_dir)
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ def test_export_saved_model_train(self):
+ self._test_export_saved_model_for_mode(
+ _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN)
+
+ def test_export_saved_model_eval(self):
+ self._test_export_saved_model_for_mode(
+ _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL)
+
+ def test_export_saved_model_predict(self):
+ self._test_export_saved_model_for_mode(
+ _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT)
+
+ def _test_export_saved_model_for_mode(self, input_receiver_fn, mode):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
+ est.train(input_fn=_x_y_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est._export_saved_model_for_mode(
+ export_dir_base, input_receiver_fn, mode=mode)
+
+ # Check that all the files are in the right places.
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self._validate_exported_files(export_dir)
+
+ # Restore, to validate that the export was well-formed.
+ tag_set = model_fn_lib.EXPORT_TAG_MAP[mode]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, tag_set, export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertFalse('name_collision_1' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_receiver_map(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 1)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertFalse('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_train_only(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 1)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('multiplied' in graph_ops)
+ self.assertTrue('mean/update_op' in graph_ops)
+ self.assertFalse('eval_multiplied' in graph_ops)
+ self.assertTrue('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_eval_only(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 1)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.EVAL], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('eval_multiplied' in graph_ops)
+ self.assertTrue('eval_mean/value' in graph_ops)
+ self.assertFalse('multiplied' in graph_ops)
+ self.assertTrue('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_no_serving(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ self.assertEqual(len(export_dirs), 2)
+ # Restore, to validate that the export was well-formed.
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('multiplied' in graph_ops)
+ self.assertFalse('eval_multiplied' in graph_ops)
+ self.assertTrue('feature_x' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+ export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.EVAL], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('eval_multiplied' in graph_ops)
+ self.assertFalse('multiplied' in graph_ops)
+ # TODO(karmel): is this the desired behavior when names are shared?
+ self.assertTrue('feature_x_1' in graph_ops)
+ self.assertTrue('feature_y_1' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_three_defs(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ # Restore, to validate that the export was well-formed.
+ for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items():
+ export_dir = export_dirs[mode]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, tag_set, export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('global_step/Assign' in graph_ops)
+ self.assertTrue('global_step/Initializer/zeros' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_proto_roundtrip_all_vars(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('later_var' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertFalse('later_var' in graph_ops)
+ self.assertTrue('weight' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def test_export_all_saved_models_name_collision(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+ export_dirs, tmpdir = self._test_export_all_saved_models(
+ input_receiver_fn_map)
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('name_collision' in graph_ops)
+ self.assertFalse('name_collision_1' in graph_ops)
+ collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertEqual(3, collection_vars[-1].eval())
+
+ export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('name_collision' in graph_ops)
+ self.assertFalse('name_collision_1' in graph_ops)
+ collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ # This is a non-obvious detail: when we load the estimator spec
+ # for predict, name_collision gets set to 36. However, we then restore
+ # from checkpoint, which should overwrite that var and make it the 3
+ # from training. In practice, this would not be a good way to write
+ # a model_fn, but leaving this check in for now to ensure consistency
+ # with what would happen given our current order of spec, then
+ # checkpoint.
+ self.assertEqual(3, collection_vars[-1].eval())
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
+ def _test_export_all_saved_models(self, input_receiver_fn_map):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_with_x_y)
+ est.train(input_fn=_x_y_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dirs = est._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map)
+
+ # Check that all the files are in the right places.
+ self.assertTrue(gfile.Exists(export_dir_base))
+
+ for _, export_dir in export_dirs.items():
+ self._validate_exported_files(export_dir)
+
+ return export_dirs, tmpdir
+
+ def _validate_exported_files(self, export_dir):
self.assertTrue(gfile.Exists(export_dir))
self.assertTrue(gfile.Exists(os.path.join(
compat.as_bytes(export_dir),
@@ -1918,18 +2233,6 @@ class EstimatorExportTest(test.TestCase):
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.data-00000-of-00001'))))
- # Restore, to validate that the export was well-formed.
- with ops.Graph().as_default() as graph:
- with session.Session(graph=graph) as sess:
- loader.load(sess, [tag_constants.SERVING], export_dir)
- graph_ops = [x.name for x in graph.get_operations()]
- self.assertTrue('input_example_tensor' in graph_ops)
- self.assertTrue('ParseExample/ParseExample' in graph_ops)
- self.assertTrue('weight' in graph_ops)
-
- # Clean up.
- gfile.DeleteRecursively(tmpdir)
-
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
@@ -2485,5 +2788,6 @@ class EstimatorIntegrationTest(test.TestCase):
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 41c1f5a2e2..9aafb56679 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -40,6 +40,60 @@ from tensorflow.python.util.tf_export import tf_export
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
+_SINGLE_LABEL_DEFAULT_NAME = 'label'
+
+
+def _wrap_and_check_receiver_tensors(receiver_tensors):
+ """Ensure that receiver_tensors is a dict of str to Tensor mappings.
+
+ Args:
+ receiver_tensors: dict of str to Tensors, or a single Tensor.
+
+ Returns:
+ dict of str to Tensors; this is the original dict if one was passed, or
+ the original tensor wrapped in a dictionary.
+
+ Raises:
+ ValueError: if receiver_tensors is None, or has non-string keys,
+ or non-Tensor values
+ """
+ if receiver_tensors is None:
+ raise ValueError('receiver_tensors must be defined.')
+ if not isinstance(receiver_tensors, dict):
+ receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
+ for name, tensor in receiver_tensors.items():
+ _check_tensor_key(name, error_label='receiver_tensors')
+ _check_tensor(tensor, name, error_label='receiver_tensor')
+ return receiver_tensors
+
+
+def _check_tensor(tensor, name, error_label='feature'):
+ """Check that passed `tensor` is a Tensor or SparseTensor."""
+ if not (isinstance(tensor, ops.Tensor)
+ or isinstance(tensor, sparse_tensor.SparseTensor)):
+ fmt_name = ' {}'.format(name) if name else ''
+ value_error = ValueError(
+ '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name))
+ # NOTE(ericmc): This if-else block is a specific carve-out for
+ # LabeledTensor, which has a `.tensor` attribute and which is
+ # convertible to tf.Tensor via ops.convert_to_tensor.
+ # Allowing all types convertible to tf.Tensor is considered by soergel@
+ # to be too permissive.
+ # TODO(soergel): accept any type convertible to Tensor,
+ # as in cl/193238295 snapshot #6.
+ if hasattr(tensor, 'tensor'):
+ try:
+ ops.convert_to_tensor(tensor)
+ except TypeError:
+ raise value_error
+ else:
+ raise value_error
+
+
+def _check_tensor_key(name, error_label='feature'):
+ if not isinstance(name, six.string_types):
+ raise ValueError(
+ '{} keys must be strings: {}.'.format(error_label, name))
@tf_export('estimator.export.ServingInputReceiver')
@@ -51,16 +105,18 @@ class ServingInputReceiver(collections.namedtuple(
The expected return values are:
features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
`SparseTensor`, specifying the features to be passed to the model.
- receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
- input nodes where this receiver expects to be fed by default. Typically,
- this is a single placeholder expecting serialized `tf.Example` protos.
+ receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
+ or `SparseTensor`, specifying input nodes where this receiver expects to
+ be fed by default. Typically, this is a single placeholder expecting
+ serialized `tf.Example` protos.
receiver_tensors_alternatives: a dict of string to additional
- groups of receiver tensors, each of which may be a `Tensor` or a dict of
- string to `Tensor`. These named receiver tensor alternatives generate
- additional serving signatures, which may be used to feed inputs at
- different points within the input receiver subgraph. A typical usage is
- to allow feeding raw feature `Tensor`s *downstream* of the
- tf.parse_example() op. Defaults to None.
+ groups of receiver tensors, each of which may be a `Tensor`,
+ `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`.
+ These named receiver tensor alternatives generate additional serving
+ signatures, which may be used to feed inputs at different points within
+ the input receiver subgraph. A typical usage is to allow feeding raw
+ feature `Tensor`s *downstream* of the tf.parse_example() op.
+ Defaults to None.
"""
def __new__(cls, features, receiver_tensors,
@@ -70,36 +126,10 @@ class ServingInputReceiver(collections.namedtuple(
if not isinstance(features, dict):
features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
for name, tensor in features.items():
- if not isinstance(name, six.string_types):
- raise ValueError('feature keys must be strings: {}.'.format(name))
- if not (isinstance(tensor, ops.Tensor)
- or isinstance(tensor, sparse_tensor.SparseTensor)):
- value_error = ValueError(
- 'feature {} must be a Tensor or SparseTensor.'.format(name))
- # NOTE(ericmc): This if-else block is a specific carve-out for
- # LabeledTensor, which has a `.tensor` attribute and which is
- # convertible to tf.Tensor via ops.convert_to_tensor.
- # Allowing all types convertible to tf.Tensor is considered by soergel@
- # to be too permissive.
- if hasattr(tensor, 'tensor'):
- try:
- ops.convert_to_tensor(tensor)
- except TypeError:
- raise value_error
- else:
- raise value_error
-
- if receiver_tensors is None:
- raise ValueError('receiver_tensors must be defined.')
- if not isinstance(receiver_tensors, dict):
- receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
- for name, tensor in receiver_tensors.items():
- if not isinstance(name, six.string_types):
- raise ValueError(
- 'receiver_tensors keys must be strings: {}.'.format(name))
- if not isinstance(tensor, ops.Tensor):
- raise ValueError(
- 'receiver_tensor {} must be a Tensor.'.format(name))
+ _check_tensor_key(name)
+ _check_tensor(tensor, name)
+
+ receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors)
if receiver_tensors_alternatives is not None:
if not isinstance(receiver_tensors_alternatives, dict):
@@ -115,14 +145,9 @@ class ServingInputReceiver(collections.namedtuple(
receiver_tensors_alternatives[alternative_name] = (
receiver_tensors_alt)
for name, tensor in receiver_tensors_alt.items():
- if not isinstance(name, six.string_types):
- raise ValueError(
- 'receiver_tensors keys must be strings: {}.'.format(name))
- if not (isinstance(tensor, ops.Tensor)
- or isinstance(tensor, sparse_tensor.SparseTensor)):
- raise ValueError(
- 'receiver_tensor {} must be a Tensor or SparseTensor.'.format(
- name))
+ _check_tensor_key(name, error_label='receiver_tensors_alternative')
+ _check_tensor(
+ tensor, name, error_label='receiver_tensors_alternative')
return super(ServingInputReceiver, cls).__new__(
cls,
@@ -155,25 +180,25 @@ class TensorServingInputReceiver(collections.namedtuple(
The expected return values are:
features: A single `Tensor` or `SparseTensor`, representing the feature
to be passed to the model.
- receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
- input nodes where this receiver expects to be fed by default. Typically,
- this is a single placeholder expecting serialized `tf.Example` protos.
+ receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
+ or `SparseTensor`, specifying input nodes where this receiver expects to
+ be fed by default. Typically, this is a single placeholder expecting
+ serialized `tf.Example` protos.
receiver_tensors_alternatives: a dict of string to additional
- groups of receiver tensors, each of which may be a `Tensor` or a dict of
- string to `Tensor`. These named receiver tensor alternatives generate
- additional serving signatures, which may be used to feed inputs at
- different points within the input receiver subgraph. A typical usage is
- to allow feeding raw feature `Tensor`s *downstream* of the
- tf.parse_example() op. Defaults to None.
+ groups of receiver tensors, each of which may be a `Tensor`,
+ `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`.
+ These named receiver tensor alternatives generate additional serving
+ signatures, which may be used to feed inputs at different points within
+ the input receiver subgraph. A typical usage is to allow feeding raw
+ feature `Tensor`s *downstream* of the tf.parse_example() op.
+ Defaults to None.
"""
def __new__(cls, features, receiver_tensors,
receiver_tensors_alternatives=None):
if features is None:
raise ValueError('features must be defined.')
- if not (isinstance(features, ops.Tensor)
- or isinstance(features, sparse_tensor.SparseTensor)):
- raise ValueError('feature must be a Tensor or SparseTensor.')
+ _check_tensor(features, None)
receiver = ServingInputReceiver(
features=features,
@@ -187,6 +212,49 @@ class TensorServingInputReceiver(collections.namedtuple(
receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)
+class SupervisedInputReceiver(collections.namedtuple(
+ 'SupervisedInputReceiver',
+ ['features', 'labels', 'receiver_tensors'])):
+ """A return type for a training_input_receiver_fn or eval_input_receiver_fn.
+
+ This differs from a ServingInputReceiver in that (1) this receiver expects
+ a set of labels to be passed in with features, and (2) this receiver does
+ not support receiver_tensors_alternatives, which are primarily used for
+ serving.
+
+ The expected return values are:
+ features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the features to be passed to the model.
+ labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the labels to be passed to the model.
+ receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
+ or `SparseTensor`, specifying input nodes where this receiver expects to
+ be fed by default. Typically, this is a single placeholder expecting
+ serialized `tf.Example` protos.
+
+ """
+
+ def __new__(cls, features, labels, receiver_tensors):
+ # Both features and labels can be dicts or raw tensors.
+ for input_vals, error_label in ((features, 'feature'), (labels, 'label')):
+ if input_vals is None:
+ raise ValueError('{}s must be defined.'.format(error_label))
+ if isinstance(input_vals, dict):
+ for name, tensor in input_vals.items():
+ _check_tensor_key(name, error_label=error_label)
+ _check_tensor(tensor, name, error_label=error_label)
+ else:
+ _check_tensor(input_vals, None, error_label=error_label)
+
+ receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors)
+
+ return super(SupervisedInputReceiver, cls).__new__(
+ cls,
+ features=features,
+ labels=labels,
+ receiver_tensors=receiver_tensors)
+
+
@tf_export('estimator.export.build_parsing_serving_input_receiver_fn')
def build_parsing_serving_input_receiver_fn(feature_spec,
default_batch_size=None):
@@ -216,6 +284,23 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
return serving_input_receiver_fn
+def _placeholder_from_tensor(t, default_batch_size=None):
+ shape_list = t.get_shape().as_list()
+ shape_list[0] = default_batch_size
+ shape = tensor_shape.TensorShape(shape_list)
+
+ # Reuse the feature tensor's op name (t.op.name) for the placeholder,
+ # excluding the index from the tensor's name (t.name):
+ # t.name = "%s:%d" % (t.op.name, t._value_index)
+ return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
+
+
+def _placeholders_from_receiver_tensors_dict(
+ input_vals, default_batch_size=None):
+ return {name: _placeholder_from_tensor(t, default_batch_size)
+ for name, t in input_vals.items()}
+
+
@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""Build a serving_input_receiver_fn expecting feature Tensors.
@@ -233,17 +318,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""
def serving_input_receiver_fn():
"""A serving_input_receiver_fn that expects features to be fed directly."""
- receiver_tensors = {}
- for name, t in features.items():
- shape_list = t.get_shape().as_list()
- shape_list[0] = default_batch_size
- shape = tensor_shape.TensorShape(shape_list)
-
- # Reuse the feature tensor's op name (t.op.name) for the placeholder,
- # excluding the index from the tensor's name (t.name):
- # t.name = "%s:%d" % (t.op.name, t._value_index)
- receiver_tensors[name] = array_ops.placeholder(
- dtype=t.dtype, shape=shape, name=t.op.name)
+ receiver_tensors = _placeholders_from_receiver_tensors_dict(
+ features, default_batch_size)
+
# TODO(b/34885899): remove the unnecessary copy
# The features provided are simply the placeholders, but we defensively copy
# the dict because it may be mutated.
@@ -252,13 +329,100 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
return serving_input_receiver_fn
+def build_raw_supervised_input_receiver_fn(
+ features, labels, default_batch_size=None):
+ """Build a supervised_input_receiver_fn for raw features and labels.
+
+ This function wraps tensor placeholders in a supervised_receiver_fn
+ with the expectation that the features and labels appear precisely as
+ the model_fn expects them. Features and labels can therefore be dicts of
+ tensors, or raw tensors.
+
+ Args:
+ features: a dict of string to `Tensor` or `Tensor`.
+ labels: a dict of string to `Tensor` or `Tensor`.
+ default_batch_size: the number of query examples expected per batch.
+ Leave unset for variable batch size (recommended).
+
+ Returns:
+ A supervised_input_receiver_fn.
+
+ Raises:
+ ValueError: if features and labels have overlapping keys.
+ """
+ # Check for overlapping keys before beginning.
+ try:
+ feat_keys = features.keys()
+ except AttributeError:
+ feat_keys = [_SINGLE_RECEIVER_DEFAULT_NAME]
+ try:
+ label_keys = labels.keys()
+ except AttributeError:
+ label_keys = [_SINGLE_LABEL_DEFAULT_NAME]
+
+ overlap_keys = set(feat_keys) & set(label_keys)
+ if overlap_keys:
+ raise ValueError('Features and labels must have distinct keys. '
+ 'Found overlapping keys: {}'.format(overlap_keys))
+
+ def supervised_input_receiver_fn():
+ """A receiver_fn that expects pass-through features and labels."""
+ if not isinstance(features, dict):
+ features_cp = _placeholder_from_tensor(features, default_batch_size)
+ receiver_features = {_SINGLE_RECEIVER_DEFAULT_NAME: features_cp}
+ else:
+ receiver_features = _placeholders_from_receiver_tensors_dict(
+ features, default_batch_size)
+ features_cp = receiver_features
+
+ if not isinstance(labels, dict):
+ labels_cp = _placeholder_from_tensor(labels, default_batch_size)
+ receiver_labels = {_SINGLE_LABEL_DEFAULT_NAME: labels_cp}
+ else:
+ receiver_labels = _placeholders_from_receiver_tensors_dict(
+ labels, default_batch_size)
+ labels_cp = receiver_labels
+
+ receiver_tensors = dict(receiver_features)
+ receiver_tensors.update(receiver_labels)
+ return SupervisedInputReceiver(features_cp, labels_cp, receiver_tensors)
+
+ return supervised_input_receiver_fn
+
+
### Below utilities are specific to SavedModel exports.
def build_all_signature_defs(receiver_tensors,
export_outputs,
- receiver_tensors_alternatives=None):
- """Build `SignatureDef`s for all export outputs."""
+ receiver_tensors_alternatives=None,
+ serving_only=True):
+ """Build `SignatureDef`s for all export outputs.
+
+ Args:
+ receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying
+ input nodes where this receiver expects to be fed by default. Typically,
+ this is a single placeholder expecting serialized `tf.Example` protos.
+ export_outputs: a dict of ExportOutput instances, each of which has
+ an as_signature_def instance method that will be called to retrieve
+ the signature_def for all export output tensors.
+ receiver_tensors_alternatives: a dict of string to additional
+ groups of receiver tensors, each of which may be a `Tensor` or a dict of
+ string to `Tensor`. These named receiver tensor alternatives generate
+ additional serving signatures, which may be used to feed inputs at
+ different points within the input receiver subgraph. A typical usage is
+ to allow feeding raw feature `Tensor`s *downstream* of the
+ tf.parse_example() op. Defaults to None.
+ serving_only: boolean; if true, resulting signature defs will only include
+ valid serving signatures. If false, all requested signatures will be
+ returned.
+
+ Returns:
+ signature_def representing all passed args.
+
+ Raises:
+ ValueError: if export_outputs is not a dict
+ """
if not isinstance(receiver_tensors, dict):
receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
if export_outputs is None or not isinstance(export_outputs, dict):
@@ -293,17 +457,24 @@ def build_all_signature_defs(receiver_tensors,
_log_signature_report(signature_def_map, excluded_signatures)
# The above calls to export_output.as_signature_def should return only
- # valid signatures; if there is a validity problem, they raise ValueError,
- # which we ignore above. Consequently the call to is_valid_signature here
- # should not remove anything else; it's just an extra sanity check.
- return {k: v for k, v in signature_def_map.items()
- if signature_def_utils.is_valid_signature(v)}
+ # valid signatures; if there is a validity problem, they raise a ValueError,
+ # in which case we exclude that signature from signature_def_map above.
+ # The is_valid_signature check ensures that the signatures produced are
+ # valid for serving, and acts as an additional sanity check for export
+ # signatures produced for serving. We skip this check for training and eval
+ # signatures, which are not intended for serving.
+ if serving_only:
+ signature_def_map = {k: v for k, v in signature_def_map.items()
+ if signature_def_utils.is_valid_signature(v)}
+ return signature_def_map
_FRIENDLY_METHOD_NAMES = {
signature_constants.CLASSIFY_METHOD_NAME: 'Classify',
signature_constants.REGRESS_METHOD_NAME: 'Regress',
signature_constants.PREDICT_METHOD_NAME: 'Predict',
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train',
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval',
}
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 87b964be37..d387ea2940 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -38,6 +38,8 @@ class ExportOutput(object):
__metaclass__ = abc.ABCMeta
+ _SEPARATOR_CHAR = '/'
+
@abc.abstractmethod
def as_signature_def(self, receiver_tensors):
"""Generate a SignatureDef proto for inclusion in a MetaGraphDef.
@@ -51,6 +53,52 @@ class ExportOutput(object):
"""
pass
+ def _check_output_key(self, key, error_label):
+ # For multi-head models, the key can be a tuple.
+ if isinstance(key, tuple):
+ key = self._SEPARATOR_CHAR.join(key)
+
+ if not isinstance(key, six.string_types):
+ raise ValueError(
+ '{} output key must be a string; got {}.'.format(error_label, key))
+ return key
+
+ def _wrap_and_check_outputs(
+ self, outputs, single_output_default_name, error_label=None):
+ """Wraps raw tensors as dicts and checks type.
+
+ Note that we create a new dict here so that we can overwrite the keys
+ if necessary.
+
+ Args:
+ outputs: A `Tensor` or a dict of string to `Tensor`.
+ single_output_default_name: A string key for use in the output dict
+ if the provided `outputs` is a raw tensor.
+ error_label: descriptive string for use in error messages. If none,
+ single_output_default_name will be used.
+
+ Returns:
+ A dict of tensors
+
+ Raises:
+ ValueError: if the outputs dict keys are not strings or tuples of strings
+ or the values are not Tensors.
+ """
+ if not isinstance(outputs, dict):
+ outputs = {single_output_default_name: outputs}
+
+ output_dict = {}
+ for key, value in outputs.items():
+ error_name = error_label or single_output_default_name
+ key = self._check_output_key(key, error_name)
+ if not isinstance(value, ops.Tensor):
+ raise ValueError(
+ '{} output value must be a Tensor; got {}.'.format(
+ error_name, value))
+
+ output_dict[key] = value
+ return output_dict
+
@tf_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
@@ -154,9 +202,6 @@ class RegressionOutput(ExportOutput):
return signature_def_utils.regression_signature_def(examples, self.value)
-_SINGLE_OUTPUT_DEFAULT_NAME = 'output'
-
-
@tf_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
"""Represents the output of a generic prediction head.
@@ -165,6 +210,7 @@ class PredictOutput(ExportOutput):
Named outputs must be provided as a dict from string to `Tensor`,
"""
+ _SINGLE_OUTPUT_DEFAULT_NAME = 'output'
def __init__(self, outputs):
"""Constructor for PredictOutput.
@@ -177,16 +223,9 @@ class PredictOutput(ExportOutput):
ValueError: if the outputs is not dict, or any of its keys are not
strings, or any of its values are not `Tensor`s.
"""
- if not isinstance(outputs, dict):
- outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs}
- for key, value in outputs.items():
- if not isinstance(key, six.string_types):
- raise ValueError(
- 'Prediction output key must be a string; got {}.'.format(key))
- if not isinstance(value, ops.Tensor):
- raise ValueError(
- 'Prediction output value must be a Tensor; got {}.'.format(value))
- self._outputs = outputs
+
+ self._outputs = self._wrap_and_check_outputs(
+ outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction')
@property
def outputs(self):
@@ -195,3 +234,161 @@ class PredictOutput(ExportOutput):
def as_signature_def(self, receiver_tensors):
return signature_def_utils.predict_signature_def(receiver_tensors,
self.outputs)
+
+
+class _SupervisedOutput(ExportOutput):
+ """Represents the output of a supervised training or eval process."""
+ __metaclass__ = abc.ABCMeta
+
+ LOSS_NAME = 'loss'
+ PREDICTIONS_NAME = 'predictions'
+ METRICS_NAME = 'metrics'
+
+ METRIC_VALUE_SUFFIX = 'value'
+ METRIC_UPDATE_SUFFIX = 'update_op'
+
+ _loss = None
+ _predictions = None
+ _metrics = None
+
+ def __init__(self, loss=None, predictions=None, metrics=None):
+ """Constructor for SupervisedOutput (ie, Train or Eval output).
+
+ Args:
+ loss: dict of Tensors or single Tensor representing calculated loss.
+ predictions: dict of Tensors or single Tensor representing model
+ predictions.
+ metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op.
+
+ Raises:
+ ValueError: if any of the outputs' dict keys are not strings or tuples of
+ strings or the values are not Tensors (or Operations in the case of
+ update_op).
+ """
+
+ if loss is not None:
+ loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME)
+ self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME)
+ if predictions is not None:
+ pred_dict = self._wrap_and_check_outputs(
+ predictions, self.PREDICTIONS_NAME)
+ self._predictions = self._prefix_output_keys(
+ pred_dict, self.PREDICTIONS_NAME)
+ if metrics is not None:
+ self._metrics = self._wrap_and_check_metrics(metrics)
+
+ def _prefix_output_keys(self, output_dict, output_name):
+ """Prepend output_name to the output_dict keys if it doesn't exist.
+
+ This produces predictable prefixes for the pre-determined outputs
+ of SupervisedOutput.
+
+ Args:
+ output_dict: dict of string to Tensor, assumed valid.
+ output_name: prefix string to prepend to existing keys.
+
+ Returns:
+ dict with updated keys and existing values.
+ """
+
+ new_outputs = {}
+ for key, val in output_dict.items():
+ key = self._prefix_key(key, output_name)
+ new_outputs[key] = val
+ return new_outputs
+
+ def _prefix_key(self, key, output_name):
+ if key.find(output_name) != 0:
+ key = output_name + self._SEPARATOR_CHAR + key
+ return key
+
+ def _wrap_and_check_metrics(self, metrics):
+ """Handle the saving of metrics.
+
+ Metrics is either a tuple of (value, update_op), or a dict of such tuples.
+ Here, we separate out the tuples and create a dict with names to tensors.
+
+ Args:
+ metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+
+ Returns:
+ dict of output_names to tensors
+
+ Raises:
+ ValueError: if the dict key is not a string, or the metric values or ops
+ are not tensors.
+ """
+ if not isinstance(metrics, dict):
+ metrics = {self.METRICS_NAME: metrics}
+
+ outputs = {}
+ for key, (metric_val, metric_op) in metrics.items():
+ key = self._check_output_key(key, self.METRICS_NAME)
+ key = self._prefix_key(key, self.METRICS_NAME)
+
+ val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX
+ op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX
+ if not isinstance(metric_val, ops.Tensor):
+ raise ValueError(
+ '{} output value must be a Tensor; got {}.'.format(
+ key, metric_val))
+ if (not isinstance(metric_op, ops.Tensor) and
+ not isinstance(metric_op, ops.Operation)):
+ raise ValueError(
+ '{} update_op must be a Tensor or Operation; got {}.'.format(
+ key, metric_op))
+ outputs[val_name] = metric_val
+ outputs[op_name] = metric_op
+
+ return outputs
+
+ @property
+ def loss(self):
+ return self._loss
+
+ @property
+ def predictions(self):
+ return self._predictions
+
+ @property
+ def metrics(self):
+ return self._metrics
+
+ @abc.abstractmethod
+ def _get_signature_def_fn(self):
+ """Returns a function that produces a SignatureDef given desired outputs."""
+ pass
+
+ def as_signature_def(self, receiver_tensors):
+ signature_def_fn = self._get_signature_def_fn()
+ return signature_def_fn(
+ receiver_tensors, self.loss, self.predictions, self.metrics)
+
+
+class TrainOutput(_SupervisedOutput):
+ """Represents the output of a supervised training process.
+
+ This class generates the appropriate signature def for exporting
+ training output by type-checking and wrapping loss, predictions, and metrics
+ values.
+ """
+
+ def _get_signature_def_fn(self):
+ return signature_def_utils.supervised_train_signature_def
+
+
+class EvalOutput(_SupervisedOutput):
+ """Represents the output of a supervised eval process.
+
+ This class generates the appropriate signature def for exporting
+ eval output by type-checking and wrapping loss, predictions, and metrics
+ values.
+ """
+
+ def _get_signature_def_fn(self):
+ return signature_def_utils.supervised_eval_signature_def
+
+
+
+
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index 7090e53d80..b21ba91b0f 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -225,5 +225,115 @@ class ExportOutputTest(test.TestCase):
})
+class MockSupervisedOutput(export_output_lib._SupervisedOutput):
+ """So that we can test the abstract class methods directly."""
+
+ def _get_signature_def_fn(self):
+ pass
+
+
+class SupervisedOutputTest(test.TestCase):
+
+ def test_supervised_outputs_valid(self):
+ """Tests that no errors are raised when provided outputs are valid."""
+ loss = {"my_loss": constant_op.constant([0])}
+ predictions = {u"output1": constant_op.constant(["foo"])}
+ metrics = {"metrics": (constant_op.constant([0]),
+ constant_op.constant([10])),
+ "metrics2": (constant_op.constant([0]),
+ constant_op.constant([10]))}
+
+ outputter = MockSupervisedOutput(loss, predictions, metrics)
+ self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"])
+ self.assertEqual(
+ outputter.predictions["predictions/output1"], predictions["output1"])
+ self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(
+ outputter.metrics["metrics2/update_op"], metrics["metrics2"][1])
+
+ # Single Tensor is OK too
+ outputter = MockSupervisedOutput(
+ loss["my_loss"], predictions["output1"], metrics["metrics"])
+ self.assertEqual(outputter.loss, {"loss": loss["my_loss"]})
+ self.assertEqual(
+ outputter.predictions, {"predictions": predictions["output1"]})
+ self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+
+ def test_supervised_outputs_none(self):
+ outputter = MockSupervisedOutput(
+ constant_op.constant([0]), None, None)
+ self.assertEqual(len(outputter.loss), 1)
+ self.assertEqual(outputter.predictions, None)
+ self.assertEqual(outputter.metrics, None)
+
+ def test_supervised_outputs_invalid(self):
+ with self.assertRaisesRegexp(ValueError, "predictions output value must"):
+ MockSupervisedOutput(constant_op.constant([0]), [3], None)
+ with self.assertRaisesRegexp(ValueError, "loss output value must"):
+ MockSupervisedOutput("str", None, None)
+ with self.assertRaisesRegexp(ValueError, "metrics output value must"):
+ MockSupervisedOutput(None, None, (15.3, 4))
+ with self.assertRaisesRegexp(ValueError, "loss output key must"):
+ MockSupervisedOutput({25: "Tensor"}, None, None)
+
+ def test_supervised_outputs_tuples(self):
+ """Tests that no errors are raised when provided outputs are valid."""
+ loss = {("my", "loss"): constant_op.constant([0])}
+ predictions = {(u"output1", "2"): constant_op.constant(["foo"])}
+ metrics = {("metrics", "twice"): (constant_op.constant([0]),
+ constant_op.constant([10]))}
+
+ outputter = MockSupervisedOutput(loss, predictions, metrics)
+ self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"]))
+ self.assertEqual(set(outputter.predictions.keys()),
+ set(["predictions/output1/2"]))
+ self.assertEqual(set(outputter.metrics.keys()),
+ set(["metrics/twice/value", "metrics/twice/update_op"]))
+
+ def test_supervised_outputs_no_prepend(self):
+ """Tests that no errors are raised when provided outputs are valid."""
+ loss = {"loss": constant_op.constant([0])}
+ predictions = {u"predictions": constant_op.constant(["foo"])}
+ metrics = {u"metrics": (constant_op.constant([0]),
+ constant_op.constant([10]))}
+
+ outputter = MockSupervisedOutput(loss, predictions, metrics)
+ self.assertEqual(set(outputter.loss.keys()), set(["loss"]))
+ self.assertEqual(set(outputter.predictions.keys()), set(["predictions"]))
+ self.assertEqual(set(outputter.metrics.keys()),
+ set(["metrics/value", "metrics/update_op"]))
+
+ def test_train_signature_def(self):
+ loss = {"my_loss": constant_op.constant([0])}
+ predictions = {u"output1": constant_op.constant(["foo"])}
+ metrics = {"metrics": (constant_op.constant([0]),
+ constant_op.constant([10]))}
+
+ outputter = export_output_lib.TrainOutput(loss, predictions, metrics)
+
+ receiver = {u"features": constant_op.constant(100, shape=(100, 2)),
+ "labels": constant_op.constant(100, shape=(100, 1))}
+ sig_def = outputter.as_signature_def(receiver)
+
+ self.assertTrue("loss/my_loss" in sig_def.outputs)
+ self.assertTrue("metrics/value" in sig_def.outputs)
+ self.assertTrue("predictions/output1" in sig_def.outputs)
+ self.assertTrue("features" in sig_def.inputs)
+
+ def test_eval_signature_def(self):
+ loss = {"my_loss": constant_op.constant([0])}
+ predictions = {u"output1": constant_op.constant(["foo"])}
+
+ outputter = export_output_lib.EvalOutput(loss, predictions, None)
+
+ receiver = {u"features": constant_op.constant(100, shape=(100, 2)),
+ "labels": constant_op.constant(100, shape=(100, 1))}
+ sig_def = outputter.as_signature_def(receiver)
+
+ self.assertTrue("loss/my_loss" in sig_def.outputs)
+ self.assertFalse("metrics/value" in sig_def.outputs)
+ self.assertTrue("predictions/output1" in sig_def.outputs)
+ self.assertTrue("features" in sig_def.inputs)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index c203be7dac..0af587f2a8 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -54,7 +54,7 @@ ops.register_tensor_conversion_function(LabeledTensorMock,
_convert_labeled_tensor_mock_to_tensor)
-class ExportTest(test_util.TensorFlowTestCase):
+class ServingInputReceiverTest(test_util.TensorFlowTestCase):
def test_serving_input_receiver_constructor(self):
"""Tests that no errors are raised when input is expected."""
@@ -161,6 +161,165 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
_ = export.ServingInputReceiver(feature, receiver_tensor)
+
+class SupervisedInputReceiverTest(test_util.TensorFlowTestCase):
+
+ def test_input_receiver_constructor(self):
+ """Tests that no errors are raised when input is expected."""
+ features = {
+ "feature0": constant_op.constant([0]),
+ u"feature1": constant_op.constant([1]),
+ "feature2": sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+ labels = {
+ "classes": constant_op.constant([0] * 100),
+ }
+
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+ export.SupervisedInputReceiver(features, labels, receiver_tensors)
+
+ def test_input_receiver_raw_values(self):
+ """Tests that no errors are raised when input is expected."""
+ features = {
+ "feature0": constant_op.constant([0]),
+ u"feature1": constant_op.constant([1]),
+ "feature2": sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+
+ labels = {
+ "classes": constant_op.constant([0] * 100),
+ }
+
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+ rec = export.SupervisedInputReceiver(
+ features["feature2"], labels, receiver_tensors)
+ self.assertIsInstance(rec.features, sparse_tensor.SparseTensor)
+
+ rec = export.SupervisedInputReceiver(
+ features, labels["classes"], receiver_tensors)
+ self.assertIsInstance(rec.labels, ops.Tensor)
+
+ def test_input_receiver_features_invalid(self):
+ features = constant_op.constant([0] * 100)
+ labels = constant_op.constant([0])
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+
+ with self.assertRaisesRegexp(ValueError, "features must be defined"):
+ export.SupervisedInputReceiver(
+ features=None,
+ labels=labels,
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
+ export.SupervisedInputReceiver(
+ features={1: constant_op.constant([1])},
+ labels=labels,
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(ValueError, "label keys must be strings"):
+ export.SupervisedInputReceiver(
+ features=features,
+ labels={1: constant_op.constant([1])},
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(
+ ValueError, "feature feature1 must be a Tensor or SparseTensor"):
+ export.SupervisedInputReceiver(
+ features={"feature1": [1]},
+ labels=labels,
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(
+ ValueError, "feature must be a Tensor or SparseTensor"):
+ export.SupervisedInputReceiver(
+ features=[1],
+ labels=labels,
+ receiver_tensors=receiver_tensors)
+
+ with self.assertRaisesRegexp(
+ ValueError, "label must be a Tensor or SparseTensor"):
+ export.SupervisedInputReceiver(
+ features=features,
+ labels=100,
+ receiver_tensors=receiver_tensors)
+
+ def test_input_receiver_receiver_tensors_invalid(self):
+ features = {
+ "feature0": constant_op.constant([0]),
+ u"feature1": constant_op.constant([1]),
+ "feature2": sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+ labels = constant_op.constant([0])
+
+ with self.assertRaisesRegexp(
+ ValueError, "receiver_tensors must be defined"):
+ export.SupervisedInputReceiver(
+ features=features,
+ labels=labels,
+ receiver_tensors=None)
+
+ with self.assertRaisesRegexp(
+ ValueError, "receiver_tensors keys must be strings"):
+ export.SupervisedInputReceiver(
+ features=features,
+ labels=labels,
+ receiver_tensors={
+ 1: array_ops.placeholder(dtypes.string, name="example0")})
+
+ with self.assertRaisesRegexp(
+ ValueError, "receiver_tensor example1 must be a Tensor"):
+ export.SupervisedInputReceiver(
+ features=features,
+ labels=labels,
+ receiver_tensors={"example1": [1]})
+
+ def test_single_feature_single_receiver(self):
+ feature = constant_op.constant(5)
+ label = constant_op.constant(5)
+ receiver_tensor = array_ops.placeholder(dtypes.string)
+ input_receiver = export.SupervisedInputReceiver(
+ feature, label, receiver_tensor)
+
+ # single receiver is automatically named
+ receiver_key, = input_receiver.receiver_tensors.keys()
+ self.assertEqual("input", receiver_key)
+
+ def test_multi_feature_single_receiver(self):
+ features = {"foo": constant_op.constant(5),
+ "bar": constant_op.constant(6)}
+ labels = {"value": constant_op.constant(5)}
+ receiver_tensor = array_ops.placeholder(dtypes.string)
+ _ = export.SupervisedInputReceiver(features, labels, receiver_tensor)
+
+ def test_multi_feature_multi_receiver(self):
+ features = {"foo": constant_op.constant(5),
+ "bar": constant_op.constant(6)}
+ labels = {"value": constant_op.constant(5)}
+ receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64),
+ "qux": array_ops.placeholder(dtypes.float32)}
+ _ = export.SupervisedInputReceiver(features, labels, receiver_tensors)
+
+ def test_feature_labeled_tensor(self):
+ feature = LabeledTensorMock()
+ label = constant_op.constant(5)
+ receiver_tensor = array_ops.placeholder(dtypes.string)
+ _ = export.SupervisedInputReceiver(feature, label, receiver_tensor)
+
+
+class ExportTest(test_util.TensorFlowTestCase):
+
def test_build_parsing_serving_input_receiver_fn(self):
feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64),
"float_feature": parsing_ops.VarLenFeature(dtypes.float32)}
@@ -237,6 +396,69 @@ class ExportTest(test_util.TensorFlowTestCase):
dtypes.int32,
serving_input_receiver.receiver_tensors["feature_2"].dtype)
+ def test_build_raw_supervised_input_receiver_fn(self):
+ features = {"feature_1": constant_op.constant(["hello"]),
+ "feature_2": constant_op.constant([42])}
+ labels = {"foo": constant_op.constant([5]),
+ "bar": constant_op.constant([6])}
+ input_receiver_fn = export.build_raw_supervised_input_receiver_fn(
+ features, labels)
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn()
+ self.assertEqual(set(["feature_1", "feature_2"]),
+ set(input_receiver.features.keys()))
+ self.assertEqual(set(["foo", "bar"]),
+ set(input_receiver.labels.keys()))
+ self.assertEqual(set(["feature_1", "feature_2", "foo", "bar"]),
+ set(input_receiver.receiver_tensors.keys()))
+ self.assertEqual(
+ dtypes.string, input_receiver.receiver_tensors["feature_1"].dtype)
+ self.assertEqual(
+ dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype)
+
+ def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):
+ features = {"feature_1": constant_op.constant(["hello"]),
+ "feature_2": constant_op.constant([42])}
+ labels = {"foo": constant_op.constant([5]),
+ "bar": constant_op.constant([6])}
+ input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn(
+ features["feature_1"], labels)
+ input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn(
+ features["feature_1"], labels["foo"])
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn1()
+ self.assertIsInstance(input_receiver.features, ops.Tensor)
+ self.assertEqual(set(["foo", "bar"]),
+ set(input_receiver.labels.keys()))
+ self.assertEqual(set(["input", "foo", "bar"]),
+ set(input_receiver.receiver_tensors.keys()))
+
+ input_receiver = input_receiver_fn2()
+ self.assertIsInstance(input_receiver.features, ops.Tensor)
+ self.assertIsInstance(input_receiver.labels, ops.Tensor)
+ self.assertEqual(set(["input", "label"]),
+ set(input_receiver.receiver_tensors.keys()))
+
+ def test_build_raw_supervised_input_receiver_fn_batch_size(self):
+ features = {"feature_1": constant_op.constant(["hello"]),
+ "feature_2": constant_op.constant([42])}
+ labels = {"foo": constant_op.constant([5]),
+ "bar": constant_op.constant([6])}
+ input_receiver_fn = export.build_raw_supervised_input_receiver_fn(
+ features, labels, default_batch_size=10)
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn()
+ self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape)
+ self.assertEqual([10], input_receiver.features["feature_1"].shape)
+
+ def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):
+ features = {"feature_1": constant_op.constant(["hello"]),
+ "feature_2": constant_op.constant([42])}
+ labels = {"feature_1": constant_op.constant([5]),
+ "bar": constant_op.constant([6])}
+ with self.assertRaises(ValueError):
+ export.build_raw_supervised_input_receiver_fn(features, labels)
+
def test_build_all_signature_defs_without_receiver_alternatives(self):
receiver_tensor = array_ops.placeholder(dtypes.string)
output_1 = constant_op.constant([1.])
@@ -404,6 +626,35 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertTrue(int(time_1) < int(time_2))
self.assertTrue(int(time_2) < int(time_3))
+ def test_build_all_signature_defs_serving_only(self):
+ receiver_tensor = {"input": array_ops.placeholder(dtypes.string)}
+ output_1 = constant_op.constant([1.])
+ export_outputs = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ export_output.PredictOutput(outputs=output_1),
+ "train": export_output.TrainOutput(loss=output_1),
+ }
+
+ signature_defs = export.build_all_signature_defs(
+ receiver_tensor, export_outputs)
+
+ expected_signature_defs = {
+ "serving_default": signature_def_utils.predict_signature_def(
+ receiver_tensor, {"output": output_1})
+ }
+
+ self.assertDictEqual(expected_signature_defs, signature_defs)
+
+ signature_defs = export.build_all_signature_defs(
+ receiver_tensor, export_outputs, serving_only=False)
+
+ expected_signature_defs.update({
+ "train": signature_def_utils.supervised_train_signature_def(
+ receiver_tensor, loss={"loss": output_1})
+ })
+
+ self.assertDictEqual(expected_signature_defs, signature_defs)
+
class TensorServingReceiverTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 8111ab564c..3edf9fe940 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
@@ -53,6 +54,13 @@ class ModeKeys(object):
LOSS_METRIC_KEY = 'loss'
AVERAGE_LOSS_METRIC_KEY = 'average_loss'
+# Mapping of the modes to appropriate tag_constants that are used for saving.
+EXPORT_TAG_MAP = {
+ ModeKeys.PREDICT: [tag_constants.SERVING],
+ ModeKeys.TRAIN: [tag_constants.TRAINING],
+ ModeKeys.EVAL: [tag_constants.EVAL],
+}
+
@tf_export('estimator.EstimatorSpec')
class EstimatorSpec(
@@ -326,6 +334,57 @@ class EstimatorSpec(
return EstimatorSpec(*new_fields)
+class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
+ 'mode',
+ 'predictions',
+ 'loss',
+ 'train_op',
+ 'eval_metrics',
+ 'export_outputs',
+ 'scaffold_fn',
+ 'host_call'])):
+ """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
+
+ This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See
+ tensorflow/contrib/tpu/python/tpu/tpu_estimator.py for more detailed
+ documentation.
+ """
+
+ def __new__(cls,
+ mode,
+ predictions=None,
+ loss=None,
+ train_op=None,
+ eval_metrics=None,
+ export_outputs=None,
+ scaffold_fn=None,
+ host_call=None):
+ """Creates a `_TPUEstimatorSpec` instance."""
+ return super(_TPUEstimatorSpec, cls).__new__(cls,
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metrics=eval_metrics,
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn,
+ host_call=host_call)
+
+ def as_estimator_spec(self):
+ """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
+ if not self.eval_metrics:
+ eval_metric_ops = None
+ else:
+ metric_fn, tensors = self.eval_metrics
+ eval_metric_ops = metric_fn(**tensors)
+ return EstimatorSpec(mode=self.mode,
+ predictions=self.predictions,
+ loss=self.loss,
+ train_op=self.train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=self.export_outputs)
+
+
def _check_is_tensor_or_operation(x, name):
if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index e7f9e590af..f82e94b1a3 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -696,7 +696,7 @@ class _FuncGraph(ops.Graph):
return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
**kwargs)
- def capture(self, tensor):
+ def capture(self, tensor, name=None):
"""Adds the given tensor to this graph and returns the captured tensor."""
if tensor in self._captured:
# Captured already.
@@ -704,15 +704,16 @@ class _FuncGraph(ops.Graph):
elif self._capture_by_value:
return self._add_tensor_and_parents(tensor)
else:
- return self._capture_tensor_as_extra_input(tensor)
+ return self._capture_tensor_as_extra_input(tensor, name)
- def _capture_tensor_as_extra_input(self, tensor):
+ def _capture_tensor_as_extra_input(self, tensor, name=None):
# Substitute with a placeholder.
self.extra_inputs.append(tensor)
# Hoist the new input placeholder out of any control flow context
# we're currently in.
with ops.control_dependencies(None):
- ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape())
+ ph = array_ops.placeholder(
+ tensor.dtype, shape=tensor.get_shape(), name=name)
# pylint: disable=protected-access
if ops._USE_C_SHAPES:
handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph,
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2209e8e21a..de3bf0032b 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1057,13 +1057,19 @@ def internal_convert_to_tensor(value,
"""
if ctx is None: ctx = context.context()
- if ctx.executing_eagerly():
- # Fast path for EagerTensors that don't need any conversion.
- if isinstance(value, EagerTensor):
+ if isinstance(value, EagerTensor):
+ if ctx.executing_eagerly():
+ # Fast path for EagerTensors that don't need any conversion.
# Note that we don't check that value's dtype matches the dtype
# argument. We expect that the C runtime will do that checking
# when we execute the kernel.
return value
+ else:
+ graph = get_default_graph()
+ if not graph.building_function:
+ raise RuntimeError("Attempting to capture an EagerTensor without "
+ "building a function.")
+ return graph.capture(value, name=name)
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
@@ -1251,7 +1257,10 @@ def internal_convert_to_tensor_or_indexed_slices(value,
Raises:
ValueError: If `dtype` does not match the element type of `value`.
"""
- if isinstance(value, _TensorLike):
+ if isinstance(value, EagerTensor) and not context.executing_eagerly():
+ return internal_convert_to_tensor(
+ value, dtype=dtype, name=name, as_ref=as_ref)
+ elif isinstance(value, _TensorLike):
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 1b66f58939..523eb67935 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -395,7 +395,7 @@ py_test(
py_test(
name = "resnet50_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/applications/resnet50_test.py"],
srcs_version = "PY2AND3",
deps = [
@@ -563,7 +563,7 @@ py_test(
py_test(
name = "normalization_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/layers/normalization_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
@@ -604,6 +604,7 @@ py_test(
name = "lstm_test",
size = "medium",
srcs = ["_impl/keras/layers/lstm_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"noasan", # times out b/63678675
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 3af4eaabe9..16ee2952b2 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -1658,7 +1658,7 @@ class DeferredTensor(object):
"""Tensor-like object used to build graphs of layers in Eager mode.
When calling a layer on a DeferredTensor, the layer will not perform any
- computation and will simply perfom shape inference to return new
+ computation and will simply perform shape inference to return new
DeferredTensors with appropriate shape information. Thus DeferredTensor
behaves like a graph-mode Tensor when manipulated by layers.
"""
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index 3197d49fce..b7fab6e974 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -318,6 +318,9 @@ class Network(base_layer.Layer):
layer, name='layer-%d' % layer_index, overwrite=True)
def __setattr__(self, name, value):
+ no_dependency = isinstance(value, checkpointable.NoDependency)
+ if no_dependency:
+ value = value.value
if isinstance(value, (base_layer.Layer, Network)):
try:
is_graph_network = self._is_graph_network
@@ -332,7 +335,8 @@ class Network(base_layer.Layer):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
value._use_resource_variables = True
- if isinstance(value, checkpointable.CheckpointableBase):
+ if (not no_dependency
+ and isinstance(value, checkpointable.CheckpointableBase)):
# Layer (and therefore Network/Model) inherit from CheckpointableBase
# rather than Checkpointable, which means there is no Checkpointable
# __setattr__ override (it would be a performance issue for functional
diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py
index 8aba16aef3..a90ad131a5 100644
--- a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py
@@ -20,8 +20,11 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras._impl import keras
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training import rmsprop
@@ -75,7 +78,7 @@ class TestSequential(test.TestCase):
model.pop()
@tf_test_util.run_in_graph_and_eager_modes()
- def test_sequential_deferred_build(self):
+ def test_sequential_deferred_build_with_np_arrays(self):
num_hidden = 5
input_dim = 3
batch_size = 5
@@ -100,6 +103,40 @@ class TestSequential(test.TestCase):
self.assertEqual(len(model.weights), 2 * 2)
@tf_test_util.run_in_graph_and_eager_modes()
+ def test_sequential_deferred_build_with_dataset_iterators(self):
+ if not context.executing_eagerly():
+ # TODO(psv/fchollet): Add support for this use case in graph mode.
+ return
+ num_hidden = 5
+ input_dim = 3
+ num_classes = 2
+ num_samples = 50
+ steps_per_epoch = 10
+
+ model = keras.models.Sequential()
+ # We don't specify the input shape.
+ model.add(keras.layers.Dense(num_hidden))
+ model.add(keras.layers.Dense(num_classes))
+ model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
+ self.assertEqual(len(model.layers), 2)
+ self.assertEqual(len(model.weights), 0)
+ self.assertFalse(model.built)
+
+ x = array_ops.ones((num_samples, input_dim))
+ y = array_ops.zeros((num_samples, num_classes))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch)
+ self.assertTrue(model.built)
+ self.assertEqual(model.inputs[0].get_shape().as_list(), [None, input_dim])
+ self.assertEqual(model.outputs[0].get_shape().as_list(),
+ [None, num_classes])
+ self.assertEqual(len(model.weights), 2 * 2)
+
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_invalid_use_cases(self):
# Added objects must be layer instances
with self.assertRaises(TypeError):
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 5f9b3e8c7d..c7623d2b52 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -18,11 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import weakref
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras._impl.keras import backend as K
@@ -106,6 +108,11 @@ class Model(Network):
```
"""
+ def __init__(self, *args, **kwargs):
+ super(Model, self).__init__(*args, **kwargs)
+ # Create a cache for iterator get_next op.
+ self._iterator_get_next = weakref.WeakKeyDictionary()
+
def compile(self,
optimizer,
loss=None,
@@ -623,12 +630,23 @@ class Model(Network):
**kwargs)
self._post_build_cleanup()
+ def _get_iterator_get_next_tensors(self, iterator):
+ get_next_op = self._iterator_get_next.get(iterator, None)
+ if get_next_op is None:
+ get_next_op = iterator.get_next()
+ self._iterator_get_next[iterator] = get_next_op
+ return get_next_op
+
def _standardize_user_data(self,
x,
y=None,
sample_weight=None,
class_weight=None,
- batch_size=None):
+ batch_size=None,
+ check_steps=False,
+ steps_name='steps',
+ steps=None,
+ validation_split=0):
"""Runs validation checks on input and target data passed by the user.
Also standardizes the data to lists of arrays, in order.
@@ -660,6 +678,16 @@ class Model(Network):
to, as conveyed by `y`.
batch_size: Integer batch size. If provided, it is used to run additional
validation checks on stateful models.
+ check_steps: boolean, True if we want to check for validity of `steps` and
+ False, otherwise. For example, when we are standardizing one batch of
+ data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps`
+ value is not required and we should not check for its validity in these
+ cases.
+ steps_name: The public API's parameter name for `steps`.
+ steps: Integer or `None`. Total number of steps (batches of samples) to
+ execute.
+ validation_split: Float between 0 and 1.
+ Fraction of the training data to be used as validation data.
Returns:
A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
@@ -671,33 +699,54 @@ class Model(Network):
ValueError: In case of invalid user-provided data.
RuntimeError: If the model was never compiled.
"""
- # First, we build/compile the model on the fly if necessary.
if isinstance(x, dataset_ops.Dataset):
raise ValueError('You passed a `Dataset` instance to your model (%s), '
'which is not supported. Instead, pass an `Iterator`, '
'which you can obtain e.g. via '
'`dataset.make_one_shot_iterator()` (the exact method '
'to use will depend on your specific dataset).' % x)
- if isinstance(x, iterator_ops.Iterator):
- if y is not None:
- raise ValueError('You passed a dataset iterator (%s) as input `x` to '
- 'your model. In that case, you should not specify '
- 'a target (`y`) argument, since the dataset iterator '
- 'generates both input data and target data. '
- 'Received: %s' % (x, y))
- if not context.executing_eagerly():
- x, y = x.get_next()
- # TODO(fchollet): handle case of `get_next` not returning 2 tensors?
- else:
- # TODO(psv): implement this. The way to support it will be to typecheck
- # for `iterator` before `_standardize_user_data` is called and redirect
- # to new training/eval functions in `training_eager.py`. The model
- # may need to get built using the specs of the data from the first batch
- # drawn from the iterator.
- raise ValueError('Dataset iterators are not supported '
- 'with eager execution yet.')
+ # Validates `steps` argument based on x's type.
+ if check_steps:
+ training_utils.check_steps_argument(x, steps, steps_name)
+
+ is_x_eager_iterator = isinstance(x, iterator_ops.EagerIterator)
+ is_x_iterator = isinstance(x, iterator_ops.Iterator)
+
+ # Validate user inputs when data is given as a dataset iterator.
+ if is_x_iterator or is_x_eager_iterator:
+ training_utils.validate_iterator_input(x, y, sample_weight,
+ validation_split)
+
+ # For eager iterators, when we have to process multiple batches of samples,
+ # we will standardize the data when we actually loop over iterator and get
+ # the batches. For now, we just return the iterator as is.
+ if is_x_eager_iterator and steps is not None:
+ return x, y, sample_weight
+
+ # If input data is a dataset iterator in graph mode or if it is an eager
+ # iterator and only one batch of samples is required, we fetch the data
+ # tensors from the iterator and then standardize them.
+ if is_x_iterator or is_x_eager_iterator:
+ try:
+ if is_x_iterator:
+ next_element = self._get_iterator_get_next_tensors(x)
+ else:
+ next_element = x.get_next()
+ except errors.OutOfRangeError:
+ raise RuntimeError('Your dataset iterator ran out of data; '
+ 'Make sure that your dataset can generate '
+ 'required number of samples.')
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s' % next_element)
+ x, y = next_element
+
+ # First, we build/compile the model on the fly if necessary.
all_inputs = []
+ is_build_called = False
+ is_compile_called = False
if not self.built:
# We need to use `x` to set the model inputs.
# We type-check that `x` and `y` are either single arrays
@@ -720,6 +769,7 @@ class Model(Network):
# If values, then in symbolic-mode placeholders will be created
# to match the value shapes.
if not self.inputs:
+ is_build_called = True
self._set_inputs(x)
if y is not None:
@@ -736,6 +786,7 @@ class Model(Network):
raise ValueError('Please provide as model targets either a single '
'array or a list of arrays. '
'You passed: y=' + str(y))
+ all_inputs += list(y)
elif isinstance(y, dict):
raise ValueError('Please do not pass a dictionary as model targets.')
else:
@@ -743,14 +794,10 @@ class Model(Network):
raise ValueError('Please provide as model targets either a single '
'array or a list of arrays. '
'You passed: y=' + str(y))
+ all_inputs.append(y)
# Typecheck that all inputs are *either* value *or* symbolic.
# TODO(fchollet): this check could be removed in Eager mode?
- if y is not None:
- if isinstance(y, (list, tuple)):
- all_inputs += list(y)
- else:
- all_inputs.append(y)
if any(tensor_util.is_tensor(v) for v in all_inputs):
if not all(tensor_util.is_tensor(v) for v in all_inputs):
raise ValueError('Do not pass inputs that mix Numpy arrays and '
@@ -764,17 +811,22 @@ class Model(Network):
if not isinstance(y, (list, tuple)):
y = [y]
target_tensors = [v for v in y if tensor_util.is_tensor(v)]
+ is_compile_called = True
self.compile(optimizer=self.optimizer,
loss=self.loss,
metrics=self.metrics,
loss_weights=self.loss_weights,
target_tensors=target_tensors)
- # If `x` and `y` were all symbolic, then no model should not be fed any
- # inputs and targets.
+ # In graph mode, if we had just set inputs and targets as symbolic tensors
+ # by invoking build and compile on the model respectively, we do not have to
+ # feed anything to the model. Model already has input and target data as
+ # part of the graph.
# Note: in this case, `any` and `all` are equivalent since we disallow
# mixed symbolic/value inputs.
- if any(tensor_util.is_tensor(v) for v in all_inputs):
+ if (not context.executing_eagerly() and is_build_called and
+ is_compile_called and
+ any(tensor_util.is_tensor(v) for v in all_inputs)):
return [], [], []
# What follows is input validation and standardization to list format,
@@ -904,7 +956,12 @@ class Model(Network):
if isinstance(inputs, list):
assert len(inputs) == 1
inputs = inputs[0]
- self.build(input_shape=(None,) + inputs.shape[1:])
+
+ if tensor_util.is_tensor(inputs):
+ input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ else:
+ input_shape = (None,) + inputs.shape[1:]
+ self.build(input_shape=input_shape)
elif context.executing_eagerly():
self._eager_set_inputs(inputs)
else:
@@ -931,12 +988,18 @@ class Model(Network):
# On-the-fly setting of model inputs/outputs as DeferredTensors,
# to keep track of number of inputs and outputs and their ndim.
if isinstance(inputs, (list, tuple)):
- dummy_output_values = self.call(
- [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
+ if tensor_util.is_tensor(inputs[0]):
+ dummy_output_values = self.call(inputs)
+ else:
+ dummy_output_values = self.call(
+ [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
dummy_input_values = list(inputs)
else:
- dummy_output_values = self.call(
- ops.convert_to_tensor(inputs, dtype=K.floatx()))
+ if tensor_util.is_tensor(inputs):
+ dummy_output_values = self.call(inputs)
+ else:
+ dummy_output_values = self.call(
+ ops.convert_to_tensor(inputs, dtype=K.floatx()))
dummy_input_values = [inputs]
if isinstance(dummy_output_values, (list, tuple)):
dummy_output_values = list(dummy_output_values)
@@ -1071,7 +1134,7 @@ class Model(Network):
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
- Do not specify the `batch_size` is your data is in the
+ Do not specify the `batch_size` if your data is in the
form of symbolic tensors or dataset iterators (since they generate
batches).
epochs: Integer. Number of epochs to train the model.
@@ -1094,7 +1157,8 @@ class Model(Network):
the loss and any model metrics
on this data at the end of each epoch.
The validation data is selected from the last samples
- in the `x` and `y` data provided, before shuffling.
+ in the `x` and `y` data provided, before shuffling. This argument is
+ not supported when `x` is a dataset iterator.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
@@ -1124,7 +1188,8 @@ class Model(Network):
`(samples, sequence_length)`,
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
- `sample_weight_mode="temporal"` in `compile()`.
+ `sample_weight_mode="temporal"` in `compile()`. This argument is not
+ supported when `x` is a dataset iterator.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
@@ -1165,21 +1230,23 @@ class Model(Network):
epochs = kwargs.pop('nb_epoch')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
- if x is None and y is None and steps_per_epoch is None:
- raise ValueError('If fitting from data tensors, '
- 'you should specify the `steps_per_epoch` '
- 'argument.')
- # Validate user data.
+ # Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x,
y,
sample_weight=sample_weight,
class_weight=class_weight,
- batch_size=batch_size)
+ batch_size=batch_size,
+ check_steps=True,
+ steps_name='steps_per_epoch',
+ steps=steps_per_epoch,
+ validation_split=validation_split)
+
# Prepare validation data.
if validation_data:
- if isinstance(validation_data, iterator_ops.Iterator):
+ if (isinstance(validation_data, iterator_ops.Iterator) or
+ isinstance(validation_data, iterator_ops.EagerIterator)):
val_x = validation_data
val_y = None
val_sample_weight = None
@@ -1196,11 +1263,13 @@ class Model(Network):
'or alternatively it could be a dataset iterator. However we '
'received `validation_data=%s`' % validation_data)
+ # Validate and standardize validation data.
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x,
val_y,
sample_weight=val_sample_weight,
- batch_size=batch_size)
+ batch_size=batch_size,
+ steps=validation_steps)
elif validation_split and 0. < validation_split < 1.:
if training_utils.has_symbolic_tensors(x):
@@ -1229,6 +1298,7 @@ class Model(Network):
inputs=x,
targets=y,
sample_weights=sample_weights,
+ class_weight=class_weight,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
@@ -1300,7 +1370,8 @@ class Model(Network):
`(samples, sequence_length)`,
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
- `sample_weight_mode="temporal"` in `compile()`.
+ `sample_weight_mode="temporal"` in `compile()`. This argument is not
+ supported when `x` is a dataset iterator.
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
@@ -1318,17 +1389,16 @@ class Model(Network):
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
- if x is None and y is None and steps is None:
- raise ValueError('If evaluating from data tensors, '
- 'you should specify the `steps` '
- 'argument.')
- # Validate user data.
+ # Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x,
y,
sample_weight=sample_weight,
- batch_size=batch_size)
+ batch_size=batch_size,
+ check_steps=True,
+ steps_name='steps',
+ steps=steps)
if context.executing_eagerly():
return training_eager.test_loop(
@@ -1345,7 +1415,12 @@ class Model(Network):
Computation is done in batches.
Arguments:
- x: Input samples, as Numpy array(s) or tensor(s).
+ x: Input samples. It could be:
+ - A Numpy array (or array-like), or a list of arrays
+ (in case the model has multiple inputs).
+ - A TensorFlow tensor, or a list of tensors
+ (in case the model has multiple inputs).
+ - A `tf.data` dataset iterator.
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
@@ -1369,11 +1444,10 @@ class Model(Network):
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
- if x is None and steps is None:
- raise ValueError('If predicting from data tensors, '
- 'you should specify the `steps` '
- 'argument.')
- x, _, _ = self._standardize_user_data(x)
+
+ # Validate and standardize user data.
+ x, _, _ = self._standardize_user_data(
+ x, check_steps=True, steps_name='steps', steps=steps)
if context.executing_eagerly():
return training_eager.predict_loop(
@@ -1406,7 +1480,9 @@ class Model(Network):
with shape (samples, sequence_length),
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
- sample_weight_mode="temporal" in compile().
+ sample_weight_mode="temporal" in compile(). This argument is not
+ supported when `x` is a dataset iterator.
+
class_weight: Optional dictionary mapping
class indices (integers) to
a weight (float) to apply to the model's loss for the samples
@@ -1424,11 +1500,9 @@ class Model(Network):
Raises:
ValueError: In case of invalid user-provided arguments.
"""
+ # Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
- x,
- y,
- sample_weight=sample_weight,
- class_weight=class_weight)
+ x, y, sample_weight=sample_weight, class_weight=class_weight)
if context.executing_eagerly():
outputs = training_eager.train_on_batch(
@@ -1470,7 +1544,8 @@ class Model(Network):
with shape (samples, sequence_length),
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
- sample_weight_mode="temporal" in compile().
+ sample_weight_mode="temporal" in compile(). This argument is not
+ supported when `x` is a dataset iterator.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1481,6 +1556,7 @@ class Model(Network):
Raises:
ValueError: In case of invalid user-provided arguments.
"""
+ # Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight)
@@ -1503,23 +1579,34 @@ class Model(Network):
"""Returns predictions for a single batch of samples.
Arguments:
- x: Input samples, as Numpy array(s) or tensor(s).
+ x: Input data. It could be:
+ - A Numpy array (or array-like), or a list of arrays
+ (in case the model has multiple inputs).
+ - A TensorFlow tensor, or a list of tensors
+ (in case the model has multiple inputs).
+ - A `tf.data` dataset iterator.
Returns:
Numpy array(s) of predictions.
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
"""
- x, _, _ = self._standardize_user_data(x)
-
+ # Validate and standardize user data.
+ inputs, _, _ = self._standardize_user_data(x)
if context.executing_eagerly():
- inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x]
+ if not isinstance(inputs, iterator_ops.EagerIterator):
+ inputs = [
+ ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs
+ ]
return self(inputs) # pylint: disable=not-callable
if not context.executing_eagerly():
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = x + [0]
+ ins = inputs + [0]
else:
- ins = x
+ ins = inputs
self._make_predict_function()
outputs = self.predict_function(ins)
@@ -1631,8 +1718,7 @@ class Model(Network):
steps_per_epoch=10000, epochs=10)
```
Raises:
- ValueError: In case the generator yields
- data in an invalid format.
+ ValueError: In case the generator yields data in an invalid format.
"""
if not self.built and not self._is_graph_network:
raise NotImplementedError(
@@ -1697,8 +1783,7 @@ class Model(Network):
ValueError: in case of invalid arguments.
Raises:
- ValueError: In case the generator yields
- data in an invalid format.
+ ValueError: In case the generator yields data in an invalid format.
"""
if not self.built and not self._is_graph_network:
raise NotImplementedError(
@@ -1751,8 +1836,7 @@ class Model(Network):
Numpy array(s) of predictions.
Raises:
- ValueError: In case the generator yields
- data in an invalid format.
+ ValueError: In case the generator yields data in an invalid format.
"""
if not self.built and not self._is_graph_network:
raise NotImplementedError(
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
index 4164cae864..12e74ef51d 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
@@ -108,8 +108,8 @@ def fit_loop(model,
do_validation = False
if val_inputs:
do_validation = True
- if verbose and inputs and hasattr(inputs[0], 'shape') and hasattr(
- val_inputs[0], 'shape'):
+ if (steps_per_epoch is None and verbose and inputs and
+ hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
print('Train on %d samples, validate on %d samples' %
(inputs[0].shape[0], val_inputs[0].shape[0]))
if validation_steps:
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
index b9c99b2222..3617eb281a 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
@@ -23,7 +23,9 @@ import copy
import numpy as np
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager.backprop import GradientTape
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras._impl.keras import backend
@@ -177,6 +179,550 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
return outs, total_loss, loss_metrics
+def iterator_fit_loop(model,
+ inputs,
+ class_weight,
+ steps_per_epoch,
+ callback_model,
+ out_labels,
+ epoch_logs,
+ val_inputs=None,
+ val_targets=None,
+ val_sample_weights=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ callback_metrics=None,
+ validation_steps=None,
+ do_validation=False):
+ """Fit function for eager execution when input is given as dataset iterator.
+
+ Updates the given epoch logs.
+
+ Arguments:
+ model: Instance of the `Model`.
+ inputs: Input dataset iterator.
+ class_weight: Optional class-weight array to weight the importance of
+ samples in `inputs` based on the class they belong to, as conveyed by
+ the targets from the `inputs` iterator.
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch.
+ callback_model: Instance of `Model` to callback.
+ out_labels: Output labels generated from model metric names.
+ epoch_logs: Dictionary of logs from every epoch.
+ val_inputs: Input data for validation.
+ val_targets: Target data for validation.
+ val_sample_weights: Sample weight data for validation.
+ epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
+ callback_metrics: List of strings, the display names of the metrics
+ passed to the callbacks. They should be the
+ concatenation of list the display names of the outputs of
+ `f` and the list of display names of the outputs of `f_val`.
+ validation_steps: Number of steps to run validation for (only if doing
+ validation from data tensors). Ignored with default value of `None`.
+ do_validation: Boolean value indicating whether we should do validation.
+
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
+ """
+ assert isinstance(inputs, iterator_ops.EagerIterator)
+ for step_index in range(steps_per_epoch):
+ batch_logs = {}
+ batch_logs['batch'] = step_index
+ batch_logs['size'] = 1
+ callbacks.on_batch_begin(step_index, batch_logs)
+
+ # Get data from the iterator.
+ try:
+ next_element = inputs.get_next()
+ except errors.OutOfRangeError:
+ logging.warning(
+ 'Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset'
+ ' can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' % steps_per_epoch * epochs)
+ break
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s' % next_element)
+ x, y = next_element
+
+ # Validate and standardize data.
+ x, y, sample_weights = model._standardize_user_data(
+ x, y, class_weight=class_weight)
+ if sample_weights:
+ sample_weights = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights
+ ]
+
+ if step_index == 0 and not callback_metrics:
+ out_labels = model.metrics_names
+ if do_validation:
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ callback_metrics = copy.copy(out_labels)
+ callbacks.set_params({
+ 'epochs': epochs,
+ 'steps': steps_per_epoch,
+ 'verbose': verbose,
+ 'do_validation': do_validation,
+ 'metrics': callback_metrics or [],
+ })
+
+ # Train model.
+ outs, loss, loss_metrics = _process_single_batch(
+ model, x, y, sample_weights=sample_weights, training=True)
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ # Calculate metrics.
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ # Required for eager execution
+ metrics_results = _eager_metrics_fn(model, outs, y)
+ batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
+
+ for k, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_logs[k] = tensor_util.constant_value(v)
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callback_model.stop_training:
+ break
+
+ if step_index == steps_per_epoch - 1:
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+
+def batch_fit_loop(model,
+ inputs,
+ targets,
+ epoch_logs,
+ index_array,
+ out_labels,
+ callback_model,
+ batch_size,
+ sample_weights=None,
+ val_inputs=None,
+ val_targets=None,
+ val_sample_weights=None,
+ callbacks=None,
+ shuffle=True,
+ num_train_samples=None,
+ do_validation=False):
+ """Fit function for eager execution when input is given as arrays or tensors.
+
+ Updates the given epoch logs.
+
+ Arguments:
+ model: Instance of the `Model`.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ epoch_logs: Dictionary of logs from every epoch.
+ index_array: Index array generated from number of training samples.
+ out_labels: Output labels generated from model metric names.
+ callback_model: Instance of `Model` to callback.
+ batch_size: Integer batch size or None if unknown.
+ sample_weights: Optional list of sample weight arrays.
+ val_inputs: Input data for validation.
+ val_targets: Target data for validation.
+ val_sample_weights: Sample weight data for validation.
+ callbacks: List of callbacks to be called during training.
+ shuffle: Whether to shuffle the data at the beginning of each epoch.
+ num_train_samples: Integer number of training samples.
+ do_validation: Boolean value indicating whether we should do validation.
+ """
+ # TODO(psv): Create a dataset iterator instead of manually creating batches
+ # here and in batch_test_loop, batch_predict_loop.
+ if shuffle == 'batch':
+ index_array = model._batch_shuffle(index_array, batch_size)
+ elif shuffle:
+ np.random.shuffle(index_array)
+
+ batches = generic_utils.make_batches(num_train_samples, batch_size)
+
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ inputs_batch = slice_arrays(inputs, batch_ids, contiguous=not shuffle)
+ targets_batch = slice_arrays(targets, batch_ids, contiguous=not shuffle)
+ if sample_weights:
+ sample_weights_batch = slice_arrays(
+ sample_weights, batch_ids, contiguous=not shuffle)
+ else:
+ sample_weights_batch = None
+ batch_logs = {}
+ batch_logs['batch'] = batch_index
+ batch_logs['size'] = len(batch_ids)
+
+ callbacks.on_batch_begin(batch_index, batch_logs)
+
+ inputs_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in inputs_batch
+ ]
+ targets_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in targets_batch
+ ]
+ if sample_weights:
+ sample_weights_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights_batch
+ ]
+
+ outs, loss, loss_metrics = _process_single_batch(
+ model,
+ inputs_batch,
+ targets_batch,
+ sample_weights=sample_weights_batch,
+ training=True)
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ # Required for eager execution
+ metrics_results = _eager_metrics_fn(model, outs, targets_batch)
+ batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
+
+ for k, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_logs[k] = tensor_util.constant_value(v)
+ callbacks.on_batch_end(batch_index, batch_logs)
+ if callback_model.stop_training:
+ break
+
+ if batch_index == len(batches) - 1: # Last batch.
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+
+def iterator_test_loop(model, inputs, steps, verbose=0):
+ """Test function for eager execution when input is given as dataset iterator.
+
+ Arguments:
+ model: Model instance that is being evaluated in Eager mode.
+ inputs: Input dataset iterator.
+ steps: Total number of steps (batches of samples) before declaring
+ predictions finished.
+ verbose: Verbosity mode.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
+ """
+ assert isinstance(inputs, iterator_ops.EagerIterator)
+ outs = []
+ num_samples = 0
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=steps)
+ for step_index in range(steps):
+ # Get data from the iterator.
+ try:
+ next_element = inputs.get_next()
+ except errors.OutOfRangeError:
+ logging.warning(
+ 'Your dataset iterator ran out of data interrupting testing. '
+ 'Make sure that your dataset can generate at least `steps` batches '
+ '(in this case, %d batches).', steps)
+ break
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s' % next_element)
+ x, y = next_element
+
+ # Validate and standardize data.
+ x, y, sample_weights = model._standardize_user_data(x, y)
+
+ # Calculate model output, loss values.
+ loss_outs, loss, loss_metrics = _model_loss(
+ model, x, y, sample_weights=sample_weights, training=False)
+ metrics_results = _eager_metrics_fn(model, loss_outs, y)
+ batch_outs = []
+ for _, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_outs.append(tensor_util.constant_value(v))
+
+ # Get current step size.
+ if isinstance(x, list):
+ step_size = x[0].get_shape().as_list()[0]
+ else:
+ step_size = x.get_shape().as_list()[0]
+
+ # Accumulate results in output array.
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step_index == 0:
+ for _ in enumerate(batch_outs):
+ outs.append(0.)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out * step_size
+
+ # Calculate sample size.
+ num_samples += step_size
+ if verbose == 1:
+ progbar.update(step_index + 1)
+
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def batch_test_loop(model,
+ inputs,
+ targets,
+ batch_size,
+ sample_weights=None,
+ verbose=0):
+ """Test function for eager execution when input is given as arrays or tensors.
+
+ Arguments:
+ model: Model instance that is being evaluated in Eager mode.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ batch_size: Integer batch size.
+ sample_weights: Optional list of sample weight arrays.
+ verbose: Verbosity mode.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+ """
+ outs = []
+ feed_data = inputs + targets
+ if sample_weights:
+ feed_data += sample_weights
+ num_samples = training_utils.check_num_samples(
+ feed_data, batch_size=batch_size)
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=num_samples)
+ batches = generic_utils.make_batches(num_samples, batch_size)
+ index_array = np.arange(num_samples)
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ inputs_batch = slice_arrays(inputs, batch_ids)
+ targets_batch = slice_arrays(targets, batch_ids)
+ if sample_weights:
+ sample_weights_batch = slice_arrays(sample_weights, batch_ids)
+ else:
+ sample_weights_batch = None
+
+ inputs_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in inputs_batch
+ ]
+ targets_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in targets_batch
+ ]
+ if sample_weights:
+ sample_weights_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights_batch
+ ]
+
+ loss_outs, loss, loss_metrics = _model_loss(
+ model,
+ inputs_batch,
+ targets_batch,
+ sample_weights=sample_weights_batch,
+ training=False)
+ metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch)
+ batch_outs = []
+ for _, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_outs.append(tensor_util.constant_value(v))
+
+ if isinstance(batch_outs, list):
+ if batch_index == 0:
+ for _ in enumerate(batch_outs):
+ outs.append(0.)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out * len(batch_ids)
+ else:
+ if batch_index == 0:
+ outs.append(0.)
+ outs[0] += batch_outs * len(batch_ids)
+
+ if verbose == 1:
+ progbar.update(batch_end)
+
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def iterator_predict_loop(model, inputs, steps, verbose=0):
+ """Predict function for eager execution when input is dataset iterator.
+
+ Arguments:
+ model: Instance of `Model`.
+ inputs: Input dataset iterator.
+ steps: Total number of steps (batches of samples) before declaring
+ `_predict_loop` finished.
+ verbose: Verbosity mode.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions (if the model has multiple outputs).
+
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
+ """
+ assert isinstance(inputs, iterator_ops.EagerIterator)
+ outs = []
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=steps)
+ for step_index in range(steps):
+ # Get data from the iterator.
+ try:
+ next_element = inputs.get_next()
+ except errors.OutOfRangeError:
+ logging.warning(
+ 'Your dataset iterator ran out of data; '
+ 'interrupting prediction. Make sure that your '
+ 'dataset can generate at least `steps` '
+ 'batches (in this case, %d batches).', steps)
+ break
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError(
+ 'Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s. We do not use the '
+ '`target` value here.' % next_element)
+ x, _ = next_element
+
+ # Validate and standardize data.
+ x, _, _ = model._standardize_user_data(x)
+
+ if model._expects_training_arg:
+ batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
+ else:
+ batch_outs = model.call(x[0] if len(x) == 1 else x)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+
+ # We collect the results from every step and then concatenate them once
+ # in the end. This is an expensive process. We are doing this because we
+ # do not know the number of samples beforehand.
+ if step_index == 0:
+ for _ in batch_outs:
+ outs.append([])
+ for i, batch_out in enumerate(batch_outs):
+ outs[i].append(backend.get_value(batch_out))
+
+ if verbose == 1:
+ progbar.update(step_index + 1)
+ for i, out in enumerate(outs):
+ outs[i] = np.concatenate(tuple(out), axis=0)
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def batch_predict_loop(model, inputs, batch_size, verbose=0):
+ """Predict function for eager execution when input is arrays or tensors.
+
+ Arguments:
+ model: Instance of `Model`.
+ inputs: List of input arrays.
+ batch_size: Integer batch size.
+ verbose: Verbosity mode.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions (if the model has multiple outputs).
+ """
+ outs = []
+ num_samples = training_utils.check_num_samples(inputs, batch_size)
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=num_samples)
+ batches = generic_utils.make_batches(num_samples, batch_size)
+ index_array = np.arange(num_samples)
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ inputs_batch = slice_arrays(inputs, batch_ids)
+
+ inputs_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in inputs_batch
+ ]
+
+ if len(inputs_batch) == 1:
+ if model._expects_training_arg:
+ batch_outs = model.call(inputs_batch[0], training=False)
+ else:
+ batch_outs = model.call(inputs_batch[0])
+ else:
+ if model._expects_training_arg:
+ batch_outs = model.call(inputs_batch, training=False)
+ else:
+ batch_outs = model.call(inputs_batch)
+
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if batch_index == 0:
+ # Pre-allocate the results arrays.
+ for batch_out in batch_outs:
+ dims = batch_out.shape[1:].dims
+ dims_list = [d.value for d in dims]
+ shape = (num_samples,) + tuple(dims_list)
+ outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype))
+ for i, batch_out in enumerate(batch_outs):
+ outs[i][batch_start:batch_end] = batch_out
+ if verbose == 1:
+ progbar.update(batch_end)
+
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
def slice_arrays(arrays, indices, contiguous=True):
"""Slices batches out of provided arrays (workaround for eager tensors).
@@ -268,19 +814,24 @@ def train_on_batch(model, inputs, targets, sample_weights=None):
Returns:
total loss and the loss associated with each output.
"""
- inputs = [
- ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs]
- targets = [
- ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets]
- sample_weights = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- if val is not None else None for val in sample_weights]
+ if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ inputs = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
+ ]
+ targets = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets
+ ]
+ if sample_weights:
+ sample_weights = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights
+ ]
+
outs, loss, _ = _process_single_batch(
model, inputs, targets, sample_weights=sample_weights, training=True)
if not isinstance(outs, list):
outs = [outs]
- metrics_results = _eager_metrics_fn(
- model, outs, targets)
+ metrics_results = _eager_metrics_fn(model, outs, targets)
if not isinstance(loss, list):
loss = [loss]
return loss + metrics_results
@@ -298,48 +849,55 @@ def test_on_batch(model, inputs, targets, sample_weights=None):
Returns:
total loss, loss and metrics associated with each output.
"""
- inputs = [
- ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs]
- targets = [
- ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets]
- sample_weights = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- if val is not None else None for val in sample_weights]
- outs, loss, loss_metrics = _process_single_batch(
+ if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ inputs = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
+ ]
+ targets = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets
+ ]
+ if sample_weights:
+ sample_weights = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights
+ ]
+ outs, loss, loss_metrics = _model_loss(
model, inputs, targets, sample_weights=sample_weights, training=False)
if not isinstance(outs, list):
outs = [outs]
- metrics_results = _eager_metrics_fn(
- model, outs, targets)
+ metrics_results = _eager_metrics_fn(model, outs, targets)
if not isinstance(loss, list):
loss = [loss]
return loss + loss_metrics + metrics_results
-def fit_loop(
- model,
- inputs,
- targets,
- sample_weights=None,
- val_inputs=None,
- val_targets=None,
- val_sample_weights=None,
- batch_size=None,
- epochs=100,
- verbose=1,
- callbacks=None,
- shuffle=True,
- callback_metrics=None,
- initial_epoch=0,
- steps_per_epoch=None,
- validation_steps=None):
- """Abstract fit function for eager execution.
+def fit_loop(model,
+ inputs,
+ targets,
+ sample_weights=None,
+ class_weight=None,
+ val_inputs=None,
+ val_targets=None,
+ val_sample_weights=None,
+ batch_size=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ shuffle=True,
+ callback_metrics=None,
+ initial_epoch=0,
+ steps_per_epoch=None,
+ validation_steps=None):
+ """Fit function for eager execution.
Arguments:
model: Instance of the model that is being executed in Eager mode.
inputs: List of input arrays.
targets: List of target arrays.
sample_weights: Optional list of sample weight arrays.
+ class_weight: Optional class-weight array to weight the importance of
+ samples in `inputs` based on the class they belong to, as conveyed by
+ `targets`.
val_inputs: Input data for validation.
val_targets: Target data for validation.
val_sample_weights: Sample weight data for validation.
@@ -366,47 +924,40 @@ def fit_loop(
Raises:
ValueError: In case of invalid argument values.
"""
- if not batch_size:
- raise ValueError('With eager execution, `batch_size` should be specified.')
- if steps_per_epoch or validation_steps:
- raise ValueError('With eager execution, `steps_per_epoch` and '
- '`validation_steps` are not valid arguments '
- '(set `batch_size` instead).')
- # Required for Eager mode
+ # Required for eager execution
with backend.learning_phase_scope(1):
do_validation = False
if val_inputs:
do_validation = True
- if (verbose and inputs and hasattr(inputs[0], 'shape') and
- hasattr(val_inputs[0], 'shape')):
+ if (steps_per_epoch is None and verbose and inputs and
+ hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
print('Train on %d samples, validate on %d samples' %
(inputs[0].shape[0], val_inputs[0].shape[0]))
- if validation_steps:
- if steps_per_epoch is None:
- raise ValueError('Can only use `validation_steps` when doing step-wise '
- 'training, i.e. `steps_per_epoch` must be set.')
- do_validation = True
- out_labels = model.metrics_names
- if do_validation:
- callback_metrics = copy.copy(out_labels) + [
- 'val_' + n for n in out_labels
- ]
- else:
- callback_metrics = copy.copy(out_labels)
+ num_train_samples = None
+ out_labels = None
+ if steps_per_epoch is None or model._is_compiled:
+ out_labels = model.metrics_names
+ if do_validation:
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ callback_metrics = copy.copy(out_labels)
- if sample_weights:
- feed_data = inputs + targets + sample_weights
- else:
- feed_data = inputs + targets
- num_train_samples = training_utils.check_num_samples(
- feed_data,
- batch_size=batch_size,
- steps=steps_per_epoch,
- steps_name='steps_per_epoch')
+ if steps_per_epoch is None:
+ if sample_weights:
+ feed_data = inputs + targets + sample_weights
+ else:
+ feed_data = inputs + targets
+ num_train_samples = training_utils.check_num_samples(
+ feed_data,
+ batch_size=batch_size,
+ steps=steps_per_epoch,
+ steps_name='steps_per_epoch')
- if num_train_samples is not None:
- index_array = np.arange(num_train_samples)
+ if num_train_samples is not None:
+ index_array = np.arange(num_train_samples)
model.history = cbks.History()
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
@@ -441,6 +992,8 @@ def fit_loop(
for cbk in callbacks:
if not val_inputs:
cbk.validation_data = []
+ elif isinstance(val_inputs, iterator_ops.EagerIterator):
+ cbk.validation_data = val_inputs
elif val_sample_weights:
cbk.validation_data = val_inputs + val_targets + val_sample_weights
else:
@@ -449,87 +1002,48 @@ def fit_loop(
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
- if shuffle == 'batch':
- index_array = model._batch_shuffle(index_array, batch_size)
- elif shuffle:
- np.random.shuffle(index_array)
-
- batches = generic_utils.make_batches(num_train_samples, batch_size)
-
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- try:
- inputs_batch = slice_arrays(inputs, batch_ids,
- contiguous=not shuffle)
- targets_batch = slice_arrays(targets, batch_ids,
- contiguous=not shuffle)
- if sample_weights:
- sample_weights_batch = slice_arrays(sample_weights, batch_ids,
- contiguous=not shuffle)
- else:
- sample_weights_batch = None
- except TypeError:
- raise TypeError('TypeError while preparing batch. '
- 'If using HDF5 input data, '
- 'pass shuffle="batch".')
- batch_logs = {}
- batch_logs['batch'] = batch_index
- batch_logs['size'] = len(batch_ids)
-
- callbacks.on_batch_begin(batch_index, batch_logs)
-
- inputs_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in inputs_batch]
- targets_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in targets_batch]
- if sample_weights:
- sample_weights_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- if val is not None else None
- for val in sample_weights_batch]
-
- outs, loss, loss_metrics = _process_single_batch(
+
+ if steps_per_epoch is not None:
+ iterator_fit_loop(
model,
- inputs_batch,
- targets_batch,
- sample_weights=sample_weights_batch,
- training=True)
-
- if not isinstance(outs, list):
- outs = [outs]
-
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- # Required for Eager mode
- metrics_results = _eager_metrics_fn(model, outs, targets_batch)
- batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
-
- for k, v in zip(model.metrics_names,
- [backend.mean(loss)] + loss_metrics + metrics_results):
- batch_logs[k] = tensor_util.constant_value(v)
- callbacks.on_batch_end(batch_index, batch_logs)
- if callback_model.stop_training:
- break
-
- if batch_index == len(batches) - 1: # Last batch.
- if do_validation:
- val_outs = test_loop(
- model, val_inputs, val_targets,
- sample_weights=val_sample_weights,
- batch_size=batch_size,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
+ inputs,
+ class_weight,
+ steps_per_epoch=steps_per_epoch,
+ callback_model=callback_model,
+ out_labels=out_labels,
+ epoch_logs=epoch_logs,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ callback_metrics=callback_metrics,
+ validation_steps=validation_steps,
+ do_validation=do_validation)
+ else:
+ batch_fit_loop(
+ model,
+ inputs,
+ targets,
+ epoch_logs=epoch_logs,
+ index_array=index_array,
+ out_labels=out_labels,
+ callback_model=callback_model,
+ batch_size=batch_size,
+ sample_weights=sample_weights,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ callbacks=callbacks,
+ shuffle=shuffle,
+ num_train_samples=num_train_samples,
+ do_validation=do_validation)
callbacks.on_epoch_end(epoch, epoch_logs)
if callback_model.stop_training:
break
- callbacks.on_train_end()
- return model.history
+ callbacks.on_train_end()
+ return model.history
def test_loop(model, inputs, targets,
@@ -537,7 +1051,7 @@ def test_loop(model, inputs, targets,
batch_size=None,
verbose=0,
steps=None):
- """Abstract method to loop over some data in batches.
+ """Test function for eager execution.
Arguments:
model: Model instance that is being evaluated in Eager mode.
@@ -557,77 +1071,26 @@ def test_loop(model, inputs, targets,
the display labels for the scalar outputs.
"""
with backend.learning_phase_scope(0):
- feed_data = inputs + targets
- if sample_weights:
- feed_data += sample_weights
- num_samples = training_utils.check_num_samples(
- feed_data, batch_size=batch_size, steps=steps, steps_name='steps')
- outs = []
- if verbose == 1:
- progbar = generic_utils.Progbar(target=num_samples)
- batches = generic_utils.make_batches(num_samples, batch_size)
- index_array = np.arange(num_samples)
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- inputs_batch = slice_arrays(inputs, batch_ids)
- targets_batch = slice_arrays(targets, batch_ids)
- if sample_weights:
- sample_weights_batch = slice_arrays(sample_weights, batch_ids)
- else:
- sample_weights_batch = None
-
- inputs_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in inputs_batch]
- targets_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in targets_batch]
- if sample_weights:
- sample_weights_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- if val is not None else None
- for val in sample_weights_batch]
-
- loss_outs, loss, loss_metrics = _model_loss(
+ if steps is not None:
+ return iterator_test_loop(model, inputs, steps, verbose=verbose)
+ else:
+ return batch_test_loop(
model,
- inputs_batch,
- targets_batch,
- sample_weights=sample_weights_batch,
- training=False)
- metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch)
- batch_outs = []
- for _, v in zip(model.metrics_names,
- [backend.mean(loss)] + loss_metrics + metrics_results):
- batch_outs.append(tensor_util.constant_value(v))
-
- if isinstance(batch_outs, list):
- if batch_index == 0:
- for batch_out in enumerate(batch_outs):
- outs.append(0.)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out * len(batch_ids)
- else:
- if batch_index == 0:
- outs.append(0.)
- outs[0] += batch_outs * len(batch_ids)
-
- if verbose == 1:
- progbar.update(batch_end)
- for i in range(len(outs)):
- outs[i] /= num_samples
- if len(outs) == 1:
- return outs[0]
- return outs
+ inputs,
+ targets,
+ batch_size=batch_size,
+ sample_weights=sample_weights,
+ verbose=verbose)
def predict_loop(model, inputs,
batch_size=32,
verbose=0,
steps=None):
- """Abstract method to loop over some data in batches.
+ """Predict function for eager execution.
Arguments:
- model:
+ model: Instance of `Model`.
inputs: List of input arrays.
batch_size: integer batch size.
verbose: verbosity mode.
@@ -641,49 +1104,8 @@ def predict_loop(model, inputs,
(if the model has multiple outputs).
"""
with backend.learning_phase_scope(0):
- num_samples = training_utils.check_num_samples(
- inputs, batch_size, steps, 'steps')
- if verbose == 1:
- if steps is not None:
- progbar = generic_utils.Progbar(target=steps)
- else:
- progbar = generic_utils.Progbar(target=num_samples)
-
- outs = []
- batches = generic_utils.make_batches(num_samples, batch_size)
- index_array = np.arange(num_samples)
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- inputs_batch = slice_arrays(inputs, batch_ids)
-
- inputs_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in inputs_batch]
-
- if len(inputs_batch) == 1:
- if model._expects_training_arg:
- batch_outs = model.call(inputs_batch[0], training=False)
- else:
- batch_outs = model.call(inputs_batch[0])
- else:
- if model._expects_training_arg:
- batch_outs = model.call(inputs_batch, training=False)
- else:
- batch_outs = model.call(inputs_batch)
-
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if batch_index == 0:
- # Pre-allocate the results arrays.
- for batch_out in batch_outs:
- dims = batch_out.shape[1:].dims
- dims_list = [d.value for d in dims]
- shape = (num_samples,) + tuple(dims_list)
- outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype))
- for i, batch_out in enumerate(batch_outs):
- outs[i][batch_start:batch_end] = batch_out
- if verbose == 1:
- progbar.update(batch_end)
- if len(outs) == 1:
- return outs[0]
- return outs
+ if steps is not None:
+ return iterator_predict_loop(model, inputs, steps, verbose=verbose)
+ else:
+ return batch_predict_loop(
+ model, inputs, batch_size=batch_size, verbose=verbose)
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 58011a1412..cc2386a5bd 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -24,6 +24,7 @@ import unittest
import numpy as np
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras._impl import keras
@@ -1340,16 +1341,12 @@ class TestTrainingWithDataTensors(test.TestCase):
output_a_np)
# test fit
- out = model.fit(None,
- output_a_np, epochs=1, batch_size=10)
- out = model.fit(None,
- output_a_np, epochs=1, batch_size=10)
+ _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=3)
+ _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=3)
# test evaluate
- out = model.evaluate(None,
- output_a_np, batch_size=10)
- out = model.evaluate(None,
- output_a_np, batch_size=10)
+ _ = model.evaluate(None, output_a_np, steps=3)
+ _ = model.evaluate(None, output_a_np, steps=3)
# test predict
out = model.predict(None, steps=3)
@@ -1383,16 +1380,12 @@ class TestTrainingWithDataTensors(test.TestCase):
output_a_np)
# test fit
- out = model.fit(None,
- output_a_np, epochs=1, batch_size=10)
- out = model.fit(None,
- output_a_np, epochs=1, batch_size=10)
+ _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=10)
+ _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=10)
# test evaluate
- out = model.evaluate(None,
- output_a_np, batch_size=10)
- out = model.evaluate(None,
- output_a_np, batch_size=10)
+ _ = model.evaluate(None, output_a_np, steps=10)
+ _ = model.evaluate(None, output_a_np, steps=10)
# test predict
out = model.predict(None, steps=3)
@@ -1715,40 +1708,56 @@ class TestTrainingWithDataTensors(test.TestCase):
class TestTrainingWithDatasetIterators(test.TestCase):
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_training_and_eval_methods_on_iterators_single_io(self):
with self.test_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
- optimizer = 'rmsprop'
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
iterator = dataset.make_one_shot_iterator()
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=0)
- model.evaluate(iterator, steps=2, verbose=0)
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(iterator, steps=2, verbose=1)
model.predict(iterator, steps=2)
model.train_on_batch(iterator)
model.test_on_batch(iterator)
+ model.predict_on_batch(iterator)
+
# Test with validation data
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
validation_data=iterator, validation_steps=2)
# Test with validation split
- with self.assertRaisesRegexp(ValueError,
- 'you cannot use `validation_split`'):
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset iterator'):
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
validation_split=0.5, validation_steps=2)
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, '`sample_weight` argument is not supported '
+ 'when input `x` is a dataset iterator'):
+ model.fit(
+ iterator,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
+
# Test invalid usage
with self.assertRaisesRegexp(ValueError,
'Instead, pass an `Iterator`'):
@@ -1759,19 +1768,54 @@ class TestTrainingWithDatasetIterators(test.TestCase):
model.fit(iterator, iterator,
epochs=1, steps_per_epoch=2, verbose=0)
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(iterator, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(iterator, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(iterator, verbose=0)
+
+ def test_get_next_op_created_once(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ # Finalize graph to make sure we are not appending another iterator
+ # get_next op in the graph.
+ ops.get_default_graph().finalize()
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_iterators_running_out_of_data(self):
with self.test_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
- optimizer = 'rmsprop'
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(2)
dataset = dataset.batch(10)
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py
index 662938f421..04d80c891f 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py
@@ -22,6 +22,7 @@ import copy
import numpy as np
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras._impl.keras import backend as K
@@ -65,14 +66,7 @@ def check_num_samples(ins,
if steps is not None and batch_size is not None:
raise ValueError(
'If ' + steps_name + ' is set, the `batch_size` must be None.')
-
- if not ins or has_symbolic_tensors(ins):
- if steps is None:
- raise ValueError('If your data is in the form of symbolic tensors, '
- 'you should specify the `' + steps_name + '` argument '
- '(instead of the `batch_size` argument, '
- 'because symbolic tensors are expected to produce '
- 'batches of input data).')
+ if check_steps_argument(ins, steps, steps_name):
return None
if hasattr(ins[0], 'shape'):
return int(ins[0].shape[0])
@@ -551,8 +545,11 @@ def standardize_weights(y,
def has_symbolic_tensors(ls):
- return (any(tensor_util.is_tensor(v) for v in ls)
- and not context.executing_eagerly())
+ if context.executing_eagerly():
+ return False
+ if isinstance(ls, (list, tuple)):
+ return any(tensor_util.is_tensor(v) for v in ls)
+ return tensor_util.is_tensor(ls)
def populate_metric_names(model):
@@ -614,3 +611,77 @@ def add_metric_name(model, metric_name, index):
metric_name = '%s_%d' % (base_metric_name, j)
j += 1
model.metrics_names.append(metric_name)
+
+
+def validate_iterator_input(x, y, sample_weight, validation_split=None):
+ """Validates user input arguments when a dataset iterator is passed.
+
+ Arguments:
+ x: Input data. A `tf.data` dataset iterator.
+ y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
+ Expected to be `None` when `x` is a dataset iterator.
+ sample_weight: An optional sample-weight array passed by the user to
+ weight the importance of each sample in `x`. Expected to be `None` when
+ `x` is a dataset iterator
+ validation_split: Float between 0 and 1. Fraction of the training data to
+ be used as validation data. Expected to be `None` when `x` is a dataset
+ iterator.
+
+ Raises:
+ ValueError: if argument `y` or `sample_weight` or `validation_split` are
+ provided by user.
+ """
+ if y is not None:
+ raise ValueError('You passed a dataset iterator (%s) as input `x` to '
+ 'your model. In that case, you should not specify '
+ 'a target (`y`) argument, since the dataset iterator '
+ 'generates both input data and target data. '
+ 'Received: %s' % (x, y))
+ if sample_weight is not None:
+ raise ValueError('`sample_weight` argument is not supported when input'
+ ' `x` is a dataset iterator. '
+ 'Received: x=%s, sample_weight=%s' % (x, sample_weight))
+ if validation_split is not None and validation_split != 0.0:
+ raise ValueError(
+ '`validation_split` argument is not supported when '
+ 'input `x` is a dataset iterator. '
+ 'Received: x=%s, validation_split=%f' % (x, validation_split))
+
+
+def check_steps_argument(input_data, steps, steps_name):
+ """Validates `steps` argument based on input data's type.
+
+ The cases when `steps` value must be provided are when
+ 1. input data passed is an iterator.
+ 2. model was built on top of symbolic tensors, input data is not
+ required and is `None`.
+ 3. input data passed is a symbolic tensor.
+
+ Arguments:
+ input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
+ tf.data.Dataset iterator or `None`.
+ steps: Integer or `None`. Total number of steps (batches of samples) to
+ execute.
+ steps_name: The public API's parameter name for `steps`.
+
+ Returns:
+ boolean, True if `steps` argument is required, else False.
+
+ Raises:
+ ValueError: if `steps` argument is required for given input data type
+ but not provided.
+ """
+
+ is_x_iterator = (
+ isinstance(input_data, iterator_ops.Iterator) or
+ isinstance(input_data, iterator_ops.EagerIterator))
+
+ if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
+ (isinstance(input_data, list) and not input_data)):
+ if steps is None:
+ input_type_str = 'iterators' if is_x_iterator else 'data tensors'
+ raise ValueError('When using {input_type} as input to a model, you should'
+ ' specify the `{steps_name}` argument.'.format(
+ input_type=input_type_str, steps_name=steps_name))
+ return True
+ return False
diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
index 295ad47f6b..1e88dc09fb 100644
--- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
@@ -23,12 +23,15 @@ import os
import numpy as np
import six
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpointable
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@@ -248,6 +251,26 @@ class ModelSubclassingTest(test.TestCase):
model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0)
_ = model.evaluate(steps=10, verbose=0)
+ @test_util.run_in_graph_and_eager_modes()
+ def test_single_io_workflow_with_dataset_iterators(self):
+ num_classes = 2
+ num_samples = 10
+ input_dim = 50
+
+ with self.test_session():
+ model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True)
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ x = np.ones((num_samples, input_dim))
+ y = np.zeros((num_samples, num_classes))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=2, steps_per_epoch=10, verbose=0)
+ _ = model.evaluate(iterator, steps=10, verbose=0)
+
def test_multi_io_workflow_with_numpy_arrays_and_custom_placeholders(self):
num_classes = (2, 3)
@@ -583,6 +606,22 @@ class ModelSubclassingTest(test.TestCase):
loss = model.train_on_batch(x, y)
self.assertGreater(loss, 0.1)
+ def test_no_dependency(self):
+ class Foo(keras.Model):
+
+ def __init__(self):
+ super(Foo, self).__init__()
+ self.isdep = keras.layers.Dense(1)
+ self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
+ self.notdep_var = checkpointable.NoDependency(
+ resource_variable_ops.ResourceVariable(1., name='notdep_var'))
+
+ m = Foo()
+ self.assertEqual([m.isdep, m.notdep], m.layers)
+ self.assertEqual(1, len(m._checkpoint_dependencies))
+ self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
+ self.assertEqual('notdep_var:0', m.notdep_var.name)
+
class CustomCallModel(keras.Model):
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 77e6f5f1a0..843759fed0 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1847,6 +1847,23 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ def testGradInWhileWrtInitialLoopVal(self):
+ with self.test_session():
+ x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
+ y = x + 1
+
+ def body(i, v):
+ z = v * 2
+ return i + 1, gradients_impl.gradients(z, x)[0]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot compute gradient inside while loop with respect to op 'x'. "
+ "We do not support taking the gradient wrt or through the initial "
+ "value of a loop variable. Gradients can be computed through "
+ "loop invariants or wrt the input parameters to the loop body."):
+ control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+
def testWhileGradInWhile(self):
with self.test_session():
n = ops.convert_to_tensor(1.0, name="n")
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index b692d3da60..27804be65c 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -23,6 +23,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
@@ -292,6 +293,7 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(cache_values, value)
+ @test_util.enable_c_shapes
def testConv2DTransposeShapeInference(self):
# Test case for 8972
initializer = random_ops.truncated_normal(
@@ -301,7 +303,8 @@ class Conv2DTransposeTest(test.TestCase):
f_shape = array_ops.stack([array_ops.shape(x)[0], 10, 5, 5])
output = nn_ops.conv2d_transpose(
x, f, f_shape, strides=[1, 1, 1, 1], padding="SAME")
- self.assertEqual(output.get_shape().as_list(), [None, 10, 5, 5])
+ self.assertEqual(output.get_shape().as_list(), [3, 10, 5, 5])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index 5ec80b95ee..dc465c867f 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -147,6 +147,32 @@ class AssertCloseTest(test.TestCase):
array_ops.identity(w).eval(feed_dict=feed_dict)
+class MaybeGetStaticTest(test.TestCase):
+
+ def testGetStaticInt(self):
+ x = 2
+ self.assertEqual(x, du.maybe_get_static_value(x))
+ self.assertAllClose(
+ np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
+
+ def testGetStaticNumpyArray(self):
+ x = np.array(2, dtype=np.int32)
+ self.assertEqual(x, du.maybe_get_static_value(x))
+ self.assertAllClose(
+ np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
+
+ def testGetStaticConstant(self):
+ x = constant_op.constant(2, dtype=dtypes.int32)
+ self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x))
+ self.assertAllClose(
+ np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
+
+ def testGetStaticPlaceholder(self):
+ x = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ self.assertEqual(None, du.maybe_get_static_value(x))
+ self.assertEqual(None, du.maybe_get_static_value(x, dtype=np.float64))
+
+
@test_util.with_c_api
class GetLogitsAndProbsTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 052f11f92e..91be80322c 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -85,7 +85,10 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
shard_count = 5,
- tags = ["noasan"], # times out b/63678675
+ tags = [
+ "noasan", # times out, b/63678675
+ "optonly", # times out, b/79171797
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 098f9724a2..49855200c2 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -43,6 +43,7 @@ def scalar_shape():
return ops.convert_to_tensor([], dtype=dtypes.int32)
+@test_util.with_c_shapes
class ListOpsTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 984192258c..3daf07ea63 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -400,6 +400,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.var_is_initialized_op(abc.handle)),
True)
+ def testScatterBool(self):
+ with context.eager_mode():
+ ref = resource_variable_ops.ResourceVariable(
+ [False, True, False], trainable=False)
+ indices = math_ops.range(3)
+ updates = constant_op.constant([True, True, True])
+ state_ops.scatter_update(ref, indices, updates)
+ self.assertAllEqual(ref.read_value(), [True, True, True])
+
@test_util.run_in_graph_and_eager_modes()
def testConstraintArg(self):
constraint = lambda x: x
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 918bbd38ed..c0b36f143d 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -438,7 +438,6 @@ class TensorArrayTest(test.TestCase):
"Tried to read from index 3 but array size is: 3"):
self.evaluate(ta.read(3))
- @test_util.run_in_graph_and_eager_modes()
def testTensorArrayWriteMultipleFails(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 306055d202..cabc1e724c 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -1169,19 +1169,35 @@ def _assert_same_base_type(items, expected_type=None):
Raises:
ValueError: If any types do not match.
"""
- original_item_str = None
+ original_expected_type = expected_type
+ mismatch = False
for item in items:
if item is not None:
item_type = item.dtype.base_dtype
if not expected_type:
expected_type = item_type
- original_item_str = item.name if hasattr(item, 'name') else str(item)
elif expected_type != item_type:
- raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
- item.name if hasattr(item, 'name') else str(item),
- item_type, expected_type,
- (' as %s' % original_item_str) if original_item_str else ''))
- return expected_type
+ mismatch = True
+ break
+ if mismatch:
+ # Loop back through and build up an informative error message (this is very
+ # slow, so we don't do it unless we found an error above).
+ expected_type = original_expected_type
+ original_item_str = None
+ for item in items:
+ if item is not None:
+ item_type = item.dtype.base_dtype
+ if not expected_type:
+ expected_type = item_type
+ original_item_str = item.name if hasattr(item, 'name') else str(item)
+ elif expected_type != item_type:
+ raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
+ item.name if hasattr(item, 'name') else str(item),
+ item_type, expected_type,
+ (' as %s' % original_item_str) if original_item_str else ''))
+ return expected_type # Should be unreachable
+ else:
+ return expected_type
@tf_export('assert_same_float_dtype')
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 07d4ff7b02..5f60dab6ac 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -43,6 +43,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
# go/tf-wildcard-import
@@ -1433,6 +1434,8 @@ def ZerosLikeOutsideLoop(op, index):
"""Create zeros_like for the specified output of an op."""
val = op.outputs[index]
if not util.IsSwitch(op):
+ if val.dtype == dtypes.resource:
+ return array_ops.zeros(gen_resource_variable_ops.variable_shape(val))
return array_ops.zeros_like(val, optimize=False)
else:
op_ctxt = op._get_control_flow_context()
@@ -1441,6 +1444,10 @@ def ZerosLikeOutsideLoop(op, index):
pred = op_ctxt.pred
branch = op_ctxt.branch
switch_val = switch(op.inputs[0], pred)[1 - branch]
+ if val.dtype == dtypes.resource:
+ with ops.control_dependencies([switch_val]):
+ return array_ops.zeros(
+ gen_resource_variable_ops.variable_shape(switch_val))
zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
# Ensure ops created within array_ops.zeros are dominated by switch in
# cond context.
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 36eee5ce78..caceadf53a 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -527,8 +528,6 @@ class Bijector(object):
ValueError: If a member of `graph_parents` is not a `Tensor`.
"""
self._graph_parents = graph_parents or []
- forward_min_event_ndims = get_static_value(forward_min_event_ndims)
- inverse_min_event_ndims = get_static_value(inverse_min_event_ndims)
if forward_min_event_ndims is None and inverse_min_event_ndims is None:
raise ValueError("Must specify at least one of `forward_min_event_ndims` "
@@ -538,12 +537,23 @@ class Bijector(object):
elif forward_min_event_ndims is None:
forward_min_event_ndims = inverse_min_event_ndims
+ if not isinstance(forward_min_event_ndims, int):
+ raise TypeError("Expected forward_min_event_ndims to be of "
+ "type int, got {}".format(
+ type(forward_min_event_ndims).__name__))
+
+ if not isinstance(inverse_min_event_ndims, int):
+ raise TypeError("Expected inverse_min_event_ndims to be of "
+ "type int, got {}".format(
+ type(inverse_min_event_ndims).__name__))
+
if forward_min_event_ndims < 0:
raise ValueError("forward_min_event_ndims must be a non-negative "
"integer.")
if inverse_min_event_ndims < 0:
raise ValueError("inverse_min_event_ndims must be a non-negative "
"integer.")
+
self._forward_min_event_ndims = forward_min_event_ndims
self._inverse_min_event_ndims = inverse_min_event_ndims
self._is_constant_jacobian = is_constant_jacobian
@@ -994,7 +1004,6 @@ class Bijector(object):
def _reduce_jacobian_det_over_event(
self, y, ildj, min_event_ndims, event_ndims):
"""Reduce jacobian over event_ndims - min_event_ndims."""
- assert_static(min_event_ndims)
if not self.is_constant_jacobian:
return math_ops.reduce_sum(
@@ -1012,7 +1021,7 @@ class Bijector(object):
axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
# The multiplication by ones can change the inferred static shape so we try
# to recover as much as possible.
- event_ndims_ = get_static_value(event_ndims)
+ event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
if (event_ndims_ is not None and
y.shape.ndims is not None and
ildj.shape.ndims is not None):
@@ -1027,8 +1036,7 @@ class Bijector(object):
def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
"""Compute the reduction dimensions given event_ndims."""
- assert_static(min_event_ndims)
- event_ndims_ = get_static_value(event_ndims, np.int32)
+ event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
if event_ndims_ is not None:
return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
@@ -1038,8 +1046,7 @@ class Bijector(object):
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
"""Check whether event_ndims is atleast min_event_ndims."""
- assert_static(min_event_ndims)
- event_ndims_ = get_static_value(event_ndims, np.int32)
+ event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
assertions = []
if event_ndims_ is not None:
if min_event_ndims > event_ndims_:
@@ -1051,21 +1058,15 @@ class Bijector(object):
check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
return assertions
+ def _maybe_get_event_ndims_statically(self, event_ndims):
+ """Helper which returns tries to return an integer static value."""
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
-def get_static_value(x, dtype=None):
- """Helper which returns static value; casting when dtype is preferred."""
- if x is None:
- return x
- try:
- x_ = tensor_util.constant_value(x)
- except TypeError:
- x_ = x
- if x_ is None or dtype is None:
- return x_
- return np.array(x_, dtype)
-
+ if isinstance(event_ndims_, np.ndarray):
+ if (event_ndims_.dtype not in (np.int32, np.int64) or
+ len(event_ndims_.shape)):
+ raise ValueError("Expected a scalar integer, got {}".format(
+ event_ndims_))
+ event_ndims_ = event_ndims_.tolist()
-def assert_static(x):
- """Helper which asserts that input arg is known statically."""
- if x is None or type(x) != type(get_static_value(x)): # pylint: disable=unidiomatic-typecheck
- raise TypeError("Input must be known statically.")
+ return event_ndims_
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 2e067eab45..3afa85fda0 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -162,6 +162,30 @@ def same_dynamic_shape(a, b):
lambda: constant_op.constant(False))
+def maybe_get_static_value(x, dtype=None):
+ """Helper which tries to return a static value.
+
+ Given `x`, extract it's value statically, optionally casting to a specific
+ dtype. If this is not possible, None is returned.
+
+ Args:
+ x: `Tensor` for which to extract a value statically.
+ dtype: Optional dtype to cast to.
+
+ Returns:
+ Statically inferred value if possible, otherwise None.
+ """
+ if x is None:
+ return x
+ try:
+ x_ = tensor_util.constant_value(x)
+ except TypeError:
+ x_ = x
+ if x_ is None or dtype is None:
+ return x_
+ return np.array(x_, dtype)
+
+
def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 1448151fef..069b5a4308 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -297,7 +297,8 @@ def _DefaultGradYs(grad_ys,
def _IsTrainable(tensor):
dtype = dtypes.as_dtype(tensor.dtype)
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.complex64, dtypes.complex128)
+ dtypes.complex64, dtypes.complex128,
+ dtypes.resource)
def _IsBackpropagatable(tensor):
@@ -417,6 +418,30 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn()
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
+ """Raises an error if we backprop through a loop var."""
+ # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
+ # message.
+ target_op = None
+ queue = collections.deque([op])
+ visited = set()
+ while queue:
+ curr_op = queue.popleft()
+ if curr_op in visited: continue
+ visited.add(curr_op)
+ if curr_op in from_ops:
+ target_op = curr_op
+ break
+ queue.extend(t.op for t in curr_op.inputs)
+ assert target_op
+ raise ValueError(
+ "Cannot compute gradient inside while loop with respect to op '%s'. "
+ "We do not support taking the gradient wrt or through the initial value "
+ "of a loop variable. Gradients can be computed through loop invariants "
+ "or wrt the input parameters to the loop body."
+ % target_op.name)
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -629,6 +654,21 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
(op.name, op.type))
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
+
+ # NOTE(skyewm): We don't support computing gradients wrt a loop variable
+ # unless it's within the context of a single iteration (i.e. the
+ # gradient is wrt to the loop parameter in the body function, not wrt or
+ # through the initial value). This means if we're in a while loop
+ # context, we should never see a switch node from this context.
+ # pylint: disable=protected-access
+ if (control_flow_util.IsSwitch(op) and
+ op._control_flow_context is not None and
+ op._control_flow_context.IsWhileContext() and
+ op._control_flow_context ==
+ ops.get_default_graph()._get_control_flow_context()):
+ _RaiseNoGradWrtInitialLoopValError(op, from_ops)
+ # pylint: enable=protected-access
+
if (grad_fn or is_func_call) and has_out_grads:
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 5e8b8822ef..e729950201 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -944,6 +944,21 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
# Smoke test to ensure numpy inputs are accepted
F(x)
+ def testRVGradientsDynamicCond(self):
+ with self.test_session():
+ alpha = resource_variable_ops.ResourceVariable(
+ np.random.random((1,)),
+ dtype="float32")
+
+ conditional = array_ops.placeholder_with_default(True, shape=())
+ output = control_flow_ops.cond(
+ conditional, lambda: alpha * 2, lambda: alpha * 3)
+
+ g, = gradients_impl.gradients(output, alpha)
+ variables.global_variables_initializer().run()
+ self.assertAllEqual(g.eval(), [2.0])
+ self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index d2f45ce37b..cc92da4fd7 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
+import weakref
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -395,69 +396,8 @@ class _GraphTensorArray(object):
# pylint: enable=protected-access
-# pylint: disable=protected-access
-def _eager_write_no_copy(ta, index, value):
- """Writes value into an _EagerTensorArray without creating a new TensorArray.
-
- Args:
- ta: _EagerTensorArray into which to write value.
- index: 0-D. int32 scalar with the index to write to.
- value: N-D. Tensor of type `dtype`. The Tensor to write to this index.
-
- Raises:
- errors_impl.AlreadyExistsError: attempting to overwrite an entry.
- errors_impl.InvalidArgumentError: value dtype does not match `ta`'s dtype.
- errors_impl.OutOfRangeError: `index` is out of bounds.
- ValueError: shape of `value` is not consistent with inferred shape.
- """
-
- if isinstance(index, ops.EagerTensor):
- index = index.numpy()
-
- if index < 0:
- raise errors_impl.OutOfRangeError(
- None, None,
- "Writing to negative indices (index %d) is not allowed." % index)
-
- tensor_array = ta._tensor_array
- size = len(tensor_array)
- if index >= size:
- if not ta._dynamic_size:
- raise errors_impl.OutOfRangeError(
- None, None,
- "Tried to write to index %d but array is not resizeable and size "
- "is: %d" % (index, size))
- tensor_array.extend([None for _ in range(index - size + 1)])
-
- if not isinstance(value, ops.EagerTensor):
- value = constant_op.constant(value)
-
- if ta._infer_shape:
- if ta._element_shape is None:
- ta._element_shape = value.shape
- elif ta._element_shape != value.shape:
- raise ValueError("Incompatible shape for value (%s), expected (%s)" %
- (value.shape.as_list(), ta._element_shape.as_list()))
-
- if ta._dtype != value.dtype:
- raise errors_impl.InvalidArgumentError(
- None, None,
- "TensorArray dtype is %s but Op is trying to write dtype %s" %
- (ta._dtype.name, value.dtype.name))
-
- if ta._tensor_array[index] is not None:
- raise errors_impl.AlreadyExistsError(
- None, None,
- "Could not write to TensorArray index %d because it has already been "
- "written to." % index)
-
- tensor_array[index] = value
-
-# pylint: enable=protected-access
-
-
class _EagerTensorArray(object):
- """Eager-mode implementation of TensorArray.
+ """Eager-compatible implementation of TensorArray.
"""
def __init__(self,
@@ -472,7 +412,7 @@ class _EagerTensorArray(object):
element_shape=None,
colocate_with_first_write_call=True,
name=None):
- """Constructs an Eager mode TensorArray.
+ """Constructs a TensorArray compatible with eager execution.
Args:
dtype: (required) data type of the TensorArray.
@@ -495,16 +435,19 @@ class _EagerTensorArray(object):
ValueError: handle or flow are supplied, or if size is not supplied.
"""
- del (flow, tensor_array_name, name) # not meaningful in Eager
+ del (flow, tensor_array_name, name) # Unused.
if handle is not None:
- raise ValueError("TensorArray handles are not supported in Eager mode.")
+ raise ValueError("TensorArray handles are not supported when eager "
+ "execution is enabled.")
if size is None:
- raise ValueError("Size must be declared for TensorArrays in Eager mode.")
+ raise ValueError("Size must be declared for TensorArrays when eager "
+ "execution is enabled.")
- # These attributes are not meaningful in Eager, but some library functions
- # (e.g., those in control_flow_ops.py) access them to create new tensor
- # arrays; as such, we define them for the sake of compatibility.
+ # These attributes are not meaningful when eager is enabled, but some
+ # library functions (e.g., those in control_flow_ops.py) access them to
+ # create new tensor arrays; as such, we define them for the sake of
+ # compatibility.
self._handle = None
# we assign a dummy value to _flow in case other code assumes it to be
# a Tensor
@@ -525,7 +468,7 @@ class _EagerTensorArray(object):
@property
def flow(self):
- """Flows are not meaningful in Eager; this exists for compatibility."""
+ """For compatibility; flows are not meaningful when eager is enabled."""
return self._flow
@property
@@ -534,42 +477,22 @@ class _EagerTensorArray(object):
@property
def handle(self):
- """Handles are not meaningful in Eager; this exists for compatibility."""
+ """For compatibility; handles are not meaningful when eager is enabled."""
return self._handle
- def _identity_without_array(self):
- """Returns a new TensorArray with the same properties as this Eager one.
-
- NB: Does not set the underlying _tensor_array attribute.
- """
- ta = TensorArray(
- dtype=self._dtype,
- size=len(self._tensor_array),
- dynamic_size=self._dynamic_size,
- clear_after_read=self._clear_after_read,
- handle=self._handle,
- flow=self._flow,
- infer_shape=self._infer_shape,
- element_shape=self._element_shape,
- colocate_with_first_write_call=self._colocate_with_first_write_call)
- ta._implementation._previously_read_indices = self._previously_read_indices # pylint: disable=protected-access
- return ta
-
def identity(self):
"""See TensorArray."""
- ta = self._identity_without_array()
- ta._implementation._tensor_array = [t for t in self._tensor_array] # pylint: disable=protected-access
- return ta
+ return self.parent()
def grad(self, source, flow=None, name=None):
raise NotImplementedError(
- "TensorArray.grad is not supported in Eager mode; Eager's gradient "
- "implementation does not use/need this function to compute gradients "
- "of operations that use TensorArrays.")
+ "TensorArray.grad is not supported when executing eagerly; eager's "
+ "gradient implementation does not use/need this function to compute "
+ "gradients of operations that use TensorArrays.")
def read(self, index, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
if isinstance(index, ops.EagerTensor):
index = index.numpy()
@@ -600,12 +523,58 @@ class _EagerTensorArray(object):
self._previously_read_indices.append(index)
return tensor
+ def _write(self, index, value):
+ """Writes `value` into index named by `index`.
+
+ Args:
+ index: 0-D. int32 scalar with the index to write to.
+ value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`.
+
+ Raises:
+ errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
+ errors_impl.OutOfRangeError: `index` is out of bounds.
+ ValueError: shape of `value` is not consistent with inferred shape.
+ """
+
+ if isinstance(index, ops.EagerTensor):
+ index = index.numpy()
+
+ if index < 0:
+ raise errors_impl.OutOfRangeError(
+ None, None,
+ "Writing to negative indices (index %d) is not allowed." % index)
+
+ size = len(self._tensor_array)
+ if index >= size:
+ if not self._dynamic_size:
+ raise errors_impl.OutOfRangeError(
+ None, None,
+ "Tried to write to index %d but array is not resizeable and size "
+ "is: %d" % (index, size))
+ self._tensor_array.extend([None for _ in range(index - size + 1)])
+
+ if not isinstance(value, ops.EagerTensor):
+ value = constant_op.constant(value)
+
+ if self._infer_shape:
+ if self._element_shape is None:
+ self._element_shape = value.shape
+ elif self._element_shape != value.shape:
+ raise ValueError("Incompatible shape for value (%s), expected (%s)" %
+ (value.shape.as_list(), self._element_shape.as_list()))
+
+ if self._dtype != value.dtype:
+ raise errors_impl.InvalidArgumentError(
+ None, None,
+ "TensorArray dtype is %s but Op is trying to write dtype %s" %
+ (self._dtype.name, value.dtype.name))
+ self._tensor_array[index] = value
+
def write(self, index, value, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
- ta = self.identity()
- _eager_write_no_copy(ta._implementation, index, value) # pylint: disable=protected-access
- return ta
+ del name # not meaningful when executing eagerly.
+ self._write(index, value)
+ return self.parent()
def _maybe_zero(self, ix):
val = self._tensor_array[ix]
@@ -623,7 +592,7 @@ class _EagerTensorArray(object):
def gather(self, indices, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
return array_ops.stack([self._maybe_zero(i) for i in indices.numpy()])
def concat(self, name=None):
@@ -651,17 +620,15 @@ class _EagerTensorArray(object):
raise ValueError(
"Cannot unstack %d tensors into a TensorArray of static size %d" %
(len(tensors), len(self._tensor_array)))
- ta = self._identity_without_array()
- ta._implementation._tensor_array = tensors # pylint: disable=protected-access
- return ta
+ self._tensor_array = tensors
+ return self.parent()
def scatter(self, indices, value, name=None):
"""See TensorArray."""
- del name # unused in Eager
- ta = self.identity()
+ del name # not meaningful when executing eagerly.
for index, val in zip(indices.numpy(), array_ops.unstack(value)):
- _eager_write_no_copy(ta._implementation, index, val) # pylint: disable=protected-access
- return ta
+ self._write(index, val) # pylint: disable=protected-access
+ return self.parent()
def split(self, value, lengths, name=None):
"""See TensorArray."""
@@ -690,20 +657,17 @@ class _EagerTensorArray(object):
"dynamically resizeable" % (len(self._tensor_array),
lengths.shape[0]))
else:
- ta = self._identity_without_array()
- tensor_array = array_ops.split(value, lengths, name=name)
- ta._implementation._tensor_array = tensor_array # pylint: disable=protected-access
- return ta
+ self._tensor_array = array_ops.split(value, lengths, name=name)
+ return self.parent()
def size(self, name=None):
"""See TensorArray."""
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
return constant_op.constant(len(self._tensor_array))
def close(self, name=None):
- del name # not meaningful in Eager mode
+ del name # not meaningful when executing eagerly.
del self._tensor_array[:]
- return
# TensorArray is designed to hide an underlying implementation object
@@ -789,6 +753,8 @@ class TensorArray(object):
colocate_with_first_write_call=colocate_with_first_write_call,
name=name)
+ self._implementation.parent = weakref.ref(self)
+
@property
def flow(self):
"""The flow `Tensor` forcing ops leading to this TensorArray state."""
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 01903ae596..8f1d5a099f 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -169,6 +169,25 @@ class SavedModelBuilder(object):
raise TypeError("main_op needs to be an Operation: %r" % main_op)
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
+ def _add_train_op(self, train_op):
+ """Add train op to the SavedModel.
+
+ Note that this functionality is in development, and liable to be
+ moved elsewhere.
+
+ Args:
+ train_op: Op or group of ops that are used for training. These are
+ stored as a collection with key TRAIN_OP_KEY, but not executed.
+
+ Raises:
+ TypeError if Train op is not of type `Operation`.
+ """
+ if train_op is not None:
+ if (not isinstance(train_op, ops.Tensor) and
+ not isinstance(train_op, ops.Operation)):
+ raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
+ ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
+
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
"""Tags the meta graph def and adds it to the SavedModel.
@@ -239,6 +258,20 @@ class SavedModelBuilder(object):
for outputs_key in outputs:
self._validate_tensor_info(outputs[outputs_key])
+ def _add_collections(
+ self, assets_collection, legacy_init_op, main_op, train_op):
+ """Add asset and op collections to be saved."""
+ # Save asset files and write them to disk, if any.
+ self._save_and_write_assets(assets_collection)
+
+ if main_op is None:
+ # Add legacy init op to the SavedModel.
+ self._maybe_add_legacy_init_op(legacy_init_op)
+ else:
+ self._add_main_op(main_op)
+
+ self._add_train_op(train_op)
+
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -286,14 +319,8 @@ class SavedModelBuilder(object):
# properly populated.
self._validate_signature_def_map(signature_def_map)
- # Save asset files and write them to disk, if any.
- self._save_and_write_assets(assets_collection)
-
- if main_op is None:
- # Add legacy init op to the SavedModel.
- self._maybe_add_legacy_init_op(legacy_init_op)
- else:
- self._add_main_op(main_op)
+ # Add assets and ops
+ self._add_collections(assets_collection, legacy_init_op, main_op, None)
# Initialize a saver to generate a sharded output for all saveables in the
# current scope.
@@ -352,6 +379,7 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+
"""
# pylint: enable=line-too-long
if self._has_saved_variables:
@@ -363,8 +391,8 @@ class SavedModelBuilder(object):
# properly populated.
self._validate_signature_def_map(signature_def_map)
- # Save asset files and write them to disk, if any.
- self._save_and_write_assets(assets_collection)
+ # Add assets and ops
+ self._add_collections(assets_collection, legacy_init_op, main_op, None)
# Create the variables sub-directory, if it does not exist.
variables_dir = os.path.join(
@@ -377,12 +405,6 @@ class SavedModelBuilder(object):
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))
- if main_op is None:
- # Add legacy init op to the SavedModel.
- self._maybe_add_legacy_init_op(legacy_init_op)
- else:
- self._add_main_op(main_op)
-
# Initialize a saver to generate a sharded output for all saveables in the
# current scope.
saver = tf_saver.Saver(
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index 34206c6f6d..61c6ffbd0d 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -41,6 +41,10 @@ MAIN_OP_KEY = "saved_model_main_op"
tf_export("saved_model.constants.MAIN_OP_KEY").export_constant(
__name__, "MAIN_OP_KEY")
+# CollectionDef key for the SavedModel train op.
+# Not exported while export_all_saved_models is in contrib.
+TRAIN_OP_KEY = "saved_model_train_op"
+
# Schema version for SavedModel.
SAVED_MODEL_SCHEMA_VERSION = 1
tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant(
@@ -65,3 +69,5 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
VARIABLES_FILENAME = "variables"
tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant(
__name__, "VARIABLES_FILENAME")
+
+
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 804255375e..a4d994fd43 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -734,6 +734,96 @@ class SavedModelTest(test.TestCase):
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
+ def testTrainOp(self):
+ export_dir = self._get_export_dir("test_train_op")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Add `v1` and `v2` variables to the graph.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+ v2 = variables.Variable(2, name="v2")
+ ops.add_to_collection("v", v2)
+
+ sess.run(variables.global_variables_initializer())
+ train_op = state_ops.assign_add(v1, v2)
+
+ sess.run(train_op)
+ # TODO(karmel): remove explicit call when in the public method.
+ builder._add_train_op(train_op)
+ builder.add_meta_graph_and_variables(sess, ["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ self.assertEqual(3, ops.get_collection("v")[0].eval())
+ self.assertEqual(2, ops.get_collection("v")[1].eval())
+ self.assertIsInstance(
+ ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
+
+ def testTrainOpGroup(self):
+ export_dir = self._get_export_dir("test_train_op_group")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Add `v1` and `v2` variables to the graph.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+ v2 = variables.Variable(2, name="v2")
+ ops.add_to_collection("v", v2)
+
+ sess.run(variables.global_variables_initializer())
+ train_op = control_flow_ops.group()
+
+ sess.run(train_op)
+ # TODO(karmel): remove explicit call when in the public method.
+ builder._add_train_op(train_op)
+ builder.add_meta_graph_and_variables(sess, ["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ self.assertEqual(1, ops.get_collection("v")[0].eval())
+ self.assertEqual(2, ops.get_collection("v")[1].eval())
+ self.assertIsInstance(
+ ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
+
+ def testTrainOpAfterVariables(self):
+ export_dir = self._get_export_dir("test_train_op_after_variables")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Add `v1` and `v2` variables to the graph.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+ v2 = variables.Variable(2, name="v2")
+ ops.add_to_collection("v", v2)
+
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(sess, ["pre_foo"])
+
+ train_op = state_ops.assign_add(v1, v2)
+ sess.run(train_op)
+ # TODO(karmel): remove explicit call when in the public method.
+ builder._add_train_op(train_op)
+ builder.add_meta_graph(["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ self.assertIsInstance(
+ ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["pre_foo"], export_dir)
+ self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
+
def testMultipleAssets(self):
export_dir = self._get_export_dir("test_multiple_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py
index 819f351291..99007a9634 100644
--- a/tensorflow/python/saved_model/signature_constants.py
+++ b/tensorflow/python/saved_model/signature_constants.py
@@ -94,3 +94,9 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant(
__name__, "REGRESS_OUTPUTS")
################################################################################
+# Train/Eval API constants.
+# Not exported while export_all_saved_models is in contrib.
+
+SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training"
+
+SUPERVISED_EVAL_METHOD_NAME = "tensorflow/supervised/eval"
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
index ea0f52f17e..27d6b70e9d 100644
--- a/tensorflow/python/saved_model/signature_def_utils.py
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -26,6 +26,8 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio
from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
+from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def
+from tensorflow.python.saved_model.signature_def_utils_impl import supervised_train_signature_def
# pylint: enable=unused-import
del absolute_import
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index d033159188..f8ad788f77 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -185,6 +185,62 @@ def predict_signature_def(inputs, outputs):
return signature_def
+def supervised_train_signature_def(
+ inputs, loss, predictions=None, metrics=None):
+ return _supervised_signature_def(
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
+ predictions=predictions, metrics=metrics)
+
+
+def supervised_eval_signature_def(
+ inputs, loss, predictions=None, metrics=None):
+ return _supervised_signature_def(
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
+ predictions=predictions, metrics=metrics)
+
+
+def _supervised_signature_def(
+ method_name, inputs, loss=None, predictions=None,
+ metrics=None):
+ """Creates a signature for training and eval data.
+
+ This function produces signatures that describe the inputs and outputs
+ of a supervised process, such as training or evaluation, that
+ results in loss, metrics, and the like. Note that this function only requires
+ inputs to be not None.
+
+ Args:
+ method_name: Method name of the SignatureDef as a string.
+ inputs: dict of string to `Tensor`.
+ loss: dict of string to `Tensor` representing computed loss.
+ predictions: dict of string to `Tensor` representing the output predictions.
+ metrics: dict of string to `Tensor` representing metric ops.
+
+ Returns:
+ A train- or eval-flavored signature_def.
+
+ Raises:
+ ValueError: If inputs or outputs is `None`.
+ """
+ if inputs is None or not inputs:
+ raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
+
+ signature_inputs = {key: utils.build_tensor_info(tensor)
+ for key, tensor in inputs.items()}
+
+ signature_outputs = {}
+ for output_set in (loss, predictions, metrics):
+ if output_set is not None:
+ sig_out = {key: utils.build_tensor_info(tensor)
+ for key, tensor in output_set.items()}
+ signature_outputs.update(sig_out)
+
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs, method_name)
+
+ return signature_def
+
+
@tf_export('saved_model.signature_def_utils.is_valid_signature')
def is_valid_signature(signature_def):
"""Determine whether a SignatureDef can be served by TensorFlow Serving."""
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index b2bd14db8c..ebc5450633 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -180,6 +180,101 @@ class SignatureDefUtilsTest(test.TestCase):
self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
+ def testTrainSignatureDef(self):
+ self._testSupervisedSignatureDef(
+ signature_def_utils_impl.supervised_train_signature_def,
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
+
+ def testEvalSignatureDef(self):
+ self._testSupervisedSignatureDef(
+ signature_def_utils_impl.supervised_eval_signature_def,
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME)
+
+ def _testSupervisedSignatureDef(self, fn_to_test, method_name):
+ inputs = {
+ "input-1": constant_op.constant("a", name="input-1"),
+ "input-2": constant_op.constant("b", name="input-2"),
+ }
+ loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
+ predictions = {
+ "classes": constant_op.constant([100], name="classes"),
+ }
+ metrics_val = constant_op.constant(100.0, name="metrics_val")
+ metrics = {
+ "metrics/value": metrics_val,
+ "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"),
+ }
+
+ signature_def = fn_to_test(inputs, loss, predictions, metrics)
+
+ self.assertEqual(method_name, signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(2, len(signature_def.inputs))
+ input1_tensor_info_actual = (signature_def.inputs["input-1"])
+ self.assertEqual("input-1:0", input1_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
+ input2_tensor_info_actual = (signature_def.inputs["input-2"])
+ self.assertEqual("input-2:0", input2_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(4, len(signature_def.outputs))
+ self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name)
+ self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype)
+
+ self.assertEqual("classes:0", signature_def.outputs["classes"].name)
+ self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim))
+
+ self.assertEqual(
+ "metrics_val:0", signature_def.outputs["metrics/value"].name)
+ self.assertEqual(
+ types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
+
+ self.assertEqual(
+ "metrics_op:0", signature_def.outputs["metrics/update_op"].name)
+ self.assertEqual(
+ types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
+
+ def testTrainSignatureDefMissingInputs(self):
+ self._testSupervisedSignatureDefMissingInputs(
+ signature_def_utils_impl.supervised_train_signature_def,
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
+
+ def testEvalSignatureDefMissingInputs(self):
+ self._testSupervisedSignatureDefMissingInputs(
+ signature_def_utils_impl.supervised_eval_signature_def,
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME)
+
+ def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name):
+ inputs = {
+ "input-1": constant_op.constant("a", name="input-1"),
+ "input-2": constant_op.constant("b", name="input-2"),
+ }
+ loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
+ predictions = {
+ "classes": constant_op.constant([100], name="classes"),
+ }
+ metrics_val = constant_op.constant(100, name="metrics_val")
+ metrics = {
+ "metrics/value": metrics_val,
+ "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"),
+ }
+
+ with self.assertRaises(ValueError):
+ signature_def = fn_to_test(
+ {}, loss=loss, predictions=predictions, metrics=metrics)
+
+ signature_def = fn_to_test(inputs, loss=loss)
+ self.assertEqual(method_name, signature_def.method_name)
+ self.assertEqual(1, len(signature_def.outputs))
+
+ signature_def = fn_to_test(inputs, metrics=metrics, loss=loss)
+ self.assertEqual(method_name, signature_def.method_name)
+ self.assertEqual(3, len(signature_def.outputs))
+
def testGetShapeAndTypes(self):
inputs = {
"input-1": constant_op.constant(["a", "b"]),
diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py
index 5a797da791..c82154e7b9 100644
--- a/tensorflow/python/saved_model/tag_constants.py
+++ b/tensorflow/python/saved_model/tag_constants.py
@@ -32,6 +32,9 @@ TRAINING = "train"
tf_export("saved_model.tag_constants.TRAINING").export_constant(
__name__, "TRAINING")
+# Tag for the `eval` graph. Not exported while the export logic is in contrib.
+EVAL = "eval"
+
# Tag for the `gpu` graph.
GPU = "gpu"
tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU")
@@ -39,3 +42,5 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU")
# Tag for the `tpu` graph.
TPU = "tpu"
tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU")
+
+
diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py
index 05afd37ccd..d00312a1f3 100644
--- a/tensorflow/python/training/checkpointable.py
+++ b/tensorflow/python/training/checkpointable.py
@@ -659,6 +659,31 @@ class CheckpointableBase(object):
return {}
+class NoDependency(object):
+ """Allows attribute assignment to `Checkpointable` objects with no dependency.
+
+ Example usage:
+ ```python
+ obj = Checkpointable()
+ obj.has_dependency = tf.Variable(0., name="dep")
+ obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
+ assert obj.no_dependency.name == "nodep:0"
+ ```
+
+ `obj` in this example has a dependency on the variable "dep", and both
+ attributes contain un-wrapped `Variable` objects.
+
+ `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
+ dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
+ `Layer` to the attribute without a checkpoint dependency, but the `Model` will
+ still track the `Layer` (so it will appear in `Model.layers`, and its
+ variables will appear in `Model.variables`).
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+
class Checkpointable(CheckpointableBase):
"""Manages dependencies on other objects.
@@ -691,8 +716,11 @@ class Checkpointable(CheckpointableBase):
"""Support self.foo = checkpointable syntax."""
# Perform the attribute assignment, and potentially call other __setattr__
# overrides such as that for tf.keras.Model.
+ no_dependency = isinstance(value, NoDependency)
+ if no_dependency:
+ value = value.value
super(Checkpointable, self).__setattr__(name, value)
- if isinstance(value, CheckpointableBase):
+ if not no_dependency and isinstance(value, CheckpointableBase):
self._track_checkpointable(
value, name=name,
# Allow the user to switch the Checkpointable which is tracked by this
diff --git a/tensorflow/python/training/checkpointable_test.py b/tensorflow/python/training/checkpointable_test.py
index e79acb4975..85802cb661 100644
--- a/tensorflow/python/training/checkpointable_test.py
+++ b/tensorflow/python/training/checkpointable_test.py
@@ -34,6 +34,16 @@ class InterfaceTests(test.TestCase):
root.leaf = duplicate_name_dep
root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
+ def testNoDependency(self):
+ root = checkpointable.Checkpointable()
+ hasdep = checkpointable.Checkpointable()
+ root.hasdep = hasdep
+ nodep = checkpointable.Checkpointable()
+ root.nodep = checkpointable.NoDependency(nodep)
+ self.assertEqual(1, len(root._checkpoint_dependencies))
+ self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
+ self.assertIs(root.hasdep, hasdep)
+ self.assertIs(root.nodep, nodep)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py
index cf4112ff99..f2a2b411fd 100644
--- a/tensorflow/python/training/checkpointable_utils.py
+++ b/tensorflow/python/training/checkpointable_utils.py
@@ -1044,8 +1044,11 @@ class Checkpoint(checkpointable_lib.Checkpointable):
if self._save_counter is None:
# Initialized to 0 and incremented before saving.
with ops.device("/cpu:0"):
- self._save_counter = add_variable(
- self, name="save_counter", initializer=0, dtype=dtypes.int64)
+ # add_variable creates a dependency named "save_counter"; NoDependency
+ # prevents creating a second dependency named "_save_counter".
+ self._save_counter = checkpointable_lib.NoDependency(
+ add_variable(self, name="save_counter", initializer=0,
+ dtype=dtypes.int64))
@property
def save_counter(self):
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index c0d5ea36dd..ab8b37bb65 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -357,14 +357,14 @@ class DistributionStrategy(object):
on different slices of the input data. This is in contrast to
_model parallelism_ where we divide up a single copy of a model
across multiple devices.
- Note: for now we only support data parallelism at this time, but
+ Note: we only support data parallelism for now, but
hope to add support for model parallelism in the future.
* A _tower_ is one copy of the model, running on one slice of the
input data.
- * _Synchronous_, or more commonly _sync_, training is when the
+ * _Synchronous_, or more commonly _sync_, training is where the
updates from each tower are aggregated together before updating
the model variables. This is in contrast to _asynchronous_, or
- _async_ training where each tower updates the model variables
+ _async_ training, where each tower updates the model variables
independently.
* Furthermore you might run your computation on multiple devices
on one machine (or "host"), or on multiple machines/hosts.
@@ -386,11 +386,11 @@ class DistributionStrategy(object):
* Reductions and Allreduce: A _reduction_ is some method of
aggregating multiple values into one value, like "sum" or
"mean". If doing sync training, we will perform a reduction on the
- gradients to a parameter from each tower before applying the
+ gradients to a parameter from all towers before applying the
update. Allreduce is an algorithm for performing a reduction on
values from multiple devices and making the result available on
all of those devices.
- * In the future we will have support for TensorFlows' partitioned
+ * In the future we will have support for TensorFlow's partitioned
variables, where a single variable is split across multiple
devices.
@@ -419,9 +419,9 @@ class DistributionStrategy(object):
`tower_fn` can use the `get_tower_context()` API to get enhanced
behavior in this case.
- You can also create an initializable iterator instead of one shot iterator.
- In that case, you will need to ensure that you initialize the iterator
- before calling get_next.
+ You can also create an initializable iterator instead of a one-shot
+ iterator. In that case, you will need to ensure that you initialize the
+ iterator before calling get_next.
```
iterator = my_distribution.distribute_dataset(
dataset).make_initializable_iterator())
@@ -816,6 +816,7 @@ class DistributionStrategy(object):
# TODO(josh11b): Return an unwrapped value if colocate_with is a
# single device.
_require_cross_tower_context(self)
+ assert method_string in ("sum", "mean")
return self._reduce(method_string, value, destinations)
def _reduce(self, method_string, value, destinations):
diff --git a/tensorflow/stream_executor/cuda/cuda_activation.cc b/tensorflow/stream_executor/cuda/cuda_activation.cc
index cf6b9e2c6e..02371c3c3a 100644
--- a/tensorflow/stream_executor/cuda/cuda_activation.cc
+++ b/tensorflow/stream_executor/cuda/cuda_activation.cc
@@ -38,5 +38,11 @@ ScopedActivateExecutorContext::~ScopedActivateExecutorContext() {
delete static_cast<ScopedActivateContext *>(driver_scoped_activate_context_);
}
+ScopedActivateExecutorContext::ScopedActivateExecutorContext(
+ ScopedActivateExecutorContext &&other)
+ : driver_scoped_activate_context_(other.driver_scoped_activate_context_) {
+ other.driver_scoped_activate_context_ = nullptr;
+}
+
} // namespace cuda
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/cuda/cuda_activation.h b/tensorflow/stream_executor/cuda/cuda_activation.h
index 04ffaef364..ef9807820f 100644
--- a/tensorflow/stream_executor/cuda/cuda_activation.h
+++ b/tensorflow/stream_executor/cuda/cuda_activation.h
@@ -44,10 +44,11 @@ class ScopedActivateExecutorContext {
// fatal failure if it is not CUDA inside.
explicit ScopedActivateExecutorContext(StreamExecutor* stream_exec);
+ ScopedActivateExecutorContext(ScopedActivateExecutorContext&& other);
+
~ScopedActivateExecutorContext();
private:
-
// The cuda.h-using datatype that we wrap.
ScopedActivateContext* driver_scoped_activate_context_;
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 773cac2c40..af78efe81d 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -46,8 +46,20 @@ limitations under the License.
#include "cuda/include/cudnn.h"
// clang-format on
+namespace stream_executor {
+namespace cuda {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
+
namespace {
+// TODO(csigg): remove dnn namespace qualifier from the RNN code below.
+using ::stream_executor::dnn::BatchDescriptor;
+using ::stream_executor::dnn::ConvolutionDescriptor;
+using ::stream_executor::dnn::FilterDescriptor;
+using ::stream_executor::dnn::NormalizeDescriptor;
+using ::stream_executor::dnn::PoolingDescriptor;
+
// Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion.
template <typename WideT, typename NarrowT>
@@ -58,20 +70,6 @@ NarrowT CheckedNarrowing(const WideT& wide) {
return narrow;
}
-} // namespace
-
-namespace stream_executor {
-
-using dnn::BatchDescriptor;
-using dnn::FilterDescriptor;
-using dnn::ConvolutionDescriptor;
-using dnn::PoolingDescriptor;
-using dnn::NormalizeDescriptor;
-
-namespace cuda {
-
-PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
-
string ToString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
@@ -136,226 +134,82 @@ cudnnDataType_t GetCudnnDataType<Eigen::half>() {
return CUDNN_DATA_HALF;
}
-namespace wrap {
-
-static port::ThreadPool* InitCudnnThreadpool() {
- port::ThreadPool* cudnn_threadpool_;
- port::ThreadOptions options;
- // TBD(keveman): Conservatively setting the stack size and guard size to 2MB,
- // until we can get some guarantees from NVIDIA on the minimum stack space
- // they will work with.
- options.stack_size = 2 * 1024 * 1024;
- options.guard_size = 2 * 1024 * 1024;
- cudnn_threadpool_ = new port::ThreadPool(port::Env::Default(), options,
- "cudnn_threadpool", 1);
- CHECK(cudnn_threadpool_);
- return cudnn_threadpool_;
-}
-
-static mutex cudnn_threadpool_mu(LINKER_INITIALIZED);
-static port::ThreadPool* GetCudaThreadpool() {
- mutex_lock lock(cudnn_threadpool_mu);
- static port::ThreadPool* cudnn_threadpool = InitCudnnThreadpool();
- return cudnn_threadpool;
-}
-
-#define STREAM_EXECUTOR_CUDNN_WRAP(__name) \
- struct WrapperShim__##__name { \
- template <typename... Args> \
- cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
- cuda::ScopedActivateExecutorContext sac{parent}; \
- cudnnStatus_t retval = ::__name(args...); \
- return retval; \
- } \
- } __name;
-
-#define STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \
- struct WrapperShim__##__name { \
- template <typename... Args> \
- cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \
- SHARED_LOCKS_REQUIRED(dnn->dnn_handle_mutex_) { \
- CHECK_NOTNULL(s); \
- CHECK_EQ(s, dnn->GetCurrentDnnStream()) \
- << "Stream is not set correctly!"; \
- cuda::ScopedActivateExecutorContext sac{dnn->GetParentExecutor()}; \
- cudnnStatus_t retval = ::__name(args...); \
- return retval; \
- } \
- } __name;
-
-// Handles cudnnSetStream differently in order to add debug information.
-struct WrapperShim__cudnnSetStream {
- cudnnStatus_t operator()(CudnnSupport* dnn, Stream* stream,
- cudnnHandle_t handle)
- EXCLUSIVE_LOCKS_REQUIRED(dnn->dnn_handle_mutex_) {
- dnn->SetCurrentDnnStream(stream);
- cuda::ScopedActivateExecutorContext sac{dnn->GetParentExecutor()};
- cudnnStatus_t retval = ::cudnnSetStream(handle, AsCUDAStreamValue(stream));
- return retval;
- }
-} cudnnSetStream;
-
-// clang-format off
-#define CUDNN_DNN_ROUTINE_EACH(__macro) \
- __macro(cudnnGetConvolutionNdForwardOutputDim) \
- __macro(cudnnGetConvolutionForwardAlgorithm) \
- __macro(cudnnCreateTensorDescriptor) \
- __macro(cudnnDestroyTensorDescriptor) \
- __macro(cudnnCreateFilterDescriptor) \
- __macro(cudnnSetPoolingNdDescriptor) \
- __macro(cudnnSetLRNDescriptor) \
- __macro(cudnnDestroyFilterDescriptor) \
- __macro(cudnnCreateConvolutionDescriptor) \
- __macro(cudnnCreatePoolingDescriptor) \
- __macro(cudnnDestroyPoolingDescriptor) \
- __macro(cudnnCreateLRNDescriptor) \
- __macro(cudnnDestroyLRNDescriptor) \
- __macro(cudnnDestroyConvolutionDescriptor) \
- __macro(cudnnCreate) \
- __macro(cudnnDestroy) \
- __macro(cudnnGetConvolutionForwardWorkspaceSize) \
- __macro(cudnnSetConvolutionNdDescriptor) \
- __macro(cudnnSetTensor4dDescriptor) \
- __macro(cudnnSetTensorNdDescriptor) \
- __macro(cudnnSetFilterNdDescriptor)
-
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH
-
-// clang-format off
-#define CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(__macro) \
- __macro(cudnnBatchNormalizationBackward) \
- __macro(cudnnBatchNormalizationForwardInference) \
- __macro(cudnnBatchNormalizationForwardTraining) \
- __macro(cudnnActivationForward) \
- __macro(cudnnConvolutionForward) \
- __macro(cudnnConvolutionBackwardBias) \
- __macro(cudnnTransformTensor) \
- __macro(cudnnPoolingForward) \
- __macro(cudnnPoolingBackward) \
- __macro(cudnnLRNCrossChannelForward) \
- __macro(cudnnLRNCrossChannelBackward) \
- __macro(cudnnAddTensor) \
- __macro(cudnnConvolutionBackwardData) \
- __macro(cudnnConvolutionBackwardFilter)
-
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(
- STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
-#undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM
-
-// APIs available after R3:
-#if CUDNN_VERSION >= 3000
-#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
- __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
- __macro(cudnnGetConvolutionBackwardDataAlgorithm) \
- __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
- __macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
-CUDNN_DNN_ROUTINE_EACH_AFTER_R3(STREAM_EXECUTOR_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
-#endif
-
-// APIs in R3 but not in R5
-// clang-format off
-#if CUDNN_VERSION >= 3000 && CUDNN_VERSION < 5000
-#define CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(__macro) \
- __macro(cudnnAddTensor_v3) \
- __macro(cudnnConvolutionBackwardData_v3) \
- __macro(cudnnConvolutionBackwardFilter_v3)
-// clang-format on
-
-CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(
- STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
-#undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM
-#endif
-
-// APIs in R5
-// clang-format off
-#if CUDNN_VERSION >= 5000
-#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
- __macro(cudnnCreateActivationDescriptor) \
- __macro(cudnnSetActivationDescriptor) \
- __macro(cudnnGetActivationDescriptor) \
- __macro(cudnnDestroyActivationDescriptor) \
- __macro(cudnnCreateDropoutDescriptor) \
- __macro(cudnnDestroyDropoutDescriptor) \
- __macro(cudnnSetDropoutDescriptor) \
- __macro(cudnnDropoutGetStatesSize) \
- __macro(cudnnCreateRNNDescriptor) \
- __macro(cudnnDestroyRNNDescriptor) \
- __macro(cudnnGetRNNParamsSize) \
- __macro(cudnnGetRNNWorkspaceSize) \
- __macro(cudnnGetRNNTrainingReserveSize) \
- __macro(cudnnGetRNNLinLayerMatrixParams) \
- __macro(cudnnGetRNNLinLayerBiasParams) \
- __macro(cudnnSetRNNDescriptor) \
- __macro(cudnnGetFilterNdDescriptor)
-
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH_R5(STREAM_EXECUTOR_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH_R5
-
-// clang-format off
-#define CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(__macro) \
- __macro(cudnnRNNForwardInference) \
- __macro(cudnnRNNForwardTraining) \
- __macro(cudnnRNNBackwardData) \
- __macro(cudnnRNNBackwardWeights)
+// RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
+//
+// See CudnnAccess::GetHandle() for details.
+class CudnnHandle {
+ public:
+ // Takes ownership of the executor context and the lock to access cuDNN
+ // using handle.
+ CudnnHandle(cuda::ScopedActivateExecutorContext context, mutex_lock lock,
+ cudnnHandle_t handle)
+ : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
- STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
-#undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM
-#endif
+ // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
+ // a copy.
+ cudnnHandle_t handle() const { return handle_; }
-// APIs in R6
-// clang-format off
-#if CUDNN_VERSION >= 6000
-#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \
- __macro(cudnnSetRNNDescriptor_v6) \
- __macro(cudnnCreatePersistentRNNPlan) \
- __macro(cudnnDestroyPersistentRNNPlan) \
- __macro(cudnnSetPersistentRNNPlan)
+ private:
+ cuda::ScopedActivateExecutorContext context_;
+ mutex_lock lock_;
+ cudnnHandle_t handle_; // Not owned.
+};
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH_R6
+} // namespace
-// clang-format off
-#define CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(__macro) \
- __macro(cudnnConvolutionBiasActivationForward)
+// Wraps a cuDNN handle and provides access to it through CudnnHandle instances,
+// which also locks a mutex, acquires the CUDA context, and sets the stream
+// that cuDNN should use to enqueue any work.
+//
+// Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
+class CudnnAccess {
+ public:
+ // Takes ownership of the handle.
+ explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
- STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
-#undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM
-#endif
+ ~CudnnAccess() {
+ mutex_lock lock(mutex_);
+ cudnnDestroy(handle_);
+ }
-// APIs in R7
-// clang-format off
-#if CUDNN_VERSION >= 7000
-#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
- __macro(cudnnSetConvolutionMathType) \
- __macro(cudnnSetRNNMatrixMathType) \
- __macro(cudnnSetConvolutionGroupCount) \
- __macro(cudnnGetConvolutionGroupCount)
+ // Creates a CudnnHandle instance for stream.
+ //
+ // cuDNN API calls using the same handle instance need to be serialized across
+ // threads. This is guaranteed by CudnnHandle instances locking the mutex
+ // owned by this class.
+ //
+ // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
+ // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN to
+ // use the provided stream.
+ //
+ // The stream argument may be null, which translates to the legacy default
+ // stream. See
+ // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
+ // The legacy default stream synchronizes with all other streams and it is
+ // therefore a bad idea (performance wise) to call any cuDNN APIs that
+ // enqueue work in the stream.
+ CudnnHandle GetHandle(CUDAExecutor* executor, Stream* stream) {
+ mutex_lock lock(mutex_);
+ cuda::ScopedActivateExecutorContext context(executor);
+ CUstream cu_stream = stream ? AsCUDAStreamValue(stream) : cudaStreamLegacy;
+ auto status = cudnnSetStream(handle_, cu_stream);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
+ using my_mutex_lock = mutex_lock;
+ return CudnnHandle(std::move(context), std::move(lock), handle_);
+ }
-// clang-format on
-CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH_R7
-#endif
+ private:
+ // Guards the enqueueing of cuDNN operations via the handle_ below.
+ mutex mutex_;
-} // namespace wrap
+ // cuDNN library handle.
+ cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned.
+};
namespace {
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
-cudnnHandle_t ToHandle(void* opaque_handle) {
- return static_cast<cudnnHandle_t>(opaque_handle);
-}
-
cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
cudnnConvolutionFwdAlgo_t algo =
cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
@@ -432,7 +286,7 @@ port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
port::StrCat("cudnnGetProperty failed for type: ", ToString(type),
" with status: ", ToString(status));
LOG(ERROR) << error;
- return port::Status{port::error::INTERNAL, error};
+ return port::Status(port::error::INTERNAL, error);
}
return port::Status::OK();
}
@@ -471,19 +325,11 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
} // namespace
-CudnnSupport::CudnnSupport(CUDAExecutor* parent)
- : parent_(parent), dnn_handle_(nullptr), current_dnn_stream_(nullptr) {}
-
-CudnnSupport::~CudnnSupport() {
- auto status = wrap::cudnnDestroy(parent_, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn handle: " << ToString(status);
- }
-}
+CudnnSupport::CudnnSupport(CUDAExecutor* parent) : parent_(parent) {}
port::Status CudnnSupport::Init() {
- auto status = wrap::cudnnCreate(
- parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
+ cudnnHandle_t cudnn_handle = nullptr;
+ auto status = cudnnCreate(&cudnn_handle);
if (status == CUDNN_STATUS_SUCCESS) {
CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
@@ -499,9 +345,10 @@ port::Status CudnnSupport::Init() {
"from sources, make sure the library loaded at runtime is compatible "
"with the version specified during compile configuration.");
LOG(ERROR) << error;
- return port::Status{port::error::INTERNAL, error};
+ return port::Status(port::error::INTERNAL, error);
}
+ cudnn_.reset(new CudnnAccess(cudnn_handle));
return port::Status::OK();
}
@@ -525,9 +372,9 @@ port::Status CudnnSupport::Init() {
}
}
- return port::Status{port::error::INTERNAL,
+ return port::Status(port::error::INTERNAL,
port::StrCat("cudnn library could not create a handle: ",
- ToString(status))};
+ ToString(status)));
}
port::StatusOr<perftools::gputools::dnn::VersionInfo>
@@ -538,14 +385,15 @@ CudnnSupport::GetVersion() {
version.major_version, version.minor_version, version.patch_level);
}
+namespace {
+
// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
class ScopedTensorDescriptor {
public:
- ScopedTensorDescriptor(CUDAExecutor* parent,
- const BatchDescriptor& batch_descriptor,
+ ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor,
cudnnDataType_t elem_type)
- : parent_(parent), handle_(nullptr) {
- cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent_, &handle_);
+ : handle_(nullptr) {
+ cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn tensor descriptor: "
<< ToString(status);
@@ -568,8 +416,8 @@ class ScopedTensorDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
- status = wrap::cudnnSetTensorNdDescriptor(
- parent_, handle_, elem_type, nd, dims.data(), strides.data());
+ status = cudnnSetTensorNdDescriptor(handle_, elem_type, nd, dims.data(),
+ strides.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not convert BatchDescriptor "
@@ -579,8 +427,8 @@ class ScopedTensorDescriptor {
} break;
#if CUDNN_VERSION >= 6000
case dnn::DataLayout::kBatchDepthYX4: {
- status = wrap::cudnnSetTensor4dDescriptor(
- parent_, handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
+ status = cudnnSetTensor4dDescriptor(
+ handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
batch_descriptor.count(), batch_descriptor.feature_map_count(),
batch_descriptor.height(), batch_descriptor.width());
if (status != CUDNN_STATUS_SUCCESS) {
@@ -598,7 +446,7 @@ class ScopedTensorDescriptor {
}
~ScopedTensorDescriptor() {
- cudnnStatus_t status = wrap::cudnnDestroyTensorDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn tensor descriptor: "
<< ToString(status);
@@ -608,7 +456,6 @@ class ScopedTensorDescriptor {
cudnnTensorDescriptor_t handle() const { return handle_; }
private:
- CUDAExecutor* parent_; // Parent executor. Not owned.
cudnnTensorDescriptor_t handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
@@ -617,12 +464,10 @@ class ScopedTensorDescriptor {
// Turns a FilterDescriptor structure into a cudnn filter handle within a scope.
class ScopedFilterDescriptor {
public:
- ScopedFilterDescriptor(CUDAExecutor* parent,
- const FilterDescriptor& filter_descriptor,
- const BatchDescriptor& batch_descriptor,
+ ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor,
cudnnDataType_t elem_type)
- : parent_(parent), handle_(nullptr) {
- cudnnStatus_t status = wrap::cudnnCreateFilterDescriptor(parent_, &handle_);
+ : handle_(nullptr) {
+ cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn filter descriptor: "
<< ToString(status);
@@ -656,11 +501,11 @@ class ScopedFilterDescriptor {
const auto& spatial_dims = filter_descriptor.input_filter_dims();
std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
- status = wrap::cudnnSetFilterNdDescriptor(parent_, handle_, elem_type,
+ status = cudnnSetFilterNdDescriptor(handle_, elem_type,
#if CUDNN_VERSION >= 5000
- format,
+ format,
#endif
- dims.size(), dims.data());
+ dims.size(), dims.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn filter descriptor: "
<< ToString(status);
@@ -668,7 +513,7 @@ class ScopedFilterDescriptor {
}
~ScopedFilterDescriptor() {
- cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn filter descriptor: "
<< ToString(status);
@@ -678,11 +523,7 @@ class ScopedFilterDescriptor {
cudnnFilterDescriptor_t handle() const { return handle_; }
private:
- // Parent executor object. Not owned.
- CUDAExecutor* parent_;
-
- // cudnn filter descriptor this object creates. Owned.
- cudnnFilterDescriptor_t handle_;
+ cudnnFilterDescriptor_t handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};
@@ -736,11 +577,10 @@ static bool BatchnormSpatialPersistentEnabled() {
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor(
- CUDAExecutor* parent, const ConvolutionDescriptor& convolution_descriptor,
+ const ConvolutionDescriptor& convolution_descriptor,
cudnnDataType_t data_type)
- : parent_(parent), handle_(nullptr) {
- cudnnStatus_t status =
- wrap::cudnnCreateConvolutionDescriptor(parent_, &handle_);
+ : handle_(nullptr) {
+ cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn convolution descriptor: "
<< ToString(status);
@@ -766,9 +606,9 @@ class ScopedConvolutionDescriptor {
std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
&CheckedNarrowing<int64, int>);
- status = wrap::cudnnSetConvolutionNdDescriptor(
- parent_, handle_, convolution_descriptor.ndims(), padding.data(),
- strides.data(), dilations.data(),
+ status = cudnnSetConvolutionNdDescriptor(
+ handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
+ dilations.data(),
// NOTE(keveman): cuDNN supports convolution and cross correlation.
// However, almost all the use cases do cross correlation, so just
// hard coding it here.
@@ -785,8 +625,8 @@ class ScopedConvolutionDescriptor {
#if CUDNN_MAJOR >= 7
VLOG(2) << "Requesting grouped convolution: "
<< convolution_descriptor.group_count();
- status = wrap::cudnnSetConvolutionGroupCount(
- parent_, handle_, convolution_descriptor.group_count());
+ status = cudnnSetConvolutionGroupCount(
+ handle_, convolution_descriptor.group_count());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn convolution group count: "
<< ToString(status);
@@ -802,8 +642,7 @@ class ScopedConvolutionDescriptor {
cudnnMathType_t math_type =
(use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
if (TensorOpMathEnabled()) {
- cudnnStatus_t status =
- wrap::cudnnSetConvolutionMathType(parent_, handle_, math_type);
+ cudnnStatus_t status = cudnnSetConvolutionMathType(handle_, math_type);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn convolution math type: "
<< ToString(status);
@@ -813,8 +652,7 @@ class ScopedConvolutionDescriptor {
}
~ScopedConvolutionDescriptor() {
- cudnnStatus_t status =
- wrap::cudnnDestroyConvolutionDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyConvolutionDescriptor(handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn convolution descriptor: "
<< ToString(status);
@@ -824,7 +662,6 @@ class ScopedConvolutionDescriptor {
cudnnConvolutionDescriptor_t handle() const { return handle_; }
private:
- CUDAExecutor* parent_; // Parent executor. Not owned.
cudnnConvolutionDescriptor_t handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
@@ -834,11 +671,9 @@ class ScopedConvolutionDescriptor {
// within a scope.
class ScopedPoolingDescriptor {
public:
- ScopedPoolingDescriptor(CUDAExecutor* parent,
- const PoolingDescriptor& pooling_descriptor)
- : parent_(parent), handle_(nullptr) {
- cudnnStatus_t status =
- wrap::cudnnCreatePoolingDescriptor(parent_, &handle_);
+ explicit ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor)
+ : handle_(nullptr) {
+ cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn pooling descriptor: "
<< ToString(status);
@@ -858,8 +693,8 @@ class ScopedPoolingDescriptor {
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>);
bool propagate_nans = pooling_descriptor.propagate_nans();
- status = wrap::cudnnSetPoolingNdDescriptor(
- parent_, handle_,
+ status = cudnnSetPoolingNdDescriptor(
+ handle_,
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? CUDNN_POOLING_MAX
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
@@ -873,8 +708,7 @@ class ScopedPoolingDescriptor {
}
}
~ScopedPoolingDescriptor() {
- cudnnStatus_t status =
- wrap::cudnnDestroyPoolingDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyPoolingDescriptor(handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn pooling descriptor: "
<< ToString(status);
@@ -884,7 +718,6 @@ class ScopedPoolingDescriptor {
cudnnPoolingDescriptor_t handle() const { return handle_; }
private:
- CUDAExecutor* parent_; // Parent executor. Not owned.
cudnnPoolingDescriptor_t handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
@@ -893,10 +726,10 @@ class ScopedPoolingDescriptor {
// Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
class ScopedNormalizeDescriptor {
public:
- ScopedNormalizeDescriptor(CUDAExecutor* parent,
- const NormalizeDescriptor& normalize_descriptor)
- : parent_(parent), handle_(nullptr) {
- cudnnStatus_t status = wrap::cudnnCreateLRNDescriptor(parent_, &handle_);
+ explicit ScopedNormalizeDescriptor(
+ const NormalizeDescriptor& normalize_descriptor)
+ : handle_(nullptr) {
+ cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn LRN descriptor: "
<< ToString(status);
@@ -922,15 +755,14 @@ class ScopedNormalizeDescriptor {
double lrnBeta = normalize_descriptor.beta();
double lrnK = normalize_descriptor.bias();
- status = wrap::cudnnSetLRNDescriptor(parent_, handle_, lrnN, lrnAlpha,
- lrnBeta, lrnK);
+ status = cudnnSetLRNDescriptor(handle_, lrnN, lrnAlpha, lrnBeta, lrnK);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status);
}
}
~ScopedNormalizeDescriptor() {
- cudnnStatus_t status = wrap::cudnnDestroyLRNDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyLRNDescriptor(handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn LRN descriptor: "
<< ToString(status);
@@ -940,7 +772,6 @@ class ScopedNormalizeDescriptor {
cudnnLRNDescriptor_t handle() const { return handle_; }
private:
- CUDAExecutor* parent_; // Parent executor. Not owned.
cudnnLRNDescriptor_t handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
@@ -951,13 +782,11 @@ class ScopedNormalizeDescriptor {
// descriptor handle within a scope.
class ScopedActivationDescriptor {
public:
- ScopedActivationDescriptor(CUDAExecutor* parent,
- dnn::ActivationMode activation_mode,
+ ScopedActivationDescriptor(dnn::ActivationMode activation_mode,
cudnnNanPropagation_t nan_propagation,
double value_max)
- : parent_(parent), handle_(nullptr) {
- cudnnStatus_t status =
- wrap::cudnnCreateActivationDescriptor(parent_, &handle_);
+ : handle_(nullptr) {
+ cudnnStatus_t status = cudnnCreateActivationDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not create cudnn activation descriptor: "
<< ToString(status);
@@ -988,8 +817,8 @@ class ScopedActivationDescriptor {
<< static_cast<int>(activation_mode);
}
- status = wrap::cudnnSetActivationDescriptor(parent_, handle_, mode,
- nan_propagation, relu_ceiling);
+ status = cudnnSetActivationDescriptor(handle_, mode, nan_propagation,
+ relu_ceiling);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn activation descriptor: "
<< ToString(status);
@@ -997,8 +826,7 @@ class ScopedActivationDescriptor {
}
~ScopedActivationDescriptor() {
- cudnnStatus_t status =
- wrap::cudnnDestroyActivationDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyActivationDescriptor(handle_);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not destroy cudnn activation descriptor: "
<< ToString(status);
@@ -1008,14 +836,12 @@ class ScopedActivationDescriptor {
cudnnActivationDescriptor_t handle() const { return handle_; }
private:
- CUDAExecutor* parent_; // Parent executor. Not owned.
cudnnActivationDescriptor_t handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
#endif
-namespace {
cudnnDataType_t ToCudnnDataType(
dnn::DataType data_type,
dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
@@ -1090,8 +916,6 @@ class MixinBase : public Base {};
template <>
class MixinBase<void> {};
-} // namespace
-
#if CUDNN_VERSION >= 5000
#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \
@@ -1102,6 +926,7 @@ class MixinBase<void> {};
return; \
}
+// TODO(csigg): Remove inheritance for code reuse.
template <typename Base>
class CudnnDescriptorCommon : public MixinBase<Base> {
public:
@@ -1115,12 +940,11 @@ class CudnnDescriptorCommon : public MixinBase<Base> {
class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
public:
- CudnnDropoutDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
- float dropout, uint64 seed,
+ CudnnDropoutDescriptor(const CudnnHandle& cudnn, float dropout, uint64 seed,
ScratchAllocator* state_allocator)
- : parent_(parent), handle_(nullptr) {
+ : handle_(nullptr) {
cudnnStatus_t status;
- status = wrap::cudnnCreateDropoutDescriptor(parent_, &handle_);
+ status = cudnnCreateDropoutDescriptor(&handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor");
if (dropout == 0.f) {
@@ -1130,8 +954,7 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
DeviceMemory<uint8> state_memory;
if (state_allocator) {
size_t state_sizes_in_bytes = 0;
- status = wrap::cudnnDropoutGetStatesSize(parent_, cudnn_handle,
- &state_sizes_in_bytes);
+ status = cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes);
CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes");
auto allocated =
@@ -1146,9 +969,9 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
return;
}
}
- status = wrap::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle,
- dropout, state_memory.opaque(),
- state_memory.size(), seed);
+ status = cudnnSetDropoutDescriptor(handle_, cudnn.handle(), dropout,
+ state_memory.opaque(),
+ state_memory.size(), seed);
CUDNN_RETURN_IF_FAIL(
status, port::StrCat(
"Failed to set dropout descriptor with state memory size: ",
@@ -1156,11 +979,9 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
}
~CudnnDropoutDescriptor() {
- if (handle_) {
- cudnnStatus_t status =
- wrap::cudnnDestroyDropoutDescriptor(parent_, handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
- }
+ cudnnStatus_t status = cudnnDestroyDropoutDescriptor(handle_);
+ // TODO(csigg): This is a no-op (error is not reported). Same below.
+ CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
}
cudnnDropoutDescriptor_t handle() const {
@@ -1169,8 +990,7 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
}
private:
- CUDAExecutor* parent_;
- cudnnDropoutDescriptor_t handle_;
+ cudnnDropoutDescriptor_t handle_; // Owned.
float dropout_;
uint64 seed_;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
@@ -1180,10 +1000,10 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
public:
typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
- CudnnRnnParamsDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+ CudnnRnnParamsDescriptor(const CudnnHandle& cudnn,
const CudnnRnnDescriptor& rnn_desc);
~CudnnRnnParamsDescriptor() {
- cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor");
}
cudnnFilterDescriptor_t handle() const {
@@ -1202,7 +1022,6 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
private:
int GetRegionCountPerLayer() const;
- CUDAExecutor* parent_;
cudnnFilterDescriptor_t handle_;
const CudnnRnnDescriptor* rnn_desc_;
int64 params_size_in_bytes_;
@@ -1211,19 +1030,20 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
};
+} // namespace
+
class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
public:
- CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
- int num_layers, int hidden_size, int input_size,
- int batch_size, cudnnRNNInputMode_t input_mode,
+ CudnnRnnDescriptor(const CudnnHandle& cudnn, int num_layers, int hidden_size,
+ int input_size, int batch_size,
+ cudnnRNNInputMode_t input_mode,
cudnnDirectionMode_t direction_mode,
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
cudnnDataType_t compute_type,
const dnn::AlgorithmConfig& algorithm_config,
float dropout, uint64 seed,
ScratchAllocator* state_allocator)
- : parent_(parent),
- rnn_desc_(nullptr),
+ : rnn_desc_(nullptr),
num_layers_(num_layers),
hidden_size_(hidden_size),
input_size_(input_size),
@@ -1238,21 +1058,21 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
compute_type_(compute_type),
algorithm_config_(algorithm_config) {
// Create the dropout handle.
- cudnn_dropout_desc_.reset(new CudnnDropoutDescriptor(
- parent, cudnn_handle, dropout, seed, state_allocator));
+ cudnn_dropout_desc_.reset(
+ new CudnnDropoutDescriptor(cudnn, dropout, seed, state_allocator));
if (!cudnn_dropout_desc_->ok()) {
SetFailure(cudnn_dropout_desc_->Status());
return;
}
// Create the RNN handle
- cudnnStatus_t status = wrap::cudnnCreateRNNDescriptor(parent_, &rnn_desc_);
+ cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
#if CUDNN_VERSION >= 6000
// TODO: allow the user to choose an algorithm.
rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
- status = wrap::cudnnSetRNNDescriptor_v6(
- parent, cudnn_handle, /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
+ status = cudnnSetRNNDescriptor_v6(
+ cudnn.handle(), /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
/*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
/*inputMode=*/input_mode, /*direction=*/direction_mode,
/*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type);
@@ -1264,26 +1084,25 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
CHECK_GE(batch_size_, 0);
- status = wrap::cudnnCreatePersistentRNNPlan(
- parent, rnn_desc_, batch_size_, data_type_, &rnn_plan_);
+ status = cudnnCreatePersistentRNNPlan(rnn_desc_, batch_size_, data_type_,
+ &rnn_plan_);
CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan.");
- status = wrap::cudnnSetPersistentRNNPlan(parent, rnn_desc_, rnn_plan_);
+ status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_);
CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
}
#else
CHECK(algorithm_config_.is_default())
<< "Non-default algorithm not supported for CUDA version < 6.0";
- status = wrap::cudnnSetRNNDescriptor(
- parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
- num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
- input_mode /*inputMode*/, direction_mode /*direction*/,
- rnn_mode /*mode*/, compute_type /*dataType*/);
+ status = cudnnSetRNNDescriptor(
+ /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
+ /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+ /*inputMode=*/input_mode, /*direction=*/direction_mode,
+ /*mode=*/rnn_mode, /*dataType=*/compute_type);
CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
#endif
// Create the params handle.
- cudnn_params_desc_.reset(
- new CudnnRnnParamsDescriptor(parent, cudnn_handle, *this));
+ cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this));
if (!cudnn_params_desc_->ok()) {
SetFailure(cudnn_params_desc_->Status());
return;
@@ -1295,11 +1114,11 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
cudnnStatus_t status;
#if CUDNN_VERSION >= 6000
if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
- status = wrap::cudnnDestroyPersistentRNNPlan(parent_, rnn_plan_);
+ status = cudnnDestroyPersistentRNNPlan(rnn_plan_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
}
#endif
- status = wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
+ status = cudnnDestroyRNNDescriptor(rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
}
}
@@ -1308,11 +1127,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
cudnnMathType_t math_type =
(use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
if (RnnTensorOpMathEnabled()) {
- cudnnStatus_t status =
- wrap::cudnnSetRNNMatrixMathType(parent_, rnn_desc_, math_type);
+ cudnnStatus_t status = cudnnSetRNNMatrixMathType(rnn_desc_, math_type);
if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn RNN math type: "
- << ToString(status);
+ LOG(FATAL) << "could not set cudnn RNN math type: " << ToString(status);
}
}
#endif
@@ -1354,7 +1171,6 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
}
private:
- CUDAExecutor* parent_;
cudnnRNNDescriptor_t rnn_desc_;
int num_layers_;
int hidden_size_;
@@ -1377,30 +1193,28 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
};
+namespace {
+
CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
- CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
- const CudnnRnnDescriptor& rnn_desc)
- : parent_(parent),
- handle_(nullptr),
- rnn_desc_(&rnn_desc),
- params_size_in_bytes_(0) {
+ const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc)
+ : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
cudnnTensorDescriptor_t input_desc = nullptr;
{
// Query the params size.
- auto status = wrap::cudnnCreateTensorDescriptor(parent, &input_desc);
+ auto status = cudnnCreateTensorDescriptor(&input_desc);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor");
int dims[] = {1, rnn_desc.input_size(), 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = wrap::cudnnSetTensorNdDescriptor(
- parent, input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
- sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
- strides /*strideA*/);
+ status = cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/input_desc, rnn_desc.data_type() /*dataType*/,
+ sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims,
+ /*strideA=*/strides);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
size_t params_size = 0;
- status = wrap::cudnnGetRNNParamsSize(
- parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
+ status = cudnnGetRNNParamsSize(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ /*xDesc=*/input_desc, /*sizeInBytes=*/&params_size,
rnn_desc.data_type() /*dataType*/);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size");
params_size_in_bytes_ = static_cast<int64>(params_size);
@@ -1408,13 +1222,13 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
{
// Create the params descriptor.
- auto status = wrap::cudnnCreateFilterDescriptor(parent, &handle_);
+ auto status = cudnnCreateFilterDescriptor(&handle_);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
- status = wrap::cudnnSetFilterNdDescriptor(
- parent, handle_ /*filterDesc*/, rnn_desc.data_type() /*dataType*/,
- CUDNN_TENSOR_NCHW /*format*/, sizeof(dims) / sizeof(dims[0]) /*nbDims*/,
- dims /*filterDimA*/);
+ status = cudnnSetFilterNdDescriptor(
+ /*filterDesc=*/handle_, rnn_desc.data_type() /*dataType*/,
+ /*format=*/CUDNN_TENSOR_NCHW, sizeof(dims) / sizeof(dims[0]) /*nbDims*/,
+ /*filterDimA=*/dims);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor");
}
@@ -1422,8 +1236,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
// Create the weights and biases into the params buffer
int region_count_per_layer = GetRegionCountPerLayer();
cudnnFilterDescriptor_t region_desc_handle = nullptr;
- auto status =
- wrap::cudnnCreateFilterDescriptor(parent, &region_desc_handle);
+ auto status = cudnnCreateFilterDescriptor(&region_desc_handle);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor");
const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL
? rnn_desc.num_layers()
@@ -1433,21 +1246,21 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
for (int type = 0; type < 2; type++) {
void* offset = nullptr;
if (type == 0) {
- status = wrap::cudnnGetRNNLinLayerMatrixParams(
- parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
- nullptr /*w*/, region /*linLayerID*/,
- region_desc_handle /*linLayerMatDesc*/,
- &offset /*linLayerMat*/);
+ status = cudnnGetRNNLinLayerMatrixParams(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
+ /*w=*/nullptr, /*linLayerID=*/region,
+ /*linLayerMatDesc=*/region_desc_handle,
+ /*linLayerMat=*/&offset);
CUDNN_RETURN_IF_FAIL(
status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
} else {
- status = wrap::cudnnGetRNNLinLayerBiasParams(
- parent, cudnn_handle /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/,
- layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/,
- nullptr /*w*/, region /*linLayerID*/,
- region_desc_handle /*linLayerBiasDesc*/,
- &offset /*linLayerBias*/);
+ status = cudnnGetRNNLinLayerBiasParams(
+ cudnn.handle() /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/,
+ /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
+ /*w=*/nullptr, /*linLayerID=*/region,
+ /*linLayerBiasDesc=*/region_desc_handle,
+ /*linLayerBias=*/&offset);
CUDNN_RETURN_IF_FAIL(
status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams");
}
@@ -1455,15 +1268,15 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
cudnnDataType_t data_type;
cudnnTensorFormat_t tensor_format;
int n_dims;
- status = wrap::cudnnGetFilterNdDescriptor(
- parent, region_desc_handle /*filterDesc*/,
+ status = cudnnGetFilterNdDescriptor(
+ /*filterDesc=*/region_desc_handle,
sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/,
- &data_type /*dataType*/, &tensor_format /*format*/,
- &n_dims /*nbDims*/, dims /*filterDimA*/);
+ /*dataType=*/&data_type, /*format=*/&tensor_format,
+ /*nbDims=*/&n_dims, /*filterDimA=*/dims);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description");
int64 size = dims[0] * dims[1] * dims[2] *
CudnnDataTypeToByteSize(rnn_desc.data_type());
- auto region = ParamsRegion{reinterpret_cast<int64>(offset), size};
+ ParamsRegion region = {reinterpret_cast<int64>(offset), size};
if (type == 0) {
weights_.push_back(region);
} else {
@@ -1472,13 +1285,13 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
}
}
}
- status = wrap::cudnnDestroyFilterDescriptor(parent, region_desc_handle);
+ status = cudnnDestroyFilterDescriptor(region_desc_handle);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor");
}
{
// Release the dummy input tensor descriptor.
- auto status = wrap::cudnnDestroyTensorDescriptor(parent, input_desc);
+ auto status = cudnnDestroyTensorDescriptor(input_desc);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor");
}
}
@@ -1498,6 +1311,8 @@ int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const {
}
}
+} // namespace
+
class CudnnRnnSequenceTensorDescriptor
: public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
public:
@@ -1517,14 +1332,14 @@ class CudnnRnnSequenceTensorDescriptor
SetFailure(port::Status(port::error::UNKNOWN, error_msg));
return;
}
- cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle);
+ cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle);
CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = wrap::cudnnSetTensorNdDescriptor(
- parent, handle /*tensorDesc*/, data_type /*dataType*/,
- sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
- strides /*strideA*/);
+ status = cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/handle, /*dataType=*/data_type,
+ sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims,
+ /*strideA=*/strides);
CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
// Replicate handle across the number of steps.
handles_.assign(seq_length, handle);
@@ -1532,8 +1347,7 @@ class CudnnRnnSequenceTensorDescriptor
~CudnnRnnSequenceTensorDescriptor() override {
// Only the first one needs to be destroyed. All others are the same.
- cudnnStatus_t status =
- wrap::cudnnDestroyTensorDescriptor(parent_, handles_[0]);
+ cudnnStatus_t status = cudnnDestroyTensorDescriptor(handles_[0]);
CUDNN_RETURN_IF_FAIL(status,
"Failed to destroy sequence tensor descriptor");
}
@@ -1570,21 +1384,20 @@ class CudnnRnnStateTensorDescriptor
batch_size_(batch_size),
data_size_(data_size),
data_type_(data_type) {
- cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle_);
+ cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
int dims[] = {num_layers, batch_size, data_size};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = wrap::cudnnSetTensorNdDescriptor(
- parent, handle_ /*tensorDesc*/, data_type /*dataType*/,
- sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/,
- strides /*strideA*/);
+ status = cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/handle_, /*dataType=*/data_type,
+ sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims,
+ /*strideA=*/strides);
CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
}
~CudnnRnnStateTensorDescriptor() override {
if (!handle_) {
- cudnnStatus_t status =
- wrap::cudnnDestroyTensorDescriptor(parent_, handle_);
+ cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor");
}
}
@@ -1679,13 +1492,13 @@ bool ExtractAndCheckRnnForward(
return true;
}
-bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
+bool CheckRNNParameterSize(const CudnnHandle& cudnn,
const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc) {
size_t params_size_in_bytes = 0;
- cudnnStatus_t status = wrap::cudnnGetRNNParamsSize(
- parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
+ cudnnStatus_t status = cudnnGetRNNParamsSize(
+ /*handle=*/cudnn.handle(), rnn_desc.handle() /*rnnDesc*/,
+ input_desc.handles()[0] /*xDesc*/, /*sizeInBytes=*/&params_size_in_bytes,
rnn_desc.data_type() /*dataType*/);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
@@ -1695,18 +1508,17 @@ bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
rnn_desc.ParamsSizeInBytes();
}
-bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent,
- cudnnHandle_t cudnn_handle,
+bool CreateRnnWorkspace(Stream* stream, const CudnnHandle& cudnn,
const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
ScratchAllocator* workspace_allocator,
DeviceMemory<uint8>* workspace) {
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
- cudnnStatus_t status = wrap::cudnnGetRNNWorkspaceSize(
- parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
- &workspace_size_in_bytes /*sizeInBytes*/);
+ cudnnStatus_t status = cudnnGetRNNWorkspaceSize(
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*seqLength=*/input_desc.seq_length(), /*xDesc=*/input_desc.handles(),
+ /*sizeInBytes=*/&workspace_size_in_bytes);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
return false;
@@ -1758,25 +1570,18 @@ bool CudnnSupport::DoRnnForwardImpl(
return false;
}
- // check params size
- mutex_lock lock{dnn_handle_mutex_};
- auto set_stream_status =
- wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (set_stream_status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: "
- << ToString(set_stream_status);
- }
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
- if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
- input_desc)) {
+ // check params size
+ if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) {
LOG(ERROR) << "Invalid parameters";
return false;
}
// create the workspace
DeviceMemory<uint8> workspace;
- if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc,
- input_desc, workspace_allocator, &workspace)) {
+ if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
+ workspace_allocator, &workspace)) {
LOG(ERROR) << "Unable to create rnn workspace";
return false;
}
@@ -1786,11 +1591,10 @@ bool CudnnSupport::DoRnnForwardImpl(
DeviceMemory<uint8> reserve_space;
if (is_training) {
size_t reserve_space_size_in_bytes = 0;
- cudnnStatus_t status = wrap::cudnnGetRNNTrainingReserveSize(
- parent_, ToHandle(dnn_handle_) /*handle*/,
- rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
- input_desc.handles() /*xDesc*/,
- &reserve_space_size_in_bytes /*sizeInBytes*/);
+ cudnnStatus_t status = cudnnGetRNNTrainingReserveSize(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ /*seqLength=*/model_dims.seq_length, input_desc.handles() /*xDesc*/,
+ /*sizeInBytes=*/&reserve_space_size_in_bytes);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
return false;
@@ -1825,30 +1629,28 @@ bool CudnnSupport::DoRnnForwardImpl(
// make the forward call
cudnnStatus_t status;
if (!is_training) {
- status = wrap::cudnnRNNForwardInference(
- this, stream, ToHandle(dnn_handle_) /*handle*/,
- rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
- input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
- input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
- input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
- rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
- output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/,
- output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/,
- output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
- workspace.opaque() /*workspace*/,
+ status = cudnnRNNForwardInference(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
+ input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
+ input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
+ input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
+ params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
+ output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
+ output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
+ output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
workspace.size() /*workSpaceSizeInBytes*/);
} else {
- status = wrap::cudnnRNNForwardTraining(
- this, stream, ToHandle(dnn_handle_) /*handle*/,
- rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
- input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
- input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
- input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
- rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
- output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/,
- output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/,
- output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
- workspace.opaque() /*workspace*/,
+ status = cudnnRNNForwardTraining(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
+ input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
+ input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
+ input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
+ params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
+ output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
+ output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
+ output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
workspace.size() /*workSpaceSizeInBytes*/,
reserve_space.opaque() /*reserveSpace*/,
reserve_space.size() /*reserveSpaceSizeInBytes*/);
@@ -1914,25 +1716,18 @@ bool CudnnSupport::DoRnnBackwardImpl(
return false;
}
- // check params size
- mutex_lock lock{dnn_handle_mutex_};
- auto set_stream_status =
- wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (set_stream_status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: "
- << ToString(set_stream_status);
- }
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
- if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc,
- input_desc)) {
+ // check params size
+ if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) {
LOG(ERROR) << "Invalid parameters";
return false;
}
// create the workspace
DeviceMemory<uint8> workspace;
- if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc,
- input_desc, workspace_allocator, &workspace)) {
+ if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
+ workspace_allocator, &workspace)) {
LOG(ERROR) << "Unable to create rnn workspace";
return false;
}
@@ -1952,12 +1747,11 @@ bool CudnnSupport::DoRnnBackwardImpl(
}
}
// make the backward data call
- cudnnStatus_t status = wrap::cudnnRNNBackwardData(
- this, stream, ToHandle(dnn_handle_) /*handle*/,
- rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
- output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/,
- output_desc.handles() /*dyDesc*/, output_backprop_data.opaque() /*dy*/,
- output_h_desc.handle() /*dhyDesc*/,
+ cudnnStatus_t status = cudnnRNNBackwardData(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
+ output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
+ output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
output_h_backprop_data.opaque() /*dhy*/,
output_c_desc.handle() /*dcyDesc*/,
output_c_backprop_data.opaque() /*dcy*/,
@@ -1985,13 +1779,12 @@ bool CudnnSupport::DoRnnBackwardImpl(
// Clear the dw to zeros.
stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
// make the backward weight call
- status = wrap::cudnnRNNBackwardWeights(
- this, stream, ToHandle(dnn_handle_) /*handle*/,
- rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
- input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
- input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
- output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/,
- workspace.opaque() /*workspace*/,
+ status = cudnnRNNBackwardWeights(
+ cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
+ input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
+ input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/,
+ output_data.opaque() /*y*/, workspace.opaque() /*workspace*/,
workspace.size() /*workSpaceSizeInBytes*/,
rnn_desc.params_handle() /*dwDesc*/,
params_backprop_data->opaque() /*dw*/,
@@ -2029,13 +1822,15 @@ CudnnSupport::createRnnDescriptor(
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
ScratchAllocator* state_allocator) {
#if CUDNN_VERSION >= 5000
- mutex_lock lock{dnn_handle_mutex_};
+ // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
+ // not enqueueing anything into a stream, we pass in the null stream.
+ auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
- parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size,
- batch_size, ToCudnnRnnInputMode(input_mode),
- ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
- ToCudnnDataType(data_type), GetRnnComputeType(data_type),
- algorithm_config, dropout, seed, state_allocator));
+ cudnn, num_layers, hidden_size, input_size, batch_size,
+ ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
+ ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type),
+ GetRnnComputeType(data_type), algorithm_config, dropout, seed,
+ state_allocator));
if (!rnn_desc->ok()) {
return rnn_desc->Status();
}
@@ -2046,7 +1841,7 @@ CudnnSupport::createRnnDescriptor(
port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ",
"Current Cudnn version: ", CUDNN_VERSION, ". ");
LOG(ERROR) << error_msg;
- return port::Status{port::error::UNIMPLEMENTED, error_msg};
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
#endif // CUDNN_VERSION
}
@@ -2069,7 +1864,7 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
"createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ",
"Current Cudnn version: ", CUDNN_VERSION, ". ");
LOG(ERROR) << error_msg;
- return port::Status{port::error::UNIMPLEMENTED, error_msg};
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
#endif // CUDNN_VERSION
}
@@ -2091,7 +1886,7 @@ CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
"createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ",
"Current Cudnn version: ", CUDNN_VERSION, ". ");
LOG(ERROR) << error_msg;
- return port::Status{port::error::UNIMPLEMENTED, error_msg};
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
#endif // CUDNN_VERSION
}
@@ -2393,35 +2188,26 @@ bool CudnnSupport::DoRnnBackward(
namespace {
inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
- Stream* stream, CUDAExecutor* parent, void* dnn_handle,
- const ScopedTensorDescriptor& input_nd,
+ const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd,
const ScopedFilterDescriptor& filter,
const ScopedConvolutionDescriptor& conv,
const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
- ScratchAllocator* scratch_allocator) {
+ size_t memory_limit_bytes) {
cudnnConvolutionFwdPreference_t preference =
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
: CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
cudnnConvolutionFwdAlgo_t algo_to_use;
- auto status = wrap::cudnnGetConvolutionForwardAlgorithm(
- parent, ToHandle(dnn_handle), input_nd.handle(), filter.handle(),
- conv.handle(), output_nd.handle(), preference, memory_limit_bytes,
- &algo_to_use);
+ auto status = cudnnGetConvolutionForwardAlgorithm(
+ cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
+ output_nd.handle(), preference, memory_limit_bytes, &algo_to_use);
CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
<< "Unable to find a suitable algorithm for doing forward convolution";
return algo_to_use;
}
dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
- Stream* stream, CUDAExecutor* parent, void* dnn_handle,
+ Stream* stream, const CudnnHandle& cudnn,
const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
const ScopedTensorDescriptor& input_nd,
const ScopedFilterDescriptor& filter,
@@ -2432,19 +2218,29 @@ dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
bool use_tensor_ops;
if (algorithm_config.algorithm().is_default()) {
use_tensor_ops = true;
+
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
+
algo = GetCudnnConvolutionForwardAlgo(
- stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
+ cudnn, input_nd, filter, conv, output_nd,
/*specify_workspace_limit=*/scratch_allocator != nullptr,
- scratch_allocator);
+ memory_limit_bytes);
} else {
use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
algo = ToConvForwardAlgo(algorithm_config.algorithm());
}
size_t size_in_bytes;
- auto status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
- parent, ToHandle(dnn_handle), /*srcDesc=*/input_nd.handle(),
- /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*destDesc=*/output_nd.handle(), /*algo=*/algo,
+ auto status = cudnnGetConvolutionForwardWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
+ /*yDesc=*/output_nd.handle(), /*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
@@ -2484,8 +2280,8 @@ dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
if (algorithm_config.algorithm_no_scratch().is_default()) {
use_tensor_ops = true;
algo = GetCudnnConvolutionForwardAlgo(
- stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
- /*specify_workspace_limit=*/false, nullptr);
+ cudnn, input_nd, filter, conv, output_nd,
+ /*specify_workspace_limit=*/false, 0);
} else {
use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
@@ -2614,11 +2410,12 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
}
}
+
} // namespace
template <class T>
bool CudnnSupport::DoConvolveImpl(
- Stream* stream, const BatchDescriptor& batch_descriptor,
+ Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
@@ -2628,18 +2425,13 @@ bool CudnnSupport::DoConvolveImpl(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor input_nd{parent_, batch_descriptor, cudnn_type};
- ScopedTensorDescriptor output_nd{parent_, output_descriptor, cudnn_type};
- ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor,
- cudnn_type};
- ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
- GetConvComputeType<T>()};
-
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
- }
+ ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
+ ScopedTensorDescriptor output_nd(output_descriptor, cudnn_type);
+ ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
+ ScopedConvolutionDescriptor conv(convolution_descriptor,
+ GetConvComputeType<T>());
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input.
float falpha = 1.0;
double dalpha = 1.0;
@@ -2660,42 +2452,41 @@ bool CudnnSupport::DoConvolveImpl(
// GetCudnnConvolutionForwardAlgorithm().
if (algorithm_config.algorithm().is_default()) {
// With the default algorithm, use Cudnn's heuristics.
- auto get_algorithm =
- [&](bool specify_limit) SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) {
- cudnnConvolutionFwdPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
+ auto get_algorithm = [&](bool specify_limit) {
+ cudnnConvolutionFwdPreference_t preference =
+ specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
- cudnnConvolutionFwdAlgo_t algo_to_use;
- status = wrap::cudnnGetConvolutionForwardAlgorithm(
- parent_, ToHandle(dnn_handle_), input_nd.handle(),
- filter.handle(), conv.handle(), output_nd.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
- << "Unable to find a suitable "
- "algorithm for doing forward "
- "convolution";
- return algo_to_use;
- };
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
+
+ cudnnConvolutionFwdAlgo_t algo_to_use;
+ auto status = cudnnGetConvolutionForwardAlgorithm(
+ cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
+ output_nd.handle(),
+ /*preference=*/preference,
+ /*memoryLimitInBytes=*/memory_limit_bytes,
+ /*algo=*/&algo_to_use);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
+ "algorithm for doing forward "
+ "convolution";
+ return algo_to_use;
+ };
algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
use_tensor_ops = true;
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
- status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
- parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
- /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*destDesc=*/output_nd.handle(), /*algo=*/algo,
+ auto status = cudnnGetConvolutionForwardWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
+ /*yDesc=*/output_nd.handle(), /*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
@@ -2727,10 +2518,11 @@ bool CudnnSupport::DoConvolveImpl(
use_tensor_ops = algotype.tensor_ops_enabled();
conv.set_use_tensor_op_math(use_tensor_ops);
size_t size_in_bytes;
- status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
- parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
- /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*destDesc=*/output_nd.handle(), /*algo=*/algo,
+ auto status = cudnnGetConvolutionForwardWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
+ /*yDesc=*/output_nd.handle(), /*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
if (status != CUDNN_STATUS_SUCCESS) {
if (is_profiling) {
@@ -2785,8 +2577,8 @@ bool CudnnSupport::DoConvolveImpl(
return false;
}
}
- status = wrap::cudnnConvolutionForward(
- this, stream, ToHandle(dnn_handle_),
+ auto status = cudnnConvolutionForward(
+ cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
@@ -2840,30 +2632,22 @@ bool CudnnSupport::DoFusedConvolveImpl(
"supported for cuDNN version >= 6";
return false;
#else
- ScopedTensorDescriptor conv_input_nd{
- parent_, conv_input_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type)};
- ScopedTensorDescriptor output_nd{
- parent_, output_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type)};
- ScopedFilterDescriptor filter{parent_, filter_descriptor,
- conv_input_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type)};
- ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, CUDNN_DATA_FLOAT};
- ScopedConvolutionDescriptor conv{
- parent_, convolution_descriptor,
- static_cast<cudnnDataType_t>(cudnn_compute_type)};
-
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- CHECK(status == CUDNN_STATUS_SUCCESS)
- << "failed to set stream for cudnn handle: " << ToString(status);
-
+ ScopedTensorDescriptor conv_input_nd(
+ conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+ ScopedTensorDescriptor output_nd(
+ output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+ ScopedFilterDescriptor filter(filter_descriptor,
+ static_cast<cudnnDataType_t>(cudnn_data_type));
+ ScopedTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
+ ScopedConvolutionDescriptor conv(
+ convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
const bool is_profiling = output_profile_result != nullptr;
DeviceMemory<uint8> scratch;
dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm(
- stream, parent_, dnn_handle_, algorithm_config, is_profiling,
- conv_input_nd, filter, conv, output_nd, scratch_allocator, &scratch);
+ stream, cudnn, algorithm_config, is_profiling, conv_input_nd, filter,
+ conv, output_nd, scratch_allocator, &scratch);
if (algotype.is_default()) {
if (!is_profiling) {
LOG(ERROR) << "No suitable algorithm found";
@@ -2897,9 +2681,8 @@ bool CudnnSupport::DoFusedConvolveImpl(
// activation descriptor. Note that this will change the nan propagation
// behavior from separate conv, bias, and relu (which by default is
// CUDNN_PROPAGATE_NAN.
- ScopedActivationDescriptor activation_desc{parent_, activation_mode,
- CUDNN_NOT_PROPAGATE_NAN,
- output_descriptor.value_max()};
+ ScopedActivationDescriptor activation_desc(
+ activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
: side_input_data.opaque();
@@ -2920,8 +2703,9 @@ bool CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
- status = wrap::cudnnConvolutionBiasActivationForward(
- this, stream, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale,
+ auto status = cudnnConvolutionBiasActivationForward(
+ cudnn.handle(),
+ /*alpha1=*/&conv_input_scale,
/*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
/*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
/*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
@@ -3125,17 +2909,9 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
- ScopedTensorDescriptor x_descriptor{parent_, x_desc,
- ToCudnnDataType(input_data_type)};
- ScopedTensorDescriptor scale_offset_descriptor{
- parent_, scale_offset_desc, ToCudnnDataType(scale_data_type)};
+ ScopedTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
+ ScopedTensorDescriptor scale_offset_descriptor(
+ scale_offset_desc, ToCudnnDataType(scale_data_type));
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7000
if (BatchnormSpatialPersistentEnabled() && is_training) {
@@ -3144,7 +2920,9 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
#endif
float one = 1.0;
float zero = 0.0;
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = CUDNN_STATUS_SUCCESS;
if (is_training) {
CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
<< "batch_mean and batch_var must both be null or both be non-null";
@@ -3161,11 +2939,11 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
batch_var_opaque = nullptr;
}
- status = wrap::cudnnBatchNormalizationForwardTraining(
- this, stream, ToHandle(dnn_handle_), mode, &one, &zero,
- x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
- scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0,
- batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
+ status = cudnnBatchNormalizationForwardTraining(
+ cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
+ x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
+ scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
+ batch_var_opaque, epsilon, saved_mean->opaque(),
saved_inv_var->opaque());
#if CUDNN_VERSION < 5000
CHECK(inv_var_to_var);
@@ -3178,11 +2956,11 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
#else
const void* maybe_inv_var = estimated_variance.opaque();
#endif
- status = wrap::cudnnBatchNormalizationForwardInference(
- this, stream, ToHandle(dnn_handle_), mode, &one, &zero,
- x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(),
- scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(),
- estimated_mean.opaque(), maybe_inv_var, epsilon);
+ status = cudnnBatchNormalizationForwardInference(
+ cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
+ x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
+ scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
+ epsilon);
}
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
@@ -3229,18 +3007,10 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
DeviceMemory<U>* offset_backprop) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
- ScopedTensorDescriptor x_descriptor{
- parent_, x_desc, static_cast<cudnnDataType_t>(cudnn_input_type)};
- ScopedTensorDescriptor scale_offset_descriptor{
- parent_, scale_offset_desc,
- static_cast<cudnnDataType_t>(cudnn_scale_type)};
+ ScopedTensorDescriptor x_descriptor(
+ x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
+ ScopedTensorDescriptor scale_offset_descriptor(
+ scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7000
if (BatchnormSpatialPersistentEnabled()) {
@@ -3250,10 +3020,12 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
float one = 1.0;
float zero = 0.0;
- status = wrap::cudnnBatchNormalizationBackward(
- this, stream, ToHandle(dnn_handle_), mode, &one, &zero, &one, &zero,
- x_descriptor.handle(), x.opaque(), x_descriptor.handle(),
- y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(),
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+
+ auto status = cudnnBatchNormalizationBackward(
+ cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
+ x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
+ x_descriptor.handle(), x_backprop->opaque(),
scale_offset_descriptor.handle(), scale.opaque(),
scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
mean.opaque(), inv_var.opaque());
@@ -3416,11 +3188,21 @@ bool CudnnSupport::DoFusedConvolve(
#endif
}
-template<class T>
-DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
- Stream* stream,
- BatchDescriptor* output_descriptor,
- DeviceMemory<T> backward_output_data,
+namespace {
+// NOTE(keveman): Temporary data layout transformation until cuDNN supports
+// kBatchYXDepth for backward pass. This function allocates temporary memory,
+// lays out the source data into the temporary but in the kBatchDepthXY
+// layout, and returns the temporary memory. The caller is responsible for
+// deallocating the temporary. Since the allocation is done using Stream's
+// AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
+// deallocation.
+//
+// transform_scratch is populated with a legitimate temporary allocation iff
+// the original output data needs to be transformed.
+template <class T>
+DeviceMemory<T> MaybeTransformLayout(
+ Stream* stream, const CudnnHandle& cudnn,
+ BatchDescriptor* output_descriptor, DeviceMemory<T> backward_output_data,
std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) {
if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
return backward_output_data;
@@ -3433,15 +3215,14 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
transformed_output_descriptor.CloneFrom(*output_descriptor);
transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor orig_out_back_nd{parent_, *output_descriptor,
- cudnn_type};
- ScopedTensorDescriptor transformed_out_back_nd{
- parent_, transformed_output_descriptor, cudnn_type};
+ ScopedTensorDescriptor orig_out_back_nd(*output_descriptor, cudnn_type);
+ ScopedTensorDescriptor transformed_out_back_nd(transformed_output_descriptor,
+ cudnn_type);
float alpha = 1.0f;
float beta = 0.0f;
- auto status = wrap::cudnnTransformTensor(
- this, stream, ToHandle(dnn_handle_), &alpha, orig_out_back_nd.handle(),
+ auto status = cudnnTransformTensor(
+ cudnn.handle(), &alpha, orig_out_back_nd.handle(),
backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
(*transform_scratch)->mutable_device_memory()->opaque());
@@ -3451,6 +3232,7 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout(
output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
return (*transform_scratch)->device_memory();
}
+} // namespace
bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
@@ -3459,21 +3241,15 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& output_desc,
dnn::DataType output_type, float scale,
DeviceMemoryBase* output_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
- }
-
float beta = 0.0f;
ScopedTensorDescriptor input_tensor_desc(
- parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout()));
+ input_desc, ToCudnnDataType(input_type, input_desc.layout()));
ScopedTensorDescriptor output_tensor_desc(
- parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout()));
- status = wrap::cudnnTransformTensor(
- this, stream, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
- input_data.opaque(), &beta, output_tensor_desc.handle(),
- output_data->opaque());
+ output_desc, ToCudnnDataType(output_type, output_desc.layout()));
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnTransformTensor(
+ cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
+ &beta, output_tensor_desc.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Could not transform a tensor with layout "
<< input_desc.ToString() << " and data type "
@@ -3487,8 +3263,7 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
template <class T>
bool CudnnSupport::DoConvolveBackwardDataImpl(
- Stream* stream,
- const FilterDescriptor& filter_descriptor,
+ Stream* stream, const FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
const BatchDescriptor& output_descriptor_in,
DeviceMemory<T> backward_output_data,
@@ -3497,12 +3272,6 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
DeviceMemory<T>* backward_input_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
- }
-
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
// Alpha is the scaling factor for input.
float falpha = 1.0;
@@ -3515,19 +3284,21 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
: static_cast<void*>(&fbeta);
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+
// TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
BatchDescriptor output_descriptor;
output_descriptor.CloneFrom(output_descriptor_in);
std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
- backward_output_data = MaybeTransformLayout(
- stream, &output_descriptor, backward_output_data, &transform_scratch);
+ backward_output_data =
+ MaybeTransformLayout(stream, cudnn, &output_descriptor,
+ backward_output_data, &transform_scratch);
- ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type};
- ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, cudnn_type};
- ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor,
- cudnn_type};
- ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
- GetConvComputeType<T>()};
+ ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
+ ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
+ ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
+ ScopedConvolutionDescriptor conv(convolution_descriptor,
+ GetConvComputeType<T>());
const bool is_profiling = output_profile_result != nullptr;
cudnnConvolutionBwdDataAlgo_t algo;
@@ -3535,8 +3306,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
if (algorithm_config.algorithm().is_default()) {
// With the default algorithm, use Cudnn's heuristics.
- auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
- dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t {
+ auto get_algorithm =
+ [&](bool specify_limit) -> cudnnConvolutionBwdDataAlgo_t {
cudnnConvolutionBwdDataPreference_t preference =
specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
: CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
@@ -3549,8 +3320,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
memory_limit_bytes = 0;
}
cudnnConvolutionBwdDataAlgo_t algo_to_use;
- cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardDataAlgorithm(
- parent_, ToHandle(dnn_handle_),
+ cudnnStatus_t status = cudnnGetConvolutionBackwardDataAlgorithm(
+ cudnn.handle(),
/*filterDesc=*/filter.handle(),
/*diffDesc=*/out_back_nd.handle(),
/*convDesc=*/conv.handle(),
@@ -3568,8 +3339,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
- status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
- parent_, ToHandle(dnn_handle_),
+ auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
+ cudnn.handle(),
/*filterDesc=*/filter.handle(),
/*diffDesc=*/out_back_nd.handle(),
/*convDesc=*/conv.handle(),
@@ -3605,8 +3376,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
algo = ToConvBackwardDataAlgo(algotype);
conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
size_t size_in_bytes;
- status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize(
- parent_, ToHandle(dnn_handle_),
+ auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
+ cudnn.handle(),
/*filterDesc=*/filter.handle(),
/*diffDesc=*/out_back_nd.handle(),
/*convDesc=*/conv.handle(),
@@ -3663,23 +3434,24 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
}
#if CUDNN_VERSION >= 5000
- status = wrap::cudnnConvolutionBackwardData(
+ auto status =
+ cudnnConvolutionBackwardData(cudnn.handle(),
#else
- status = wrap::cudnnConvolutionBackwardData_v3(
+ auto status =
+ cudnnConvolutionBackwardData_v3(cudnn.handle(),
#endif
- this, stream, ToHandle(dnn_handle_),
- /*alpha=*/alpha,
- /*filterDesc=*/filter.handle(),
- /*filterData=*/filter_data.opaque(),
- /*diffDesc=*/out_back_nd.handle(),
- /*diffData=*/backward_output_data.opaque(),
- /*convDesc=*/conv.handle(),
- /*algo=*/algo,
- /*workSpace=*/scratch.opaque(),
- /*workSpaceSizeInBytes=*/scratch.size(),
- /*beta=*/beta,
- /*gradDesc=*/in_back_nd.handle(),
- /*gradData=*/backward_input_data->opaque());
+ /*alpha=*/alpha,
+ /*wDesc=*/filter.handle(),
+ /*w=*/filter_data.opaque(),
+ /*dyDesc=*/out_back_nd.handle(),
+ /*dy=*/backward_output_data.opaque(),
+ /*convDesc=*/conv.handle(),
+ /*algo=*/algo,
+ /*workSpace=*/scratch.opaque(),
+ /*workSpaceSizeInBytes=*/scratch.size(),
+ /*beta=*/beta,
+ /*dxDesc=*/in_back_nd.handle(),
+ /*dx=*/backward_input_data->opaque());
if (is_profiling) {
timer->Stop(AsCUDAStream(stream));
if (status == CUDNN_STATUS_SUCCESS) {
@@ -3767,12 +3539,6 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
DeviceMemory<T>* backward_filter_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
- }
-
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
// Alpha is the scaling factor for input.
float falpha = 1.0;
@@ -3785,19 +3551,21 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
: static_cast<void*>(&fbeta);
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+
// TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
BatchDescriptor output_descriptor;
output_descriptor.CloneFrom(output_descriptor_in);
std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
- backward_output_data = MaybeTransformLayout(
- stream, &output_descriptor, backward_output_data, &transform_scratch);
+ backward_output_data =
+ MaybeTransformLayout(stream, cudnn, &output_descriptor,
+ backward_output_data, &transform_scratch);
- ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type};
- ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type};
- ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor,
- cudnn_type};
- ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
- GetConvComputeType<T>()};
+ ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
+ ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
+ ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
+ ScopedConvolutionDescriptor conv(convolution_descriptor,
+ GetConvComputeType<T>());
const bool is_profiling = output_profile_result != nullptr;
cudnnConvolutionBwdFilterAlgo_t algo;
@@ -3809,8 +3577,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
// Lambda that retrieves the algorithm.
// specify_limit will occur when we have a scratch allocator and it succeeds
// in allocating; otherwise, we'll fall back to the "no workspace" version.
- auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
- dnn_handle_mutex_) {
+ auto get_algorithm = [&](bool specify_limit) {
cudnnConvolutionBwdFilterPreference_t preference =
specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
: CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
@@ -3824,8 +3591,8 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
}
cudnnConvolutionBwdFilterAlgo_t algo_to_use;
- cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardFilterAlgorithm(
- parent_, ToHandle(dnn_handle_),
+ cudnnStatus_t status = cudnnGetConvolutionBackwardFilterAlgorithm(
+ cudnn.handle(),
/*srcDesc=*/input_nd.handle(),
/*diffDesc=*/out_back_nd.handle(),
/*convDesc=*/conv.handle(),
@@ -3843,9 +3610,10 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
if (scratch_allocator != nullptr) {
size_t size_in_bytes;
- status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
- parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
- /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
+ auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), /*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
@@ -3878,9 +3646,10 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
size_t size_in_bytes;
- status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize(
- parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(),
- /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
+ auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), /*algo=*/algo,
/*sizeInBytes=*/&size_in_bytes);
if (status != CUDNN_STATUS_SUCCESS) {
@@ -3934,11 +3703,13 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
}
#if CUDNN_VERSION >= 5000
- status = wrap::cudnnConvolutionBackwardFilter(
+ auto status = cudnnConvolutionBackwardFilter(
+ cudnn.handle(),
#else
- status = wrap::cudnnConvolutionBackwardFilter_v3(
+ auto status = cudnnConvolutionBackwardFilter_v3(
+ cudnn.handle(),
#endif
- this, stream, ToHandle(dnn_handle_), /*alpha=*/alpha,
+ /*alpha=*/alpha,
/*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(),
/*diffDesc=*/out_back_nd.handle(),
@@ -4033,25 +3804,19 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<T>* backward_bias_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
- }
-
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type};
- ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, cudnn_type};
+ ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
+ ScopedTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
- status = wrap::cudnnConvolutionBackwardBias(
- this, stream, ToHandle(dnn_handle_), &alpha, input_nd.handle(),
- input_data.opaque(), &beta, bias_nd.handle(),
- backward_bias_data->opaque());
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnConvolutionBackwardBias(
+ cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
+ bias_nd.handle(), backward_bias_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue backward convolution on stream: "
<< ToString(status);
@@ -4227,8 +3992,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
const DeviceMemory<float>& biases,
const dnn::BatchDescriptor& dimensions,
DeviceMemory<float>* output_data) {
- ScopedTensorDescriptor input_descriptor{parent_, dimensions,
- CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
BatchDescriptor bias_dimensions;
bias_dimensions.set_count(1)
@@ -4236,8 +4000,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
.set_height(1)
.set_width(1)
.set_layout(dnn::DataLayout::kBatchYXDepth);
- ScopedTensorDescriptor bias_descriptor{parent_, bias_dimensions,
- CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
// cudnnAddTensor after R3 is in-place, so we need to copy input_data to
// output_data before doing the addition, unless the input and
@@ -4253,23 +4016,18 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
}
}
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
const float alpha = 1.0f;
const float beta = 1.0f;
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+
#if CUDNN_VERSION >= 5000
- status = wrap::cudnnAddTensor(
+ auto status = cudnnAddTensor(
#else
- status = wrap::cudnnAddTensor_v3(
+ auto status = cudnnAddTensor_v3(
#endif
- this, stream, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(),
- biases.opaque(), &beta, input_descriptor.handle(), output_data->opaque());
+ cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta,
+ input_descriptor.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
@@ -4285,16 +4043,9 @@ bool CudnnSupport::DoActivate(Stream* stream,
const DeviceMemory<float>& input_data,
DeviceMemory<float>* output_data,
uint64 options) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
#if CUDNN_VERSION >= 5000
- ScopedActivationDescriptor activation_desc{
- parent_, activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()};
+ ScopedActivationDescriptor activation_desc(
+ activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
#else
cudnnActivationMode_t mode;
switch (activation_mode) {
@@ -4324,20 +4075,22 @@ bool CudnnSupport::DoActivate(Stream* stream,
}
#endif
- ScopedTensorDescriptor input_nd{parent_, dimensions, CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
// Alpha is the input scaling factor.
float alpha = 1.0;
// Beta is the output scaling factor.
float beta = 0.0;
- status = wrap::cudnnActivationForward(
- this, stream, ToHandle(dnn_handle_),
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status =
+ cudnnActivationForward(cudnn.handle(),
#if CUDNN_VERSION >= 5000
- activation_desc.handle(),
+ activation_desc.handle(),
#else
- mode,
+ mode,
#endif
- &alpha, input_nd.handle(), input_data.opaque(), &beta, input_nd.handle(),
- output_data->opaque());
+ &alpha, input_nd.handle(), input_data.opaque(),
+ &beta, input_nd.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "stream " << stream
<< " could not enqueue activation: " << ToString(status);
@@ -4353,26 +4106,19 @@ bool CudnnSupport::DoPoolForward(
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<double>* output_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
// Alpha is the scaling factor for input.
double alpha = 1.0;
// Beta is the scaling factor for output.
double beta = 0.0;
- ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_DOUBLE};
- ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
- CUDNN_DATA_DOUBLE};
- ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
- status = wrap::cudnnPoolingForward(
- this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
- src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
- output_data->opaque());
+ ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
+ ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
+ ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
<< ToString(status);
@@ -4387,26 +4133,19 @@ bool CudnnSupport::DoPoolForward(
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT};
- ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
- CUDNN_DATA_FLOAT};
- ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
- status = wrap::cudnnPoolingForward(
- this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
- src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
- output_data->opaque());
+ ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
+ ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
+ ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
<< ToString(status);
@@ -4421,25 +4160,18 @@ bool CudnnSupport::DoPoolForward(
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<Eigen::half>* output_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
- ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
- ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
- status = wrap::cudnnPoolingForward(
- this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
- src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
- output_data->opaque());
+ ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
+ ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
+ ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
<< ToString(status);
@@ -4456,27 +4188,21 @@ bool CudnnSupport::DoPoolBackward(
const DeviceMemory<double>& output_data,
const DeviceMemory<double>& input_diff_data,
DeviceMemory<double>* output_diff_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
// Alpha is the scaling factor for input.
double alpha = 1.0;
// Beta is the scaling factor for output.
double beta = 0.0;
- ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_DOUBLE};
- ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
- CUDNN_DATA_DOUBLE};
- ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
- status = wrap::cudnnPoolingBackward(
- this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
- dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
- input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
- src_desc.handle(), output_diff_data->opaque());
+ ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
+ ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
+ ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
@@ -4493,27 +4219,21 @@ bool CudnnSupport::DoPoolBackward(
const DeviceMemory<float>& output_data,
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT};
- ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
- CUDNN_DATA_FLOAT};
- ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
- status = wrap::cudnnPoolingBackward(
- this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
- dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
- input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
- src_desc.handle(), output_diff_data->opaque());
+ ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
+ ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
+ ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
@@ -4530,26 +4250,21 @@ bool CudnnSupport::DoPoolBackward(
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data) {
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
- ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
- ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
- status = wrap::cudnnPoolingBackward(
- this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
- dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
- input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
- src_desc.handle(), output_diff_data->opaque());
+ ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
+ ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
+ ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
@@ -4571,7 +4286,7 @@ bool CudnnSupport::DoNormalizeWithDimensions(
const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
// Check for unsupported modes.
if (normalize_descriptor.wrap_around()) {
- LOG(ERROR) << "CUDA LRN does not support wrap-around mode";
+ LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
return false;
}
if (normalize_descriptor.segment_size()) {
@@ -4579,26 +4294,21 @@ bool CudnnSupport::DoNormalizeWithDimensions(
return false;
}
- // Launch the normalization.
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
- ScopedTensorDescriptor dims{parent_, dimensions, CUDNN_DATA_FLOAT};
- ScopedNormalizeDescriptor normalize{parent_, normalize_descriptor};
+ ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
+ ScopedNormalizeDescriptor normalize(normalize_descriptor);
// Alpha is the scaling factor for input.
float alpha = 1.0f;
// Beta is the scaling factor for output.
float beta = 0.0f;
- status = wrap::cudnnLRNCrossChannelForward(
- this, stream, ToHandle(dnn_handle_), normalize.handle(),
- CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), input_data.opaque(),
- &beta, dims.handle(), output_data->opaque());
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+
+ // Launch the normalization.
+ auto status = cudnnLRNCrossChannelForward(
+ cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha,
+ dims.handle(), input_data.opaque(), &beta, dims.handle(),
+ output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward";
return false;
@@ -4614,7 +4324,7 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
DeviceMemory<float>* raw_variable_gradient) {
// Check for unsupported modes.
if (normalize_descriptor.wrap_around()) {
- LOG(ERROR) << "CUDA LRN does not support wrap-around mode";
+ LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
return false;
}
if (normalize_descriptor.segment_size()) {
@@ -4622,23 +4332,16 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
return false;
}
- mutex_lock lock{dnn_handle_mutex_};
- auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
- return false;
- }
-
- ScopedTensorDescriptor dims{parent_, dimensions, CUDNN_DATA_FLOAT};
- ScopedNormalizeDescriptor normalize{parent_, normalize_descriptor};
+ ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
+ ScopedNormalizeDescriptor normalize(normalize_descriptor);
float alpha = 1.0f;
float beta = 0.0f;
- status = wrap::cudnnLRNCrossChannelBackward(
- this, stream, ToHandle(dnn_handle_), normalize.handle(),
- CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(),
- normalized_data.opaque(), dims.handle(),
+ auto cudnn = cudnn_->GetHandle(parent_, stream);
+ auto status = cudnnLRNCrossChannelBackward(
+ cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha,
+ dims.handle(), normalized_data.opaque(), dims.handle(),
normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
&beta, dims.handle(), raw_variable_gradient->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
@@ -4754,17 +4457,14 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
const FilterDescriptor& filter_descriptor,
const dnn::ConvolutionDescriptor& convolution_descriptor,
dnn::BatchDescriptor* output_batch_descriptor) {
- ScopedTensorDescriptor input_nd{parent_, batch_descriptor, CUDNN_DATA_FLOAT};
- ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor,
- CUDNN_DATA_FLOAT};
- ScopedConvolutionDescriptor conv{parent_, convolution_descriptor,
- CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
+ ScopedFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
+ ScopedConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
int dn = batch_descriptor.ndims() + 2;
std::vector<int> dims(dn); // in BDYX
- auto status = wrap::cudnnGetConvolutionNdForwardOutputDim(
- parent_, conv.handle(), input_nd.handle(), filter.handle(), dn,
- dims.data());
+ auto status = cudnnGetConvolutionNdForwardOutputDim(
+ conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "could not get output tensor for convolution: "
<< ToString(status);
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index dfe2779949..e2de3c62d8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -19,6 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/mutex.h"
@@ -42,7 +43,6 @@ extern const PluginId kCuDnnPlugin;
class CudnnSupport : public dnn::DnnSupport {
public:
explicit CudnnSupport(CUDAExecutor* parent);
- ~CudnnSupport() override;
port::Status Init() override;
port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
@@ -624,54 +624,11 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::DataType output_type, float scale,
DeviceMemoryBase* output_data) override;
- const Stream* GetCurrentDnnStream() const
- SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) {
- return current_dnn_stream_;
- }
-
- void SetCurrentDnnStream(Stream* stream)
- EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_) {
- current_dnn_stream_ = stream;
- }
-
- CUDAExecutor* GetParentExecutor() { return parent_; }
-
- // Guards the enqueueing of DNN operations via the dnn_handle_ below, and
- // access to current_dnn_stream_.
- //
- // This is a public member because we need to add thread safety annotations in
- // the cudnn wrapper functions in the cc file, which need to access this
- // mutex (the annotations require C++ permission checks).
- mutex dnn_handle_mutex_;
-
private:
CUDAExecutor* parent_; // Parent executor object. Not owned.
- // cudnn library handle. cudnnHandle_t type is not present in this header to
- // prevent third-party library header inclusions from leaking outside the
- // single cuda_dnn translation unit.
- void* dnn_handle_ GUARDED_BY(dnn_handle_mutex_);
-
- // The current cudnn stream that is set by cudnnSetStream().
- Stream* current_dnn_stream_ GUARDED_BY(dnn_handle_mutex_);
-
- // NOTE(keveman): Temporary data layout transformation until cuDNN supports
- // kBatchYXDepth for backward pass. This function allocates temporary memory,
- // lays out the source data into the temporary but in the kBatchDepthXY
- // layout, and returns the temporary memory. The caller is responsible for
- // deallocating the temporary. Since the allocation is done using Stream's
- // AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
- // deallocation.
- //
- // transform_scratch is populated with a legitimate temporary allocation iff
- // the original output data needs to be transformed.
- template<class T>
- DeviceMemory<T> MaybeTransformLayout(
- Stream* stream,
- dnn::BatchDescriptor* output_descriptor,
- DeviceMemory<T> backward_output_data,
- std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch)
- EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_);
+ // Provides access to the cuDNN handle.
+ std::unique_ptr<class CudnnAccess> cudnn_;
template <class T, class U>
bool DoBatchNormalizationForwardImpl(
@@ -700,7 +657,7 @@ class CudnnSupport : public dnn::DnnSupport {
template <class T>
bool DoConvolveImpl(Stream* stream,
- const dnn::BatchDescriptor& batch_descriptor,
+ const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index 15523028c7..eeb1fab40c 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -262,6 +262,10 @@ Status InitializeSession(int num_threads, const string& graph,
tensorflow::GraphDef tensorflow_graph;
Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get());
if (!s.ok()) {
+ s = ReadTextProto(Env::Default(), graph, graph_def->get());
+ }
+
+ if (!s.ok()) {
LOG(ERROR) << "Could not create TensorFlow Graph: " << s;
return s;
}
diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc
index 16ab2ff66e..6813045d63 100644
--- a/tensorflow/tools/benchmark/benchmark_model_test.cc
+++ b/tensorflow/tools/benchmark/benchmark_model_test.cc
@@ -26,30 +26,36 @@ limitations under the License.
namespace tensorflow {
namespace {
-TEST(BenchmarkModelTest, InitializeAndRun) {
- const string dir = testing::TmpDir();
- const string filename_pb = io::JoinPath(dir, "graphdef.pb");
-
+void CreateTestGraph(const ::tensorflow::Scope& root,
+ benchmark_model::InputLayerInfo* input,
+ string* output_name, GraphDef* graph_def) {
// Create a simple graph and write it to filename_pb.
const int input_width = 400;
const int input_height = 10;
- benchmark_model::InputLayerInfo input;
- input.shape = TensorShape({input_width, input_height});
- input.data_type = DT_FLOAT;
+ input->shape = TensorShape({input_width, input_height});
+ input->data_type = DT_FLOAT;
const TensorShape constant_shape({input_height, input_width});
Tensor constant_tensor(DT_FLOAT, constant_shape);
test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; });
- auto root = Scope::NewRootScope().ExitOnError();
auto placeholder =
- ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input.shape));
- input.name = placeholder.node()->name();
+ ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input->shape));
+ input->name = placeholder.node()->name();
auto m = ops::MatMul(root, placeholder, constant_tensor);
- const string output_name = m.node()->name();
+ *output_name = m.node()->name();
+ TF_ASSERT_OK(root.ToGraphDef(graph_def));
+}
+
+TEST(BenchmarkModelTest, InitializeAndRun) {
+ const string dir = testing::TmpDir();
+ const string filename_pb = io::JoinPath(dir, "graphdef.pb");
+ auto root = Scope::NewRootScope().ExitOnError();
+ benchmark_model::InputLayerInfo input;
+ string output_name;
GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ CreateTestGraph(root, &input, &output_name, &graph_def);
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
TF_ASSERT_OK(
@@ -69,5 +75,30 @@ TEST(BenchmarkModelTest, InitializeAndRun) {
ASSERT_EQ(num_runs, 10);
}
+TEST(BenchmarkModeTest, TextProto) {
+ const string dir = testing::TmpDir();
+ const string filename_txt = io::JoinPath(dir, "graphdef.pb.txt");
+ auto root = Scope::NewRootScope().ExitOnError();
+
+ benchmark_model::InputLayerInfo input;
+ string output_name;
+ GraphDef graph_def;
+ CreateTestGraph(root, &input, &output_name, &graph_def);
+ TF_ASSERT_OK(WriteTextProto(Env::Default(), filename_txt, graph_def));
+
+ std::unique_ptr<Session> session;
+ std::unique_ptr<GraphDef> loaded_graph_def;
+ TF_ASSERT_OK(benchmark_model::InitializeSession(1, filename_txt, &session,
+ &loaded_graph_def));
+ std::unique_ptr<StatSummarizer> stats;
+ stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get())));
+ int64 time;
+ int64 num_runs = 0;
+ TF_ASSERT_OK(benchmark_model::TimeMultipleRuns(
+ 0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(),
+ &time, &num_runs));
+ ASSERT_EQ(num_runs, 10);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index 9ddb219048..00bfcfd49b 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -250,7 +250,7 @@ def update_md_files(old_version, new_version):
# Update any links to colab notebooks.
def colab_url(version):
- version_string = "%d.%d.%d" % (version.major, version.minor, version.patch)
+ version_string = "%s.%s.%s" % (version.major, version.minor, version.patch)
prefix = "https://colab.research.google.com/github/tensorflow/models/blob/r"
return prefix + version_string + "/"
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8b6ad0a138..01d424f20b 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -228,6 +228,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
strip_prefix = "libpng-1.6.34",
build_file = clean_dep("//third_party:png.BUILD"),
+ patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
)
tf_http_archive(
@@ -452,11 +453,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/b3f6a6a61625296bb532a65c0bf51b91b05b3361.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/b3f6a6a61625296bb532a65c0bf51b91b05b3361.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7b8a8728fbd27086efbf3c57cf2bb35a557108c9.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/7b8a8728fbd27086efbf3c57cf2bb35a557108c9.tar.gz",
],
- sha256 = "93895b289a78a47a1e75652e12a1b9a6c119f086a509b00e0084cf2bb944b709",
- strip_prefix = "llvm-b3f6a6a61625296bb532a65c0bf51b91b05b3361",
+ sha256 = "c620859c3ae5818f316de4837f340b3bba1646f8add0a28e6d4da34ce47e3969",
+ strip_prefix = "llvm-7b8a8728fbd27086efbf3c57cf2bb35a557108c9",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
@@ -744,6 +745,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
)
+ tf_http_archive(
+ name = "tflite_ovic_testdata",
+ sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
+ ],
+ build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
+ strip_prefix = "ovic",
+ )
+
##############################################################################
# BIND DEFINITIONS
#
diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl
index 54d383d7d7..cfd8bfe98d 100644
--- a/third_party/clang_toolchain/download_clang.bzl
+++ b/third_party/clang_toolchain/download_clang.bzl
@@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder):
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
- CLANG_REVISION = '321529'
+ CLANG_REVISION = '330570'
CLANG_SUB_REVISION = 2
package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION)
checksums = {
'Linux_x64':
- '76d4eb1ad011e3127c4a9de9b9f5d4ac624b5a9395c4d7395c9e0a487b13daf6',
+ '2108e172e05d4904c3c46125a33ab4a1175b36ec2a2226619a243e1d8f397e97',
'Mac':
- '4b2a7a65ac1ee892b318c723eec8771f514bb306f346aa8216bb0006f19d87b7',
+ '481b5c6909f0ea250216061bd45e9c982b4befff65cbfca2ee1090c21a109eac',
'Win':
- 'eba51bb8f84af41a85903113666bd21c22709010c39c4cb19dc20cf1ed14581b',
+ '8f04a3ac99d463d4179eb2f68a13575408c3dddc62887a1e441c77123e35e301',
}
platform_folder = _get_platform_folder(repo_ctx.os.name)
diff --git a/third_party/png_fix_rpi.patch b/third_party/png_fix_rpi.patch
new file mode 100644
index 0000000000..80da7b3c06
--- /dev/null
+++ b/third_party/png_fix_rpi.patch
@@ -0,0 +1,16 @@
+diff -r -u /tmp/libpng-1.6.34/scripts/pnglibconf.h.prebuilt ./scripts/pnglibconf.h.prebuilt
+--- /tmp/libpng-1.6.34/scripts/pnglibconf.h.prebuilt 2017-09-29 01:42:33.000000000 -0700
++++ ./scripts/pnglibconf.h.prebuilt 2018-05-01 09:51:24.719318242 -0700
+@@ -20,6 +20,12 @@
+ #define PNG_ALIGNED_MEMORY_SUPPORTED
+ /*#undef PNG_ARM_NEON_API_SUPPORTED*/
+ /*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/
++
++/* Workaround not having a great build file by forcing
++ * png filter optimization to be disabled on arm */
++#define PNG_ARM_NEON_OPT 0
++
++
+ /*#undef PNG_POWERPC_VSX_API_SUPPORTED*/
+ /*#undef PNG_POWERPC_VSX_CHECK_SUPPORTED*/
+ #define PNG_BENIGN_ERRORS_SUPPORTED
diff --git a/third_party/tflite_ovic_testdata.BUILD b/third_party/tflite_ovic_testdata.BUILD
new file mode 100644
index 0000000000..de47ed61f9
--- /dev/null
+++ b/third_party/tflite_ovic_testdata.BUILD
@@ -0,0 +1,12 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(
+ glob(
+ ["**/*"],
+ exclude = [
+ "BUILD",
+ ],
+ ),
+)