aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/tape.h65
-rw-r--r--tensorflow/compiler/aot/compile.cc12
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc14
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py64
-rw-r--r--tensorflow/compiler/tf2xla/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bias_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cast_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc49
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cross_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc36
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc134
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fft_ops.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc20
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc48
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/l2loss_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/lrn_ops.cc14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc17
-rw-r--r--tensorflow/compiler/tf2xla/kernels/one_hot_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pack_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc45
-rw-r--r--tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc58
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc63
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.h22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc29
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc17
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc36
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc82
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc103
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc36
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc16
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD26
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h12
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h9
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc58
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.h18
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc131
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h21
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc92
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h67
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc17
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc52
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.h29
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc6
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.h12
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc7
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h10
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc47
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h18
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc36
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc33
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h36
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc97
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h66
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc71
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h30
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.cc33
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.h29
-rw-r--r--tensorflow/compiler/xla/client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc36
-rw-r--r--tensorflow/compiler/xla/client/local_client.h10
-rw-r--r--tensorflow/compiler/xla/layout_util.cc22
-rw-r--r--tensorflow/compiler/xla/layout_util.h3
-rw-r--r--tensorflow/compiler/xla/literal_util.cc22
-rw-r--r--tensorflow/compiler/xla/literal_util.h13
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc196
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h23
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc78
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc29
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h16
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD163
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h30
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc113
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc73
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc330
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc294
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc151
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc125
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc136
-rw-r--r--tensorflow/compiler/xla/service/executable.cc13
-rw-r--r--tensorflow/compiler/xla/service/executable.h17
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h43
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc15
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h4
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/service.cc156
-rw-r--r--tensorflow/compiler/xla/service/service.h3
-rw-r--r--tensorflow/compiler/xla/shape_util.cc23
-rw-r--r--tensorflow/compiler/xla/shape_util.h3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/filecheck.cc7
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h20
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc4
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc10
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/xla_data.proto18
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py4
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py24
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py16
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py105
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py30
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py174
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py46
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py4
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py301
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py122
-rw-r--r--tensorflow/contrib/distribute/python/BUILD53
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py7
-rw-r--r--tensorflow/contrib/distribute/python/input_ops.py141
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py265
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py1
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py90
-rw-r--r--tensorflow/contrib/distribute/python/values.py101
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py143
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py64
-rw-r--r--tensorflow/contrib/eager/README.md11
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb364
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb473
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb43
-rw-r--r--tensorflow/contrib/estimator/__init__.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py69
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py119
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py4
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py6
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py7
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py32
-rw-r--r--tensorflow/contrib/linalg/BUILD44
-rw-r--r--tensorflow/contrib/linalg/__init__.py4
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py8
-rw-r--r--tensorflow/contrib/lite/Makefile19
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD1
-rw-r--r--tensorflow/contrib/lite/examples/minimal/minimal.cc71
-rw-r--r--tensorflow/contrib/lite/g3doc/rpi.md2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc3
-rw-r--r--tensorflow/contrib/lite/interpreter.h4
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java8
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc6
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java8
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD2
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc141
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc125
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h30
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc323
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc49
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h17
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc47
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc4
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc16
-rw-r--r--tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec20
-rw-r--r--tensorflow/contrib/lite/profiling/profiler.h17
-rw-r--r--tensorflow/contrib/lite/profiling/profiler_test.cc14
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h27
-rw-r--r--tensorflow/contrib/lite/toco/BUILD12
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc45
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc27
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc12
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc519
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc160
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h5
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h29
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc81
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py2
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py14
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt69
-rw-r--r--tensorflow/core/common_runtime/device.h11
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc3
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h4
-rw-r--r--tensorflow/core/common_runtime/executor.cc114
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc3
-rw-r--r--tensorflow/core/distributed_runtime/BUILD75
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc404
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h90
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc324
-rw-r--r--tensorflow/core/distributed_runtime/device_resolver_distributed.cc133
-rw-r--r--tensorflow/core/distributed_runtime/device_resolver_distributed.h67
-rw-r--r--tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc217
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc27
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc47
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h5
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h173
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc87
-rw-r--r--tensorflow/core/distributed_runtime/worker.h17
-rw-r--r--tensorflow/core/distributed_runtime/worker_env.h5
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h19
-rw-r--r--tensorflow/core/framework/tracking_allocator.h1
-rw-r--r--tensorflow/core/graph/graph.cc1
-rw-r--r--tensorflow/core/graph/graph.h10
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc372
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h37
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc15
-rw-r--r--tensorflow/core/grappler/op_types.cc8
-rw-r--r--tensorflow/core/grappler/op_types.h3
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc420
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc179
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc132
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc76
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc134
-rw-r--r--tensorflow/core/grappler/utils.cc30
-rw-r--r--tensorflow/core/grappler/utils.h20
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc18
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.h4
-rw-r--r--tensorflow/core/grappler/utils/topological_sort_test.cc34
-rw-r--r--tensorflow/core/kernels/BUILD10
-rw-r--r--tensorflow/core/kernels/assign_op.h74
-rw-r--r--tensorflow/core/kernels/broadcast_to_op.h34
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc71
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc74
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc7
-rw-r--r--tensorflow/core/kernels/conv_ops.cc85
-rw-r--r--tensorflow/core/kernels/data/BUILD15
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc14
-rw-r--r--tensorflow/core/kernels/data/captured_function.h11
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc422
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/depthwise_conv_grad_op.cc263
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.cc118
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc18
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h29
-rw-r--r--tensorflow/core/lib/hash/hash.h19
-rw-r--r--tensorflow/core/ops/array_ops.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt76
-rw-r--r--tensorflow/core/ops/dataset_ops.cc22
-rw-r--r--tensorflow/core/ops/ops.pbtxt76
-rw-r--r--tensorflow/core/protobuf/worker.proto70
-rw-r--r--tensorflow/core/protobuf/worker_service.proto10
-rw-r--r--tensorflow/docs_src/community/benchmarks.md18
-rw-r--r--tensorflow/docs_src/community/swift.md8
-rw-r--r--tensorflow/docs_src/install/install_linux.md2
-rw-r--r--tensorflow/docs_src/install/install_mac.md2
-rw-r--r--tensorflow/docs_src/install/install_sources.md2
-rw-r--r--tensorflow/docs_src/install/install_windows.md2
-rw-r--r--tensorflow/docs_src/performance/xla/index.md4
-rw-r--r--tensorflow/go/op/wrappers.go190
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/data/ops/readers.py15
-rw-r--r--tensorflow/python/eager/backprop.py16
-rw-r--r--tensorflow/python/eager/backprop_test.py20
-rw-r--r--tensorflow/python/eager/function.py5
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py292
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py58
-rw-r--r--tensorflow/python/estimator/estimator.py45
-rw-r--r--tensorflow/python/estimator/estimator_test.py67
-rw-r--r--tensorflow/python/grappler/graph_placer.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings.py3
-rw-r--r--tensorflow/python/kernel_tests/BUILD16
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py55
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py222
-rw-r--r--tensorflow/python/kernel_tests/distributions/bijector_test.py12
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD44
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py (renamed from tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py)2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py (renamed from tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py)2
-rw-r--r--tensorflow/python/kernel_tests/reduce_benchmark_test.py107
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py222
-rw-r--r--tensorflow/python/ops/embedding_ops.py9
-rw-r--r--tensorflow/python/ops/linalg/linalg.py2
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_block_diag.py (renamed from tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py)6
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_kronecker.py (renamed from tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py)6
-rw-r--r--tensorflow/python/ops/math_ops.py11
-rw-r--r--tensorflow/python/training/device_util.py27
-rw-r--r--tensorflow/python/training/device_util_test.py89
-rw-r--r--tensorflow/python/training/distribute.py8
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc18
-rw-r--r--tensorflow/stream_executor/dnn.cc1
-rw-r--r--tensorflow/stream_executor/dnn.h6
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py3
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.pbtxt134
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.pbtxt134
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.pbtxt8
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh4
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py27
-rw-r--r--tensorflow/workspace.bzl8
362 files changed, 13030 insertions, 4059 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 97c323b872..8026076b9e 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -380,49 +380,39 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<Gradient*> output_gradients,
const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
- const gtl::FlatMap<int64, int64>& tensor_usage_counts,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
- if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
- if (!output_gradients.empty() && output_gradients[i] != nullptr) {
- // TODO(apassos) figure out how to print debugging information here.
- return errors::InvalidArgument(
- "A gradient was provided for a tensor which is used as part of the "
- "computation.");
- }
- } else {
- if (output_gradients.empty() || output_gradients[i] == nullptr) {
- auto tensor_it = tensor_tape.find(id);
- if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
- auto op_it = op_tape.find(tensor_it->second);
- if (op_it == op_tape.end()) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "failed to find operation producing a tensor");
- }
- bool found = false;
- for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
- found = true;
- (*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
- break;
- }
- }
- if (!found) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "none of operations outputs match expected tensor");
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ auto tensor_it = tensor_tape.find(id);
+ if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+ auto op_it = op_tape.find(tensor_it->second);
+ if (op_it == op_tape.end()) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "failed to find operation producing a tensor");
+ }
+ bool found = false;
+ for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+ if (op_it->second.output_tensor_info[j].id == id) {
+ found = true;
+ (*result)[id].push_back(
+ vspace.Ones(op_it->second.output_tensor_info[j].shape,
+ op_it->second.output_tensor_info[j].dtype));
+ break;
}
- } else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ }
+ if (!found) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "none of operations outputs match expected tensor");
}
} else {
- (*result)[id].push_back(output_gradients[i]);
+ // No record of the target tensor found on the tape, so no gradient
+ // needs to be computed from it. Do nothing.
}
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
@@ -451,8 +441,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
- tensor_tape_, state.op_tape,
- state.tensor_usage_counts, &gradients);
+ tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 31044ff85d..bbc35da2ef 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -44,7 +44,7 @@ namespace {
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
- const xla::Computation& computation,
+ const xla::XlaComputation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
@@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
- xla::CompileOnlyClient::AotComputationInstance instance;
+ xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
@@ -93,14 +93,14 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
- xla::Computation computation;
+ xla::XlaComputation computation;
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
if (!flags.out_session_module.empty()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());
- // Serialize the SessionModule deterministically so that all the outputs of
- // a tf_library genrule are deterministic.
+ // Serialize the HloSnapshot deterministically so that all the outputs of a
+ // tf_library genrule are deterministic.
string proto;
TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index aa9d968265..27ba42b31f 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -525,14 +525,16 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
- "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+ "%arg1.0.1)");
auto add_profile_line = HasSubstr(
- "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+ "%arg1.0.1)");
auto tuple_profile_line = HasSubstr(
- "%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
- "f32[2,2]{1,0} %add)");
- auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
+ "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
+ "%dot.0.2, f32[2,2]{1,0} %add.0.5)");
+ auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
hlo_profile_lines.end());
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 2c084b04fa..7420724bdb 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import itertools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
@@ -155,5 +156,68 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
+class ReduceOpPrecisionTest(XLATestCase):
+
+ def _testReduceSum(self,
+ expected_result,
+ dtype,
+ test_inputs,
+ rtol=1e-3,
+ atol=1e-4):
+ """Tests reduce sum on a list of input arrays.
+
+ For each array in test_inputs, check that performing reduce sum on the array
+ produces a value that is close to the expected result.
+
+ Args:
+ expected_result: the expected result.
+ dtype: the data type of the reduce sum operation.
+ test_inputs: a list of input arrays for the reduce sum operation.
+ rtol: the relative error.
+ atol: the absolute error.
+ """
+
+ for test_input in test_inputs:
+ with self.test_session() as sess:
+ with self.test_scope():
+ a = array_ops.placeholder(dtype)
+ index = array_ops.placeholder(dtypes.int32)
+ out = math_ops.reduce_sum(a, index)
+ result = sess.run(out, {
+ a: np.array(test_input, dtype=dtype),
+ index: [0]
+ })
+ # Compare the results using float32 type.
+ self.assertAllClose(
+ np.float32(result),
+ np.float32(expected_result),
+ rtol=rtol,
+ atol=atol)
+
+ def testReduceSumF16(self):
+ """Tests the reduce sum of float16 doesn't lose too much precision."""
+
+ if np.float16 not in self.all_types:
+ return
+
+ f16_max = np.finfo(np.float16).max
+ self._testReduceSum(
+ f16_max, np.float16,
+ itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3))
+
+ def testReduceSumBF16(self):
+ """Tests the reduce sum of bfloat16 doesn't lose too much precision."""
+
+ if dtypes.bfloat16.as_numpy_dtype not in self.all_types:
+ return
+
+ bf16_max = np.float32(dtypes.bfloat16.max)
+ f32_max = dtypes.float32.max
+ value = min(bf16_max, f32_max - bf16_max)
+ self._testReduceSum(
+ dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
+ itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 942504e6bd..4fca51f54d 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -81,7 +81,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
- "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -168,9 +168,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -215,7 +215,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index b20c1ffc7d..8115a26210 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
const std::vector<const XlaExpression*>& expressions,
std::vector<XlaCompiler::Argument>* args) {
auto builder = ctx->builder();
+ auto client = ctx->compiler()->client();
std::vector<bool> compile_time_constant_flags(expressions.size());
TF_RETURN_IF_ERROR(
@@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.kind = XlaCompiler::Argument::kConstant;
TF_RET_CHECK(expressions[i]->resource() == nullptr)
<< "Input with resource is not yet implemented.";
+ TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
+ expressions[i]->handle()));
TF_ASSIGN_OR_RETURN(auto literal,
- builder->ComputeConstant(expressions[i]->handle()));
+ client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
} else {
@@ -212,7 +215,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
TF_RET_CHECK(arguments.size() == expressions.size());
- std::vector<xla::ComputationDataHandle> handles;
+ std::vector<xla::XlaOp> handles;
for (int64 i = 0; i < expressions.size(); ++i) {
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
continue;
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 00fd08b1a0..85ab4c41bf 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -114,8 +114,8 @@ tf_kernel_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@@ -151,7 +151,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -167,7 +167,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -203,8 +203,8 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:argmax_op",
diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
index 5c9f66df10..1e59868621 100644
--- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
@@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel {
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("AddN requires at least one argument"));
- xla::ComputationDataHandle sum = ctx->Input(0);
+ xla::XlaOp sum = ctx->Input(0);
for (int i = 1; i < ctx->num_inputs(); ++i) {
sum = ctx->builder()->Add(sum, ctx->Input(i));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 931175be11..15e1815a4c 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);
int feature_index =
@@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel {
input = builder->ConvertElementType(input, scale_type);
if (is_training_) {
- xla::ComputationDataHandle output = builder->BatchNormTraining(
+ xla::XlaOp output = builder->BatchNormTraining(
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the
@@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel {
ctx->SetOutput(3, builder->GetTupleElement(output, 1));
ctx->SetOutput(4, builder->GetTupleElement(output, 2));
} else {
- xla::ComputationDataHandle output = builder->BatchNormInference(
+ xla::XlaOp output = builder->BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
@@ -118,7 +118,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
DataType input_dtype = ctx->input_type(0);
DataType scale_dtype = ctx->input_type(2);
@@ -137,11 +137,11 @@ class FusedBatchNormGradOp : public XlaOpKernel {
const int feature_index =
GetTensorFeatureDimIndex(input_dims, data_format_);
- xla::ComputationDataHandle x_backprop;
- xla::ComputationDataHandle scale_backprop;
- xla::ComputationDataHandle offset_backprop;
+ xla::XlaOp x_backprop;
+ xla::XlaOp scale_backprop;
+ xla::XlaOp offset_backprop;
if (is_training_) {
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
b->BatchNormGrad(activations, scale, mean, var, grad_backprop,
epsilon_, feature_index);
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 569950c2df..642278ab99 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -20,9 +20,8 @@ limitations under the License.
namespace tensorflow {
namespace {
-void BatchToSpace(XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input, DataType input_dtype,
- const TensorShape& input_tensor_shape,
+void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
+ DataType input_dtype, const TensorShape& input_tensor_shape,
gtl::ArraySlice<int64> block_shape,
const xla::Literal& crops) {
const int input_rank = input_tensor_shape.dims();
@@ -46,7 +45,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
", 2] instead of ",
xla::ShapeUtil::HumanString(crops.shape())));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const int64 batch_size = input_shape[0];
// Compute the product of the block_shape values.
@@ -73,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
- xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
@@ -91,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
- xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation);
+ xla::XlaOp permuted = b->Transpose(reshaped, permutation);
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
@@ -111,8 +110,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
- xla::ComputationDataHandle reshaped_permuted =
- b->Reshape(permuted, reshaped_permuted_shape);
+ xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape);
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of shape:
@@ -139,7 +137,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
"Cropped size must be non-negative: start: ", crop_start,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
b->Slice(reshaped_permuted, start_indices, end_indices, strides);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index ed33b8ed2e..9d677f4266 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -60,7 +60,7 @@ class BiasOp : public XlaOpKernel {
"of the input tensor: ",
bias_shape.DebugString(), " vs. ", input_shape.DebugString()));
- xla::ComputationDataHandle result =
+ xla::XlaOp result =
ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim});
ctx->SetOutput(0, result);
}
@@ -103,7 +103,7 @@ class BiasAddGradOp : public XlaOpKernel {
std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0);
std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(),
feature_dim + 1);
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
const DataType accumulation_type =
XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 2436a6074a..f04cde878e 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -34,14 +34,13 @@ namespace {
class NAME##Op : public XlaBinaryOp { \
public: \
explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
- xla::ComputationDataHandle Computation( \
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \
- const gtl::ArraySlice<int64>& lhs_shape, \
- const xla::ComputationDataHandle& rhs, \
+ xla::XlaOp Computation( \
+ XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
+ const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs, \
const gtl::ArraySlice<int64>& rhs_shape, \
const BCast& broadcast_helper, \
const std::vector<int64>& extend_dimensions) override { \
- xla::ComputationBuilder* b = ctx->builder(); \
+ xla::XlaBuilder* b = ctx->builder(); \
return HLO; \
} \
}; \
@@ -63,11 +62,8 @@ XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
// } else {
// return x / y;
// }
-static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b,
- DataType dtype,
- xla::ComputationDataHandle x,
- xla::ComputationDataHandle y,
- const BCast& broadcast_helper) {
+static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
@@ -87,11 +83,8 @@ XLA_MAKE_BINARY(FloorDiv,
// Implementation of FloorMod. Pseudo-code:
// T trunc_mod = std::fmod(x, y);
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
-static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b,
- DataType dtype,
- xla::ComputationDataHandle x,
- xla::ComputationDataHandle y,
- const BCast& broadcast_helper) {
+static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero));
@@ -127,8 +120,7 @@ XLA_MAKE_BINARY(SqrtGrad,
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
lhs, extend_dimensions));
-static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& x) {
+static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) {
return builder->Mul(x, x);
}
@@ -175,11 +167,11 @@ class ApproximateEqualOp : public XlaOpKernel {
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
auto abs_shape = b->GetShape(abs);
OP_REQUIRES_OK(ctx, abs_shape.status());
- auto abs_type = abs_shape.ValueOrDie()->element_type();
+ auto abs_type = abs_shape.ValueOrDie().element_type();
auto result = b->Lt(
abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type));
ctx->SetOutput(0, result);
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
index c52b2dcb7e..e9d98c7685 100644
--- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -33,9 +33,9 @@ class CastOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
- xla::ComputationDataHandle output;
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp output;
if (src_dtype_ == dst_dtype_) {
output = input;
@@ -72,9 +72,9 @@ class BitcastOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
- xla::ComputationDataHandle output;
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp output;
if (src_dtype_ == dst_dtype_) {
output = input;
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 545aa364f9..835a7f5689 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -34,7 +34,7 @@ class CategoricalOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
// Get the logits
- const xla::ComputationDataHandle& logits = ctx->Input(0);
+ const xla::XlaOp& logits = ctx->Input(0);
TensorShape logits_shape = ctx->InputShape(0);
int64 num_samples;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples));
@@ -56,7 +56,7 @@ class CategoricalOp : public XlaOpKernel {
const int64 batch_size = logits_shape.dim_size(0);
const int64 num_classes = logits_shape.dim_size(1);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
std::array<int64, 3> uniform_shape_array = {
{batch_size, num_samples, num_classes}};
@@ -78,7 +78,7 @@ class CategoricalOp : public XlaOpKernel {
/*broadcast_dimensions=*/{0, 2});
TensorShape softmax_shape(uniform_shape_array);
- xla::ComputationDataHandle argmax;
+ xla::XlaOp argmax;
OP_REQUIRES_OK(
ctx,
XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape,
diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
index fdf75be7b1..a00bc912f9 100644
--- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
@@ -29,7 +29,7 @@ class ClipByValueOp : public XlaOpKernel {
const TensorShape min_shape = ctx->InputShape(1);
const TensorShape max_shape = ctx->InputShape(2);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto input = ctx->Input(0);
auto min = ctx->Input(1);
auto max = ctx->Input(2);
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 1a246e8df9..78285affa1 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -54,7 +54,7 @@ class ConcatBaseOp : public XlaOpKernel {
// TODO(annarev): add a helper to support int64 input.
const int32 concat_dim = literal.Get<int>({});
- std::vector<xla::ComputationDataHandle> values;
+ std::vector<xla::XlaOp> values;
std::vector<TensorShape> shapes;
OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
const int N = values.size();
@@ -70,13 +70,13 @@ class ConcatBaseOp : public XlaOpKernel {
"[",
-input_dims, ", ", input_dims, "), but got ", concat_dim));
- // Make a vector holding the ComputationDataHandles for each of
- // the inputs that has non-zero elements.
- std::vector<xla::ComputationDataHandle> input_data;
+ // Make a vector holding the XlaOp for each of the inputs that has non-zero
+ // elements.
+ std::vector<xla::XlaOp> input_data;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
- xla::ComputationDataHandle handle = values[i];
+ xla::XlaOp handle = values[i];
const TensorShape& in_shape = shapes[i];
const bool in_is_scalar = IsLegacyScalar(in_shape);
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index 8f78b4c8f9..59d06c654d 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -45,7 +45,7 @@ class ConstOp : public XlaOpKernel {
ctx->SetInvalidOutput(0);
return;
}
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// To avoid blowups for large constants filled with the same value,
// recognize that case and emit a scalar broadcast instead.
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index c0ee0c9c2e..627bad12f3 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -47,9 +47,8 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution(
}
// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::ComputationDataHandle CreateExpandedZero(
- const TensorShape& filter_shape, DataType dtype,
- xla::ComputationBuilder* builder) {
+xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
+ xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
@@ -87,8 +86,8 @@ xla::ComputationDataHandle CreateExpandedZero(
//
// Finally compare A and broadcasted B in dimension 2 amd return the result at
// the beginning of the comment.
-xla::ComputationDataHandle CreateExpandedFilterMask(
- const TensorShape& filter_shape, xla::ComputationBuilder* builder) {
+xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
+ xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
@@ -96,11 +95,11 @@ xla::ComputationDataHandle CreateExpandedFilterMask(
// Create a M sized linspace and an M*N sized linspace that will be
// broadcasted into perpendicular dimensions and compared.
- xla::ComputationDataHandle input_feature_iota;
+ xla::XlaOp input_feature_iota;
// DT_INT32 Iota will always return status::OK().
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
&input_feature_iota));
- xla::ComputationDataHandle expanded_feature_iota;
+ xla::XlaOp expanded_feature_iota;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
input_feature * depthwise_multiplier,
&expanded_feature_iota));
@@ -126,10 +125,10 @@ xla::ComputationDataHandle CreateExpandedFilterMask(
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
// zeros for the cross-depth filters. Used to build a depthwise convolution.
-xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
- const TensorShape& filter_shape, DataType dtype,
- const xla::ComputationDataHandle& filter,
- xla::ComputationBuilder* builder) {
+xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape,
+ DataType dtype,
+ const xla::XlaOp& filter,
+ xla::XlaBuilder* builder) {
int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
TensorShape expanded_filter_shape =
@@ -156,10 +155,11 @@ xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
}
// Inverse of ExpandFilterForDepthwiseConvolution.
-xla::ComputationDataHandle ContractFilterForDepthwiseBackprop(
- XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype,
- const xla::ComputationDataHandle& filter_backprop,
- xla::ComputationBuilder* builder) {
+xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
+ const TensorShape& filter_shape,
+ DataType dtype,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
auto masked_expanded_filter = builder->Select(
@@ -248,9 +248,9 @@ class ConvOp : public XlaOpKernel {
"input and filter must have the same depth: ", in_depth,
" vs ", input_shape.dim_size(feature_dim)));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
- xla::ComputationDataHandle filter = ctx->Input(1);
+ xla::XlaOp filter = ctx->Input(1);
TensorShape expanded_filter_shape = filter_shape;
if (depthwise_) {
filter = ExpandFilterForDepthwiseConvolution(
@@ -288,7 +288,7 @@ class ConvOp : public XlaOpKernel {
&unused_output_size, &padding[i].first, &padding[i].second));
}
- xla::ComputationDataHandle conv =
+ xla::XlaOp conv =
b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
@@ -391,7 +391,7 @@ class ConvBackpropInputOp : public XlaOpKernel {
expanded_filter_shape, out_backprop_shape, dilations_,
strides_, padding_, data_format_, &dims));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto filter = ctx->Input(1);
auto out_backprop = ctx->Input(2);
@@ -435,12 +435,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
}
// Mirror the filter in the spatial dimensions.
- xla::ComputationDataHandle mirrored_weights =
- b->Rev(filter, kernel_spatial_dims);
+ xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
- xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated(
+ xla::XlaOp in_backprop = b->ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
lhs_dilation, rhs_dilation, dnums);
@@ -546,9 +545,9 @@ class ConvBackpropFilterOp : public XlaOpKernel {
expanded_filter_shape, out_backprop_shape, dilations_,
strides_, padding_, data_format_, &dims));
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle activations = ctx->Input(0);
- xla::ComputationDataHandle gradients = ctx->Input(2);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp activations = ctx->Input(0);
+ xla::XlaOp gradients = ctx->Input(2);
// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
index 3df8c00f1b..7fcd4170fb 100644
--- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
@@ -53,7 +53,7 @@ class CrossOp : public XlaOpKernel {
}
std::vector<int64> strides(in0_shape.dims(), 1);
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto in0 = ctx->Input(0);
auto in1 = ctx->Input(1);
starts.back() = 0;
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index 0cf03ceb94..01aa1a83e7 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/bcast.h"
@@ -75,7 +75,7 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
// Call virtual method to emit the computation.
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle,
rhs_shape.dim_sizes(), bcast, extend_dimension);
@@ -85,11 +85,9 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
ctx->SetOutput(0, output);
}
-/* static */ std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
-XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& lhs,
- const xla::ComputationDataHandle& rhs,
- const BCast& broadcast_helper) {
+/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
+ xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
+ const BCast& broadcast_helper) {
// Manually construct the broadcasting since MapN does not do
// automatic broadcasting. The bcast helper ensures that
// lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 5bc1d5fb1f..4f92dbc874 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
@@ -30,7 +30,7 @@ namespace tensorflow {
// inputs that can be broadcast to the same shape. The base class
// contains pure virtual methods to override: description is a textual
// description of the operation; and Computation adds the
-// implementation of the operation to a xla::ComputationBuilder. For most
+// implementation of the operation to a xla::XlaBuilder. For most
// arithmetic Ops XLA handles the broadcasting automatically given the input
// tensors.
class XlaBinaryOp : public XlaOpKernel {
@@ -55,10 +55,9 @@ class XlaBinaryOp : public XlaOpKernel {
// higher-rank input should be matched when broadcasting the
// lower-rank input. See comment below and the documentation on broadcasting
// in the XLA documentation.
- virtual xla::ComputationDataHandle Computation(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
- const gtl::ArraySlice<int64>& lhs_shape,
- const xla::ComputationDataHandle& rhs,
+ virtual xla::XlaOp Computation(
+ XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
+ const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs,
const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
const std::vector<int64>& extend_dimensions) = 0;
@@ -67,11 +66,9 @@ class XlaBinaryOp : public XlaOpKernel {
// Helper function that performs the broadcasting described by
// 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
// shape.
- static std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
- Broadcast(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& lhs,
- const xla::ComputationDataHandle& rhs,
- const BCast& broadcast_helper);
+ static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
+ xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
+ const BCast& broadcast_helper);
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index 96d7809f79..23243f6246 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -50,8 +50,8 @@ class DepthToSpaceOp : public XlaOpKernel {
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
@@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel {
") is not divisible by square of the block size (",
block_size_, ")"));
- xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -141,8 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2],
// block_size_,
// depth / (block_size_ * block_size_)]
- xla::ComputationDataHandle permuted_reshaped =
- b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -152,8 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2] * block_size_,
// depth / (block_size_ * block_size_)]
//
- xla::ComputationDataHandle output =
- b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 765ea922a5..931705ba83 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -25,10 +25,10 @@ namespace tensorflow {
namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
-xla::StatusOr<xla::ComputationDataHandle> CreateDiagonal(
- const xla::ComputationDataHandle& input, int64 last_dim_size,
+xla::StatusOr<xla::XlaOp> CreateDiagonal(
+ const xla::XlaOp& input, int64 last_dim_size,
tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx,
- xla::ComputationBuilder* builder) {
+ xla::XlaBuilder* builder) {
// Create two matrices that have the following forms, and compare them:
//
// [[0, 0, 0, 0] [[0, 1, 2, 3]
@@ -38,12 +38,11 @@ xla::StatusOr<xla::ComputationDataHandle> CreateDiagonal(
//
// This produces a predicate matrix of the right size, with "true" on the
// diagonal.
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
TF_RETURN_IF_ERROR(
XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
- xla::ComputationDataHandle iota_broadcast =
- builder->Broadcast(iota, {last_dim_size});
- xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0});
+ xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size});
+ xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0});
// If this is a batched diagonal, broadcast the mask across the other
// dimensions.
@@ -65,8 +64,7 @@ xla::StatusOr<xla::ComputationDataHandle> CreateDiagonal(
std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
broadcast_dims.push_back(1LL);
broadcast_dims.push_back(last_dim_size);
- xla::ComputationDataHandle input_broadcast =
- builder->Reshape(input, broadcast_dims);
+ xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims);
broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
xla::PrimitiveType element_type;
@@ -74,7 +72,7 @@ xla::StatusOr<xla::ComputationDataHandle> CreateDiagonal(
DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
auto broadcast_shape =
xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
- xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape);
+ xla::XlaOp zeros = Zeros(builder, broadcast_shape);
input_broadcast = builder->Add(input_broadcast, zeros);
return builder->Select(mask, input_broadcast, zeros);
@@ -85,7 +83,7 @@ class DiagOp : public XlaOpKernel {
explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("Diag op must have at an input"));
@@ -96,7 +94,7 @@ class DiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
// Picture:
// tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
@@ -112,7 +110,7 @@ class DiagOp : public XlaOpKernel {
auto diag_or_status =
CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
OP_REQUIRES_OK(ctx, diag_or_status.status());
- xla::ComputationDataHandle diag = diag_or_status.ValueOrDie();
+ xla::XlaOp diag = diag_or_status.ValueOrDie();
// Reshapes to the final shape.
std::vector<int64> new_dims(dims.size() * 2);
@@ -131,7 +129,7 @@ class DiagPartOp : public XlaOpKernel {
explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -158,7 +156,7 @@ class DiagPartOp : public XlaOpKernel {
new_dims.push_back(dims[i]);
}
- xla::ComputationDataHandle diag = ctx->Input(0);
+ xla::XlaOp diag = ctx->Input(0);
// TODO(b/30878775): use Slice with strides when supported, in place of
// the Pad -> Reshape -> Slice.
@@ -199,7 +197,7 @@ class MatrixDiagOp : public XlaOpKernel {
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at an input"));
@@ -210,7 +208,7 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::ComputationDataHandle diag = ctx->Input(0);
+ xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
@@ -232,7 +230,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
@@ -241,7 +239,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
errors::InvalidArgument("Expected 2 <= dims, got shape ",
input_shape.DebugString()));
- xla::ComputationDataHandle diag = ctx->Input(0);
+ xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = dims[last_dim];
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index 800ef5ab98..0419de78b2 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -57,7 +57,7 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::ComputationDataHandle result = ctx->builder()->DynamicUpdateSlice(
+ xla::XlaOp result = ctx->builder()->DynamicUpdateSlice(
ctx->Input(0), ctx->Input(1), ctx->Input(2));
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index f2cd21ffb9..dd4a169087 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -56,7 +56,7 @@ class DynamicStitchOp : public XlaOpKernel {
std::vector<xla::Literal> indices_input;
OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input));
- std::vector<xla::ComputationDataHandle> data;
+ std::vector<xla::XlaOp> data;
std::vector<TensorShape> data_shapes;
OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes));
@@ -136,7 +136,7 @@ class DynamicStitchOp : public XlaOpKernel {
// Look up all the children expressions that represent the data
// inputs.
- std::vector<xla::ComputationDataHandle> input(indices.size());
+ std::vector<xla::XlaOp> input(indices.size());
for (int input_num = 0; input_num < indices.size(); input_num++) {
TensorShape new_shape;
// first reshaped dimension is the number of indices for this input.
@@ -166,7 +166,7 @@ class DynamicStitchOp : public XlaOpKernel {
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
- std::vector<xla::ComputationDataHandle> to_concat(number_of_indices);
+ std::vector<xla::XlaOp> to_concat(number_of_indices);
for (int index_num = 0; index_num < number_of_indices; index_num++) {
const auto& expression = input[src_input_vector[index_num]];
// Take the appropriate slice of data.
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 2fd27c5ca7..ed7462c166 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -32,7 +32,7 @@ class EluOp : public XlaOpKernel {
explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto pred = b->Gt(ctx->Input(0), zero);
@@ -47,7 +47,7 @@ class EluGradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return lhs * (1 + rhs).
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto grad = ctx->Input(0);
@@ -66,7 +66,7 @@ class SeluOp : public XlaOpKernel {
explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
const auto one = XlaHelpers::One(b, input_type(0));
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
@@ -86,9 +86,8 @@ class SeluGradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return lhs * (1 + rhs).
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
- const auto one = XlaHelpers::One(b, input_type(0));
const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
1.0507009873554804934193349852946);
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index b2970eae20..6df01cabbf 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -93,7 +93,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
input_shape.DebugString()));
const int64 depth = input_shape.dim_size(feature_dim);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
// The following code is equivalent to:
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
@@ -110,7 +110,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
kernel_size * depth, &iota));
@@ -147,7 +147,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
&padding[i].first, &padding[i].second));
}
- xla::ComputationDataHandle conv =
+ xla::XlaOp conv =
builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
padding, lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
index 99470d70e7..8f0de0a524 100644
--- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -44,23 +44,20 @@ void CpuNudge(const float min, const float max, const float quant_min,
}
// An XLA version of CpuNudge().
-void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
- const xla::ComputationDataHandle& min,
- const xla::ComputationDataHandle& max,
+void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
+ const xla::XlaOp& min, const xla::XlaOp& max,
const float quant_min_value, const float quant_max_value,
- xla::ComputationDataHandle* nudged_min,
- xla::ComputationDataHandle* nudged_max,
- xla::ComputationDataHandle* scale) {
+ xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
+ xla::XlaOp* scale) {
*scale = b->Div(b->Sub(max, min),
XlaHelpers::FloatLiteral(b, data_type,
quant_max_value - quant_min_value));
- xla::ComputationDataHandle quant_min =
+ xla::XlaOp quant_min =
XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
- xla::ComputationDataHandle zero_point_from_min =
- b->Sub(quant_min, b->Div(min, *scale));
- xla::ComputationDataHandle quant_max =
+ xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale));
+ xla::XlaOp quant_max =
XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
- xla::ComputationDataHandle nudged_zero_point =
+ xla::XlaOp nudged_zero_point =
b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
b->Round(zero_point_from_min)));
@@ -68,22 +65,18 @@ void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
*nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
}
-xla::ComputationDataHandle Quantize(
- xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
- const DataType data_type,
- const xla::ComputationDataHandle& nudged_input_min,
- const xla::ComputationDataHandle& nudged_input_max,
- const xla::ComputationDataHandle& input_scale) {
- xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
- xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
- xla::ComputationDataHandle half =
- XlaHelpers::FloatLiteral(b, data_type, 0.5f);
-
- xla::ComputationDataHandle clamped =
- b->Clamp(nudged_input_min, input, nudged_input_max);
- xla::ComputationDataHandle clamped_shifted =
- b->Sub(clamped, nudged_input_min);
- xla::ComputationDataHandle rounded =
+xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
+ const DataType data_type,
+ const xla::XlaOp& nudged_input_min,
+ const xla::XlaOp& nudged_input_max,
+ const xla::XlaOp& input_scale) {
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
+ xla::XlaOp inv_scale = b->Div(one, input_scale);
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
+
+ xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max);
+ xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min);
+ xla::XlaOp rounded =
b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
}
@@ -111,18 +104,18 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min =
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
- xla::ComputationDataHandle nudged_input_max =
+ xla::XlaOp nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
- xla::ComputationDataHandle input_scale =
+ xla::XlaOp input_scale =
XlaHelpers::FloatLiteral(b, data_type, input_scale_);
- xla::ComputationDataHandle output = Quantize(
- b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
+ xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
+ nudged_input_max, input_scale);
ctx->SetOutput(0, output);
}
@@ -159,23 +152,22 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle gradient = ctx->Input(0);
+ xla::XlaOp gradient = ctx->Input(0);
const TensorShape gradient_shape = ctx->InputShape(0);
- xla::ComputationDataHandle input = ctx->Input(1);
+ xla::XlaOp input = ctx->Input(1);
const DataType data_type = ctx->input_type(1);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min =
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
- xla::ComputationDataHandle nudged_input_max =
+ xla::XlaOp nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
- xla::ComputationDataHandle between_nudged_min_max =
+ xla::XlaOp between_nudged_min_max =
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
- xla::ComputationDataHandle zeroes = b->Broadcast(
- XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
- xla::ComputationDataHandle output =
- b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type),
+ gradient_shape.dim_sizes());
+ xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output);
}
@@ -204,18 +196,18 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
- xla::ComputationDataHandle input_min = ctx->Input(1);
- xla::ComputationDataHandle input_max = ctx->Input(2);
+ xla::XlaOp input_min = ctx->Input(1);
+ xla::XlaOp input_max = ctx->Input(2);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
- xla::ComputationDataHandle output = Quantize(
- b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
+ xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min,
+ nudged_input_max, input_scale);
ctx->SetOutput(0, output);
}
@@ -243,47 +235,43 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle gradient = ctx->Input(0);
+ xla::XlaOp gradient = ctx->Input(0);
const TensorShape gradient_shape = ctx->InputShape(0);
- xla::ComputationDataHandle input = ctx->Input(1);
+ xla::XlaOp input = ctx->Input(1);
const DataType data_type = ctx->input_type(1);
const DataType accumulation_type =
XlaHelpers::SumAccumulationType(data_type);
- xla::ComputationDataHandle input_min = ctx->Input(2);
- xla::ComputationDataHandle input_max = ctx->Input(3);
+ xla::XlaOp input_min = ctx->Input(2);
+ xla::XlaOp input_max = ctx->Input(3);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
- xla::ComputationDataHandle between_nudged_min_max =
+ xla::XlaOp between_nudged_min_max =
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
- xla::ComputationDataHandle zeroes =
- b->Broadcast(zero, gradient_shape.dim_sizes());
- xla::ComputationDataHandle output0 =
- b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
+ xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes());
+ xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output0);
- xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
- xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes);
- xla::ComputationDataHandle reduce1 = b->ReduceAll(
+ xla::XlaOp below_min = b->Lt(input, nudged_input_min);
+ xla::XlaOp select1 = b->Select(below_min, gradient, zeroes);
+ xla::XlaOp reduce1 = b->ReduceAll(
XlaHelpers::ConvertElementType(b, select1, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
- xla::ComputationDataHandle output1 =
- XlaHelpers::ConvertElementType(b, reduce1, data_type);
+ xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type);
ctx->SetOutput(1, output1);
- xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
- xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes);
- xla::ComputationDataHandle reduce2 = b->ReduceAll(
+ xla::XlaOp above_max = b->Gt(input, nudged_input_max);
+ xla::XlaOp select2 = b->Select(above_max, gradient, zeroes);
+ xla::XlaOp reduce2 = b->ReduceAll(
XlaHelpers::ConvertElementType(b, select2, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
- xla::ComputationDataHandle output2 =
- XlaHelpers::ConvertElementType(b, reduce2, data_type);
+ xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type);
ctx->SetOutput(2, output2);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index a4f3c1c3ad..fcb927dab0 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -62,9 +62,8 @@ class GenericFftOp : public XlaOpKernel {
}
}
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle fft =
- b->Fft(ctx->Input(0), fft_type_, fft_length);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length);
ctx->SetOutput(0, fft);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index eaa13b8dfa..e4467a0fb1 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -48,7 +48,7 @@ class FillOp : public XlaOpKernel {
0, {dims_shape.num_elements()}, &dims_literal));
// Convert the dims literal into a vector that we can pass to
- // ComputationBuilder.
+ // XlaBuilder.
std::vector<int64> broadcast;
broadcast.reserve(dims_literal.shape().dimensions(0));
for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
@@ -56,7 +56,7 @@ class FillOp : public XlaOpKernel {
}
// Look up the value input, reshaping to a scalar if it was a
// 'legacy' scalar (secretly a vector).
- xla::ComputationDataHandle data = ctx->Input(1);
+ xla::XlaOp data = ctx->Input(1);
if (value_shape.dims() > 0) {
CHECK_EQ(value_shape.dims(), 1);
data = ctx->builder()->Reshape(data, {});
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 0b79cb0916..d13e25bcdd 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -26,13 +26,11 @@ limitations under the License.
namespace tensorflow {
-Status XlaGather(const xla::ComputationDataHandle& input,
- const TensorShape& input_shape,
- const xla::ComputationDataHandle& indices,
- const TensorShape& indices_shape, int64 axis,
- bool indices_are_nd, DataType dtype, DataType index_type,
- xla::ComputationBuilder* builder,
- xla::ComputationDataHandle* gather_output) {
+Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
+ const xla::XlaOp& indices, const TensorShape& indices_shape,
+ int64 axis, bool indices_are_nd, DataType dtype,
+ DataType index_type, xla::XlaBuilder* builder,
+ xla::XlaOp* gather_output) {
// There is no deep reason why we need this precondition, but this is the only
// combination that is used and tested today.
CHECK(!indices_are_nd || axis == 0);
@@ -153,7 +151,7 @@ class GatherOp : public XlaOpKernel {
explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
auto input = context->Input(0);
auto input_shape = context->InputShape(0);
auto indices = context->Input(1);
@@ -182,7 +180,7 @@ class GatherOp : public XlaOpKernel {
OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
errors::InvalidArgument("indices must be int32 or int64"));
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
context, XlaGather(input, input_shape, indices, indices_shape, axis,
/*indices_are_nd=*/false, input_type(0), index_type,
@@ -220,10 +218,10 @@ class GatherNdOp : public XlaOpKernel {
indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
params_shape.dims()));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
auto params = context->Input(0);
auto indices = context->Input(1);
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
indices_shape, /*axis=*/0,
/*indices_are_nd=*/true, params_type,
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
index f9376f0eab..d898e43b85 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
@@ -33,13 +33,11 @@ namespace tensorflow {
// If `indices_are_nd` is true, the last dimension of `indices` are treated as
// a multidimensional index values. Otherwise, `indices` is treated as a tensor
// of scalar indices.
-Status XlaGather(const xla::ComputationDataHandle& input,
- const TensorShape& input_shape,
- const xla::ComputationDataHandle& indices,
- const TensorShape& indices_shape, int64 axis,
- bool indices_are_nd, DataType dtype, DataType index_type,
- xla::ComputationBuilder* builder,
- xla::ComputationDataHandle* gather_output);
+Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
+ const xla::XlaOp& indices, const TensorShape& indices_shape,
+ int64 axis, bool indices_are_nd, DataType dtype,
+ DataType index_type, xla::XlaBuilder* builder,
+ xla::XlaOp* gather_output);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index eefbe55c81..8b9b026643 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -37,7 +37,7 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
// TODO(b/35949885): There is duplication here with the handling of the
// while_op. Refactor the common code out/rework.
void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
OP_REQUIRES(ctx, cond_type_ == DT_BOOL,
errors::InvalidArgument(
@@ -48,7 +48,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Building If: " << input_types_.size() << " inputs";
- std::vector<xla::ComputationDataHandle> inputs(input_types_.size());
+ std::vector<xla::XlaOp> inputs(input_types_.size());
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
@@ -175,19 +175,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
"Mismatch in resource of then and else branch for resource ", i));
}
- xla::ComputationDataHandle outputs =
+ xla::XlaOp outputs =
b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
b->Tuple(inputs), *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
- xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i);
+ xla::XlaOp output_handle = b->GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
auto shape_or = b->GetShape(output_handle);
if (shape_or.ok()) {
LOG(INFO) << "Shape for output " << i << ": "
- << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie());
+ << xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
} else {
LOG(INFO) << "Shape unknown for output " << i;
}
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 5eeda79a93..1568b33679 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -23,10 +23,9 @@ namespace {
// Converts 'input' from RGB format to HSV format.
// 'shape' is the shape of the red/green/blue tensors.
-std::array<xla::ComputationDataHandle, 3> RGBToHSV(
- XlaOpKernelContext* ctx, xla::ComputationBuilder* b,
- const std::array<xla::ComputationDataHandle, 3>& rgb, DataType dtype,
- const TensorShape& shape) {
+std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
+ const std::array<xla::XlaOp, 3>& rgb,
+ DataType dtype, const TensorShape& shape) {
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
@@ -54,12 +53,12 @@ std::array<xla::ComputationDataHandle, 3> RGBToHSV(
}
// Converts 'input' from HSV format to RGB format.
-std::array<xla::ComputationDataHandle, 3> HSVToRGB(
- xla::ComputationBuilder* b,
- const std::array<xla::ComputationDataHandle, 3>& hsv, DataType dtype) {
- xla::ComputationDataHandle hue = hsv[0];
- xla::ComputationDataHandle saturation = hsv[1];
- xla::ComputationDataHandle value = hsv[2];
+std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
+ const std::array<xla::XlaOp, 3>& hsv,
+ DataType dtype) {
+ xla::XlaOp hue = hsv[0];
+ xla::XlaOp saturation = hsv[1];
+ xla::XlaOp value = hsv[2];
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
@@ -95,16 +94,16 @@ class RGBToHSVOp : public XlaOpKernel {
errors::FailedPrecondition("input must have 3 channels but input has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
- xla::ComputationDataHandle red =
+ xla::XlaOp red =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle green =
+ xla::XlaOp green =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle blue =
+ xla::XlaOp blue =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
@@ -133,15 +132,15 @@ class HSVToRGBOp : public XlaOpKernel {
errors::FailedPrecondition("input must have 3 channels but input has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle hue =
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp hue =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle saturation =
+ xla::XlaOp saturation =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle value =
+ xla::XlaOp value =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
@@ -174,9 +173,9 @@ class AdjustContrastOpV2 : public XlaOpKernel {
errors::InvalidArgument("contrast_factor must be scalar: ",
factor_shape.DebugString()));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle factor = context->Input(1);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp factor = context->Input(1);
DataType type = context->input_type(0);
@@ -221,19 +220,19 @@ class AdjustSaturationOp : public XlaOpKernel {
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle scale = context->Input(1);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp scale = context->Input(1);
DataType type = context->input_type(0);
- xla::ComputationDataHandle red =
+ xla::XlaOp red =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle green =
+ xla::XlaOp green =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle blue =
+ xla::XlaOp blue =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
@@ -271,19 +270,19 @@ class AdjustHueOp : public XlaOpKernel {
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
- xla::ComputationBuilder* b = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle delta = context->Input(1);
+ xla::XlaBuilder* b = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp delta = context->Input(1);
DataType type = context->input_type(0);
- xla::ComputationDataHandle red =
+ xla::XlaOp red =
b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle green =
+ xla::XlaOp green =
b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
/*dimno=*/channel_dim);
- xla::ComputationDataHandle blue =
+ xla::XlaOp blue =
b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
/*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index f36b3f5948..9058cbc747 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -99,9 +99,9 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
return dims;
}
-xla::ComputationDataHandle MakeBilinearResizeKernel(
- xla::ComputationBuilder* builder, gtl::ArraySlice<int64> kernel_size,
- int64 channels) {
+xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
+ gtl::ArraySlice<int64> kernel_size,
+ int64 channels) {
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
@@ -120,7 +120,7 @@ xla::ComputationDataHandle MakeBilinearResizeKernel(
return kernel;
};
- xla::ComputationDataHandle channels_iota;
+ xla::XlaOp channels_iota;
// DT_INT32 Iota will always return status::OK().
TF_CHECK_OK(
XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
@@ -139,10 +139,12 @@ xla::ComputationDataHandle MakeBilinearResizeKernel(
/*broadcast_dimensions=*/{0});
}
-xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input,
- const int num_spatial_dims, std::vector<int64> in_size,
- std::vector<int64> out_size, const int64 channels) {
+xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
+ const xla::XlaOp& input,
+ const int num_spatial_dims,
+ std::vector<int64> in_size,
+ std::vector<int64> out_size,
+ const int64 channels) {
// Picture for a 1x3 to 1x4 resize:
// stride = 2, kernel size = 3
// Input:
@@ -168,9 +170,9 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
ResizeConvolutionDims dims =
ComputeResizeConvolutionParameters(in_size, out_size);
- xla::ComputationDataHandle kernel =
+ xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- xla::ComputationDataHandle output = builder->ConvGeneralDilated(
+ xla::XlaOp output = builder->ConvGeneralDilated(
input, kernel, dims.stride,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
@@ -189,10 +191,12 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution(
return output;
}
-xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad,
- const int num_spatial_dims, std::vector<int64> in_size,
- std::vector<int64> grad_size, const int64 channels) {
+xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
+ const xla::XlaOp& grad,
+ const int num_spatial_dims,
+ std::vector<int64> in_size,
+ std::vector<int64> grad_size,
+ const int64 channels) {
ResizeConvolutionDims dims =
ComputeResizeConvolutionParameters(in_size, grad_size);
@@ -210,7 +214,7 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
}
dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
- xla::ComputationDataHandle kernel =
+ xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
// Broadcast the input kernel where the forward op expanded from a size == 1
@@ -223,7 +227,7 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp(
}
}
- xla::ComputationDataHandle output = builder->ConvGeneralDilated(
+ xla::XlaOp output = builder->ConvGeneralDilated(
grad, kernel, /*window_strides=*/dims.kernel_size,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
@@ -258,7 +262,7 @@ class ResizeBilinearOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, input_shape.dims() == 4,
@@ -283,7 +287,7 @@ class ResizeBilinearOp : public XlaOpKernel {
const int num_spatial_dims = 2;
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
// If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
// dimension i.
@@ -318,7 +322,7 @@ class ResizeBilinearOp : public XlaOpKernel {
// from image of size axb -> cxd is same as resizing axb -> exf -> cxd.
//
// This makes the convolutions kernels smaller and the operation faster.
- xla::ComputationDataHandle output = input;
+ xla::XlaOp output = input;
while (in_size != out_size) {
if (in_size[0] != 1 && in_size[1] != 1) {
std::vector<float> k = {
@@ -369,7 +373,7 @@ class ResizeBilinearGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape input_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, input_shape.dims() == 4,
@@ -406,9 +410,9 @@ class ResizeBilinearGradOp : public XlaOpKernel {
const int num_spatial_dims = 2;
- xla::ComputationDataHandle grad = ctx->Input(0);
+ xla::XlaOp grad = ctx->Input(0);
- xla::ComputationDataHandle output = grad;
+ xla::XlaOp output = grad;
while (in_size != grad_size) {
if (in_size[0] != 1 && in_size[1] != 1) {
std::vector<float> k = {
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 7bf4b435f5..36eb4c7545 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -61,10 +61,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
DataType index_type = output_type(0);
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
- xla::ComputationDataHandle output;
+ xla::XlaOp output;
if (is_min_) {
OP_REQUIRES_OK(ctx,
XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0),
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index b1f3c3c298..2c2d88486f 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -71,10 +71,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(),
errors::InvalidArgument(
"ArgMax implementation requires a CustomCall on CPU"));
- xla::ComputationBuilder& b = *ctx->builder();
+ xla::XlaBuilder& b = *ctx->builder();
// XLA passes <out> to the function, so it is not included here.
- std::vector<xla::ComputationDataHandle> args;
+ std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(b.ConstantLiteral(
*xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
@@ -91,7 +91,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_1d.cc.
- xla::ComputationDataHandle output;
+ xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape);
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
index c177f08d9c..1decf7d72d 100644
--- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -33,7 +33,7 @@ class L2LossOp : public XlaOpKernel {
std::iota(dims.begin(), dims.end(), 0);
DataType dtype = ctx->input_type(0);
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
// output = sum(t ** 2) / 2
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
index 1cfee3070f..39fbf98a62 100644
--- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -38,8 +38,8 @@ class LRNOp : public XlaOpKernel {
OP_REQUIRES(ctx, in_shape.dims() == 4,
errors::InvalidArgument("in must be 4-dimensional"));
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
// sqr_sum[a, b, c, d] =
// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
@@ -111,10 +111,10 @@ class LRNGradOp : public XlaOpKernel {
"input_grads, input_image, and out_image should have the same "
"shape"));
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle in_grads = ctx->Input(0);
- xla::ComputationDataHandle in_image = ctx->Input(1);
- xla::ComputationDataHandle out_image = ctx->Input(2);
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp in_grads = ctx->Input(0);
+ xla::XlaOp in_image = ctx->Input(1);
+ xla::XlaOp out_image = ctx->Input(2);
// This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python
// pseudo-code, the Eigen code does this for each spatial position:
@@ -166,7 +166,7 @@ class LRNGradOp : public XlaOpKernel {
auto dy_reduced =
XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0));
- xla::ComputationDataHandle gradients = builder->Add(
+ xla::XlaOp gradients = builder->Add(
builder->Mul(in_image, dy_reduced),
builder->Mul(in_grads,
builder->Pow(norm, builder->ConstantR0<float>(-beta_))));
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index 886baf8115..6949b296f4 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -66,8 +66,8 @@ class MatMulOp : public XlaOpKernel {
a_shape.DebugString(), ", In[1]: ",
b_shape.DebugString()));
- xla::ComputationDataHandle a = ctx->Input(0);
- xla::ComputationDataHandle b = ctx->Input(1);
+ xla::XlaOp a = ctx->Input(0);
+ xla::XlaOp b = ctx->Input(1);
if (is_sparse_) {
if (a_type_ == DT_BFLOAT16) {
a = ctx->builder()->ConvertElementType(a, xla::F32);
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index faa415a97b..fbd5dc0fda 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -44,10 +44,10 @@ class MatrixBandPartOp : public XlaOpKernel {
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in_shape.DebugString()));
- xla::ComputationBuilder* builder = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle num_lower = context->Input(1);
- xla::ComputationDataHandle num_upper = context->Input(2);
+ xla::XlaBuilder* builder = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp num_lower = context->Input(1);
+ xla::XlaOp num_upper = context->Input(2);
DataType input_type = context->input_type(0);
DataType index_type = context->input_type(1);
@@ -58,10 +58,10 @@ class MatrixBandPartOp : public XlaOpKernel {
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
- xla::ComputationDataHandle iota_m;
+ xla::XlaOp iota_m;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
- xla::ComputationDataHandle iota_n;
+ xla::XlaOp iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index b2940bdcff..db53f6fef8 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -54,16 +54,16 @@ class MatrixSetDiagOp : public XlaOpKernel {
input_shape.DebugString(),
" and diagonal shape: ", diag_shape.DebugString()));
- xla::ComputationBuilder* builder = context->builder();
- xla::ComputationDataHandle input = context->Input(0);
- xla::ComputationDataHandle diag = context->Input(1);
+ xla::XlaBuilder* builder = context->builder();
+ xla::XlaOp input = context->Input(0);
+ xla::XlaOp diag = context->Input(1);
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
- xla::ComputationDataHandle iota_m;
+ xla::XlaOp iota_m;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
- xla::ComputationDataHandle iota_n;
+ xla::XlaOp iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
auto indicator = builder->Eq(iota_m,
builder->Broadcast(iota_n, {m}),
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index 05a36a031a..7e9de3ef9b 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -25,10 +25,11 @@ class MirrorPadOp : public XlaOpKernel {
public:
explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
- xla::StatusOr<xla::ComputationDataHandle> DoMirrorPad(
- const xla::ComputationDataHandle& t, const xla::Shape& original_shape,
- const xla::Literal& pad_literal, xla::ComputationBuilder* b) {
- xla::ComputationDataHandle accum = t;
+ xla::StatusOr<xla::XlaOp> DoMirrorPad(const xla::XlaOp& t,
+ const xla::Shape& original_shape,
+ const xla::Literal& pad_literal,
+ xla::XlaBuilder* b) {
+ xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
auto t_rev = b->Rev(accum, {dimno});
@@ -76,12 +77,12 @@ class MirrorPadOp : public XlaOpKernel {
OP_REQUIRES_OK(
ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
auto in0 = ctx->Input(0);
- xla::StatusOr<std::unique_ptr<xla::Shape>> in0_shape = b->GetShape(in0);
+ xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
- xla::StatusOr<xla::ComputationDataHandle> accum_status =
- DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b);
+ xla::StatusOr<xla::XlaOp> accum_status =
+ DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b);
OP_REQUIRES_OK(ctx, accum_status.status());
diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
index 9f7c991380..cac2eea96e 100644
--- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
@@ -62,7 +62,7 @@ class OneHotOp : public XlaOpKernel {
ctx, depth >= 0,
errors::InvalidArgument("depth must be non-negative, got: ", depth));
- xla::ComputationDataHandle one_hot;
+ xla::XlaOp one_hot;
OP_REQUIRES_OK(
ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0),
indices_shape, ctx->Input(0), ctx->Input(2),
diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
index a4318e29d2..aecaabb6dc 100644
--- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
@@ -43,7 +43,7 @@ class PackOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- std::vector<xla::ComputationDataHandle> values;
+ std::vector<xla::XlaOp> values;
std::vector<TensorShape> shapes;
OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
const int num = values.size();
@@ -69,7 +69,7 @@ class PackOp : public XlaOpKernel {
-expanded_num_dims, ", ",
expanded_num_dims, ")"));
- std::vector<xla::ComputationDataHandle> reshaped_inputs(num);
+ std::vector<xla::XlaOp> reshaped_inputs(num);
TensorShape child_shape(shapes[0]);
child_shape.InsertDim(axis, 1);
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 791351637a..7c95475e7b 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -70,7 +70,7 @@ class PadOp : public XlaOpKernel {
}
// PadV2 added a "constant_values" input that indicates the pad value.
- xla::ComputationDataHandle constant_values;
+ xla::XlaOp constant_values;
if (ctx->num_inputs() == 3) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)),
errors::InvalidArgument("constant_values must be a scalar."));
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 5f635dd1bc..f8e7b48a0f 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -66,15 +66,15 @@ class PoolingOp : public XlaOpKernel {
int num_dims() const { return num_spatial_dims_ + 2; }
// Method that builds an initial value to use in reductions.
- virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0;
+ virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0;
// The reduction operation to apply to each window.
- virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0;
+ virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0;
// A post-processing operation to apply on the outputs of the ReduceWindow.
- virtual xla::ComputationDataHandle PostProcessOutput(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape) = 0;
+ virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
+ const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape) = 0;
void Compile(XlaOpKernelContext* ctx) override {
std::vector<int64> ksize = ksize_;
@@ -110,7 +110,7 @@ class PoolingOp : public XlaOpKernel {
" operator must have ", num_dims(),
" dimensions"));
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
auto input =
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
auto reduce = ctx->builder()->ReduceWindow(
@@ -135,17 +135,17 @@ class MaxPoolOp : public PoolingOp {
: PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
/*reduction_type=*/ctx->input_type(0)) {}
- xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override {
+ xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return XlaHelpers::MinValue(b, reduction_type_);
}
- const xla::Computation* Reduction(XlaOpKernelContext* ctx) override {
+ const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateMax(reduction_type_);
}
- xla::ComputationDataHandle PostProcessOutput(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape) override {
+ xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
+ const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape) override {
return output;
}
};
@@ -176,9 +176,9 @@ REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
// Common computation shared between AvgPool and AvgPoolGrad. Divide each
// element of an image by the count of elements that contributed to that
// element during pooling.
-static xla::ComputationDataHandle AvgPoolDivideByCount(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape, xla::Padding padding,
+static xla::XlaOp AvgPoolDivideByCount(
+ XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape, xla::Padding padding,
const std::vector<int64>& ksize, const std::vector<int64>& stride,
int num_spatial_dims, TensorFormat data_format) {
if (padding == xla::Padding::kValid) {
@@ -234,17 +234,17 @@ class AvgPoolOp : public PoolingOp {
/*reduction_type=*/
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override {
+ xla::XlaOp InitValue(xla::XlaBuilder* b) override {
return XlaHelpers::Zero(b, reduction_type_);
}
- const xla::Computation* Reduction(XlaOpKernelContext* ctx) override {
+ const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
return ctx->GetOrCreateAdd(reduction_type_);
}
- xla::ComputationDataHandle PostProcessOutput(
- XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
- DataType dtype, const TensorShape& input_shape) override {
+ xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx,
+ const xla::XlaOp& output, DataType dtype,
+ const TensorShape& input_shape) override {
return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
ksize_, stride_, num_spatial_dims_,
data_format_);
@@ -344,11 +344,10 @@ class MaxPoolGradOp : public XlaOpKernel {
xla::PrimitiveType element_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
- xla::ComputationDataHandle init_value =
- XlaHelpers::Zero(ctx->builder(), input_type(2));
+ xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
auto select = CreateScalarGeComputation(element_type, ctx->builder());
auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
- xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter(
+ xla::XlaOp gradients = ctx->builder()->SelectAndScatter(
input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
scatter);
@@ -462,7 +461,7 @@ class AvgPoolGradOp : public XlaOpKernel {
// The input gradients are computed by a convolution of the output gradients
// and the filter, with some appropriate padding. See the comment at the top
// of conv_grad_ops.h for details.
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
auto out_backprop = ctx->Input(1);
auto dtype = input_type(1);
xla::Padding xla_padding =
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 4171e076ff..661cd5923e 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -35,7 +35,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
// Comments taken from semantics description at
@@ -46,8 +46,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// m = max(abs(input_min), abs(input_max)) if range_given is true,
// m = max(abs(min_elem(input)),
// abs(max_elem(input))) otherwise.
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input_min, input_max;
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input_min, input_max;
if (range_given_) {
double input_min_value, input_max_value;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value));
@@ -55,14 +55,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value);
input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value);
} else {
- const xla::Computation* fmax = ctx->GetOrCreateMax(data_type);
- const xla::Computation* fmin = ctx->GetOrCreateMin(data_type);
+ const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type);
+ const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type);
input_min =
b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
input_max =
b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
}
- xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max));
+ xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max));
// Next, we choose our fixed-point quantization buckets, [min_fixed,
// max_fixed]. If signed_input is true, this is
@@ -85,7 +85,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// From this we compute our scaling factor, s:
//
// s = (max_fixed - min_fixed) / (2 * m).
- xla::ComputationDataHandle s =
+ xla::XlaOp s =
b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
@@ -93,7 +93,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
// e is transformed into e':
//
// e' = (e * s).round_to_nearest() / s.
- xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s);
+ xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index c0994c434b..5f5bd58637 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -41,9 +41,9 @@ class RandomUniformOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle result = b->RngUniform(
- XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -100,11 +100,11 @@ class RandomStandardNormalOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Normal distribution with a mean of 0 and a standard deviation of 1:
- xla::ComputationDataHandle result = b->RngNormal(
- XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -130,19 +130,18 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::Shape xla_element_shape =
xla::ShapeUtil::MakeShape(xla_shape.element_type(), {});
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype);
- xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype);
- xla::ComputationDataHandle candidate =
- b->RngNormal(mean, stddev, xla_shape);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
+ xla::XlaOp stddev = XlaHelpers::One(b, dtype);
+ xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape);
- auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) {
+ auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) {
return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0);
};
- auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate,
- xla::ComputationBuilder* b) {
- xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b));
- xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b));
+ auto out_of_range_mask = [two_sd](xla::XlaOp candidate,
+ xla::XlaBuilder* b) {
+ xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b));
+ xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b));
return b->Or(too_large, too_small);
};
@@ -152,35 +151,32 @@ class TruncatedNormalOp : public XlaOpKernel {
// out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd
// candidate = select(out_of_range_mask, rng_normal(), candidate)
// }
- std::unique_ptr<xla::ComputationBuilder> test_builder =
+ std::unique_ptr<xla::XlaBuilder> test_builder =
b->CreateSubBuilder("truncated_normal_test");
{
auto* b = test_builder.get();
- xla::ComputationDataHandle candidate =
- b->Parameter(0, xla_shape, "candidate");
- xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b);
+ xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
+ out_of_range_mask(candidate, b);
OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status());
}
- std::unique_ptr<xla::ComputationBuilder> body_builder =
+ std::unique_ptr<xla::XlaBuilder> body_builder =
b->CreateSubBuilder("truncated_normal_body");
{
auto* b = body_builder.get();
- xla::ComputationDataHandle candidate =
- b->Parameter(0, xla_shape, "candidate");
- xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b);
- xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype);
- xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype);
+ xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
+ xla::XlaOp to_resample = out_of_range_mask(candidate, b);
+ xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
+ xla::XlaOp stddev = XlaHelpers::One(b, dtype);
b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate);
}
- xla::StatusOr<xla::Computation> test_computation = test_builder->Build();
+ xla::StatusOr<xla::XlaComputation> test_computation = test_builder->Build();
OP_REQUIRES_OK(ctx, test_computation.status());
- xla::StatusOr<xla::Computation> body_computation = body_builder->Build();
+ xla::StatusOr<xla::XlaComputation> body_computation = body_builder->Build();
OP_REQUIRES_OK(ctx, body_computation.status());
- xla::ComputationDataHandle result =
- b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(),
- candidate);
+ xla::XlaOp result = b->While(test_computation.ValueOrDie(),
+ body_computation.ValueOrDie(), candidate);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index cb144bea9e..08894489ac 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -65,7 +64,7 @@ class ReduceWindowOp : public XlaOpKernel {
"rank (",
padding_high_.size(), " vs. ", rank, ")"));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -95,15 +94,15 @@ class ReduceWindowOp : public XlaOpKernel {
xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
// Wraps the reducer in a computation that unpacks the output tuple.
- xla::Computation wrapper;
+ xla::XlaComputation wrapper;
{
- std::unique_ptr<xla::ComputationBuilder> cb =
+ std::unique_ptr<xla::XlaBuilder> cb =
builder->CreateSubBuilder("wrapper");
auto x = cb->Parameter(0, scalar_shape, "x");
auto y = cb->Parameter(1, scalar_shape, "y");
auto outputs = cb->Call(*reducer.computation, {x, y});
cb->GetTupleElement(outputs, 0);
- xla::StatusOr<xla::Computation> result = cb->Build();
+ xla::StatusOr<xla::XlaComputation> result = cb->Build();
OP_REQUIRES_OK(context, result.status());
wrapper = std::move(result.ValueOrDie());
}
@@ -113,7 +112,7 @@ class ReduceWindowOp : public XlaOpKernel {
padding[i] = {padding_low_[i], padding_high_[i]};
}
- xla::ComputationDataHandle output = builder->ReduceWindowWithGeneralPadding(
+ xla::XlaOp output = builder->ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), wrapper, window_dimensions_,
window_strides_, padding);
context->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 812d258cd1..0f42563779 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -30,13 +30,11 @@ class SumOp : public XlaReductionOp {
explicit SumOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::Zero(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Add(scalar_lhs, scalar_rhs);
}
};
@@ -49,14 +47,12 @@ class ProdOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::One(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Mul(scalar_lhs, scalar_rhs);
}
};
@@ -69,14 +65,12 @@ class MinOp : public XlaReductionOp {
explicit MinOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::MaxValue(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Min(scalar_lhs, scalar_rhs);
}
};
@@ -88,14 +82,12 @@ class MaxOp : public XlaReductionOp {
explicit MaxOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::MinValue(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Max(scalar_lhs, scalar_rhs);
}
};
@@ -108,20 +100,17 @@ class MeanOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return XlaHelpers::Zero(builder, reduction_type_);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Add(scalar_lhs, scalar_rhs);
}
- xla::ComputationDataHandle BuildFinalizer(
- xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& reduce_output,
- int64 num_elements_reduced) override {
+ xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder,
+ const xla::XlaOp& reduce_output,
+ int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
return builder->Div(reduce_output, divisor);
@@ -136,14 +125,12 @@ class AllOp : public XlaReductionOp {
explicit AllOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return builder->ConstantR0<bool>(true);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->And(scalar_lhs, scalar_rhs);
}
};
@@ -155,14 +142,12 @@ class AnyOp : public XlaReductionOp {
explicit AnyOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx, ctx->input_type(0)) {}
- xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) override {
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return builder->ConstantR0<bool>(false);
}
- void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) override {
+ void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) override {
builder->Or(scalar_lhs, scalar_rhs);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index f3181f0dad..2ecfb854a1 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@@ -28,35 +28,33 @@ namespace tensorflow {
// to override: description is a textual description of the mapped
// function; InitialValue constructs the base case for the reduction;
// BuildReducer adds the implementation of the reduction lambda to a
-// xla::ComputationBuilder and BuildFinalizer adds the
+// xla::XlaBuilder and BuildFinalizer adds the
// implementation of the finalizer lambda (if there is one) to a
-// xla::ComputationBuilder.
+// xla::XlaBuilder.
class XlaReductionOp : public XlaOpKernel {
public:
XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type);
~XlaReductionOp() override {}
// Return the base case for the reduction.
- virtual xla::ComputationDataHandle InitialValue(
- xla::ComputationBuilder* builder) = 0;
+ virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
// Implement the (scalar,scalar)->scalar lambda that should be
// applied to each pair of elements to be reduced. The desired
// computation should be added to 'builder' and
// '(scalar_lhs,scalar_rhs)' are the function's inputs.
- virtual void BuildReducer(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& scalar_lhs,
- const xla::ComputationDataHandle& scalar_rhs) = 0;
+ virtual void BuildReducer(xla::XlaBuilder* builder,
+ const xla::XlaOp& scalar_lhs,
+ const xla::XlaOp& scalar_rhs) = 0;
// Applies a transformation to the output of the reduction. The desired
// computation should be added to 'builder'. Argument 'reduce_output' is the
// output of the reduction. 'num_elements_reduced' is the number of elements
// that contributed to the reduction. Returns the transformed reduction
// output, Defaults to returning 'reduce_output' unchanged.
- virtual xla::ComputationDataHandle BuildFinalizer(
- xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& reduce_output,
- int64 num_elements_reduced);
+ virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder,
+ const xla::XlaOp& reduce_output,
+ int64 num_elements_reduced);
void Compile(XlaOpKernelContext* ctx) override;
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 64fe765ae9..4fd5bfd039 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -35,10 +35,9 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
// Unless BuildFinalizer is overridden the reduction has no
// finalizer.
-xla::ComputationDataHandle XlaReductionOp::BuildFinalizer(
- xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& reduce_output,
- int64 num_elements_reduced) {
+xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder,
+ const xla::XlaOp& reduce_output,
+ int64 num_elements_reduced) {
return reduce_output;
}
@@ -96,9 +95,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
string desc = ctx->op_kernel().name();
- xla::ComputationBuilder* const b = ctx->builder();
+ xla::XlaBuilder* const b = ctx->builder();
// Construct the builder for the reduction lambda.
- xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction"));
+ xla::XlaBuilder r(strings::StrCat(desc, "-reduction"));
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
@@ -110,7 +109,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
// Call virtual method to build the reduction lambda.
BuildReducer(&r, rx, ry);
- xla::Computation reduction_computation = r.Build().ConsumeValueOrDie();
+ xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes);
auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 12a3552999..ba7d484d53 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -32,7 +32,7 @@ class ReluOp : public XlaOpKernel {
explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
ctx->SetOutput(0, builder->Max(zero, ctx->Input(0)));
}
@@ -43,7 +43,7 @@ class Relu6Op : public XlaOpKernel {
explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Clamp the scalar input between 0 and 6.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6);
ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six));
@@ -56,7 +56,7 @@ class ReluGradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
@@ -71,7 +71,7 @@ class Relu6GradOp : public XlaOpKernel {
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index c283e3b02c..70547290ea 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -45,7 +45,7 @@ class RetvalOp : public XlaOpKernel {
// compilation.
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
} else {
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
auto is_constant = ctx->builder()->IsConstant(input);
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index e51d386926..2872a3c4d4 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -48,7 +48,7 @@ class ReverseOp : public XlaOpKernel {
ctx->SetOutput(0, ctx->Input(0));
return;
}
- // ComputationBuilder::Rev() requires concrete values for dimensions arg.
+ // XlaBuilder::Rev() requires concrete values for dimensions arg.
xla::Literal lax;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
std::vector<bool> revdims(x_shape.dims());
@@ -90,7 +90,7 @@ class ReverseV2Op : public XlaOpKernel {
ctx->SetOutput(0, ctx->Input(0));
return;
}
- // ComputationBuilder::Rev() requires concrete values for dimensions arg.
+ // XlaBuilder::Rev() requires concrete values for dimensions arg.
std::vector<int64> axes;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes));
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 6bc5d3adb0..0ed4c4707d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -54,7 +54,7 @@ class ReverseSequenceOp : public XlaOpKernel {
"), ", "(", seq_lens_shape.num_elements(),
" vs. ", input_shape.dim_size(batch_dim_)));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
const auto input = context->Input(0);
const auto seq_lens = context->Input(1);
@@ -155,7 +155,7 @@ class ReverseSequenceOp : public XlaOpKernel {
auto output = builder->GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
OP_REQUIRES_OK(
context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
std::vector<int64> dims(input_shape.dims(), 1);
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 4cfa28a0ce..1819fb5433 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -74,7 +74,7 @@ class ScanOp : public XlaOpKernel {
return;
}
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
std::vector<int64> window_strides(input_shape.dims(), 1);
std::vector<int64> window_dims(input_shape.dims(), 1);
@@ -91,8 +91,8 @@ class ScanOp : public XlaOpKernel {
std::swap(padding[axis].first, padding[axis].second);
}
- xla::ComputationDataHandle init;
- const xla::Computation* reducer;
+ xla::XlaOp init;
+ const xla::XlaComputation* reducer;
if (sum_) {
init = XlaHelpers::Zero(builder, dtype);
reducer = ctx->GetOrCreateAdd(dtype);
diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
index 8433a29c4e..f2c63b4f90 100644
--- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
@@ -102,7 +102,7 @@ class ScatterNdOp : public XlaOpKernel {
OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape,
updates_shape));
- xla::ComputationBuilder* builder = context->builder();
+ xla::XlaBuilder* builder = context->builder();
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
buffer_shape.dim_sizes());
auto indices = context->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index 498342a988..664078ca16 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -62,16 +62,16 @@ class UnsortedSegmentSum : public XlaOpKernel {
d, " differs ", data_shape.dim_size(d), " vs. ",
indices_shape.dim_size(d)));
}
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_),
buffer_shape.dim_sizes());
- auto combiner =
- [](xla::ComputationDataHandle a, xla::ComputationDataHandle b,
- xla::ComputationBuilder* builder) { return builder->Add(a, b); };
+ auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
+ return builder->Add(a, b);
+ };
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index 8081d3c41c..f9f48164d6 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -40,7 +40,7 @@ class SelectOp : public XlaOpKernel {
"'then' and 'else' must have the same size. but received: ",
then_shape.DebugString(), " vs. ", else_shape.DebugString()));
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
auto cond_handle = ctx->Input(0);
auto then_handle = ctx->Input(1);
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index d079b89861..9ce01d0d44 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index 463788b8b4..bbf5ee8b12 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -43,8 +43,8 @@ class SoftmaxOp : public XlaOpKernel {
const DataType type = input_type(0);
auto logits = ctx->Input(0);
- xla::ComputationBuilder* const b = ctx->builder();
- const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+ xla::XlaBuilder* const b = ctx->builder();
+ const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
@@ -76,16 +76,15 @@ class SoftmaxOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
-std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
-CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type,
- const xla::ComputationDataHandle& logits,
- const xla::ComputationDataHandle& labels) {
- const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
+ XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits,
+ const xla::XlaOp& labels) {
+ const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
const int kBatchDim = 0;
const int kClassDim = 1;
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
@@ -123,7 +122,7 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type,
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
// (where the division broadcasts along the batch dimension)
- xla::ComputationDataHandle backprop =
+ xla::XlaOp backprop =
b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
return {loss, backprop};
}
@@ -150,7 +149,7 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
auto logits = ctx->Input(0);
auto labels = ctx->Input(1);
- xla::ComputationDataHandle loss, backprop;
+ xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
CrossEntropyWithLogits(ctx, type, logits, labels);
ctx->SetOutput(0, loss);
@@ -191,10 +190,10 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
DataType logits_type = input_type(0);
DataType indices_type = input_type(1);
- xla::ComputationDataHandle indices = ctx->Input(1);
+ xla::XlaOp indices = ctx->Input(1);
- xla::ComputationBuilder* builder = ctx->builder();
- xla::ComputationDataHandle labels;
+ xla::XlaBuilder* builder = ctx->builder();
+ xla::XlaOp labels;
OP_REQUIRES_OK(ctx,
XlaHelpers::OneHot(
builder, depth, /*axis=*/1, input_type(1), labels_shape,
@@ -207,7 +206,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
// Builds a vector of {batch_size} that is 0 if the index is in range, or
// NaN otherwise; then add that vector to the labels to force out-of-range
// values to NaNs.
- xla::ComputationDataHandle nan_or_zero = builder->Select(
+ xla::XlaOp nan_or_zero = builder->Select(
builder->And(
builder->Le(XlaHelpers::Zero(builder, indices_type), indices),
builder->Lt(indices, XlaHelpers::IntegerLiteral(
@@ -218,7 +217,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
{batch_size}));
labels = builder->Add(labels, nan_or_zero, {0});
- xla::ComputationDataHandle loss, backprop;
+ xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
ctx->SetOutput(0, loss);
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 01b46e160d..ec077924b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -20,9 +20,8 @@ limitations under the License.
namespace tensorflow {
namespace {
-void SpaceToBatch(XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input, DataType input_dtype,
- const TensorShape& input_tensor_shape,
+void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
+ DataType input_dtype, const TensorShape& input_tensor_shape,
gtl::ArraySlice<int64> block_shape,
const xla::Literal& paddings) {
const int input_rank = input_tensor_shape.dims();
@@ -46,7 +45,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
", 2] instead of ",
xla::ShapeUtil::HumanString(paddings.shape())));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
// input according to `paddings` to produce `padded` of shape `padded_shape`.
@@ -73,7 +72,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
errors::InvalidArgument(
"The product of the block dimensions must be positive"));
- xla::ComputationDataHandle padded =
+ xla::XlaOp padded =
b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
// 2. Reshape `padded` to `reshaped_padded` of shape:
@@ -101,8 +100,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_padded_shape.begin() + 1 + 2 * block_rank);
- xla::ComputationDataHandle reshaped_padded =
- b->Reshape(padded, reshaped_padded_shape);
+ xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape);
// 3. Permute dimensions of `reshaped_padded` to produce
// `permuted_reshaped_padded` of shape:
@@ -121,7 +119,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
permutation[block_rank] = 0;
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
- xla::ComputationDataHandle permuted_reshaped_padded =
+ xla::XlaOp permuted_reshaped_padded =
b->Transpose(reshaped_padded, permutation);
// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
@@ -142,8 +140,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
output_shape.begin() + 1 + block_rank);
- xla::ComputationDataHandle output =
- b->Reshape(permuted_reshaped_padded, output_shape);
+ xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 806fda632c..4c5886ee2a 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -50,8 +50,8 @@ class SpaceToDepthOp : public XlaOpKernel {
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
@@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[1] / block_size_, block_size_,
// input_shape[2] / block_size_, block_size_,
// depth]
- xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -145,8 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_, block_size_,
// depth]
- xla::ComputationDataHandle permuted_reshaped =
- b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -156,8 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_ * block_size_ * depth]
//
- xla::ComputationDataHandle output =
- b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 43c15e7538..8958b2e770 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -124,7 +124,7 @@ class SplitVOp : public XlaOpKernel {
input_shape.dims(), "), but got ",
split_dim_orig));
- xla::ComputationDataHandle input = ctx->Input(0);
+ xla::XlaOp input = ctx->Input(0);
OP_REQUIRES(ctx, input_shape.dims() > 0,
errors::InvalidArgument("Can't split a 0 dimensional input"));
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 1a78c7ab9b..0fb05a2be7 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -38,13 +38,13 @@ limitations under the License.
namespace tensorflow {
namespace {
-Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource,
+Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource,
TensorShape* stack_shape) {
auto shape_or_status = builder->GetShape(resource->value());
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
- xla::Shape shape = *shape_or_status.ValueOrDie();
+ xla::Shape shape = shape_or_status.ValueOrDie();
TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
stack_shape);
@@ -60,9 +60,8 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource,
//
// TODO(phawkins): consider changing the API of the stack operators to
// allow an optional element shape at stack construction time.
-Status MaybeInitializeStack(xla::ComputationBuilder* builder,
- XlaResource* resource, DataType dtype,
- const TensorShape& elem_shape) {
+Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource,
+ DataType dtype, const TensorShape& elem_shape) {
if (resource->type() != dtype) {
return errors::InvalidArgument(
"Stack dtype is ", DataTypeString(resource->type()),
@@ -75,8 +74,6 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
if (!resource->initialized()) {
// Stack has not been initialized.
- xla::ComputationDataHandle zero =
- XlaHelpers::Zero(builder, resource->type());
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
@@ -111,7 +108,7 @@ class StackOp : public XlaOpKernel {
// We defer initializing the Stack resource until we see the first push.
// Otherwise we do not know the shape of the stack elements.
- xla::ComputationDataHandle value;
+ xla::XlaOp value;
XlaContext& xc = XlaContext::Get(ctx);
XlaResource* resource;
string name = strings::StrCat("Stack: ", stack_name_);
@@ -138,7 +135,7 @@ class StackPushOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape elem_shape = ctx->InputShape(1);
XlaResource* resource;
@@ -147,9 +144,9 @@ class StackPushOp : public XlaOpKernel {
// Initializes the Stack, if the element shape was not already known.
OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0);
- xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1);
- xla::ComputationDataHandle value = ctx->Input(1);
+ xla::XlaOp ta = b->GetTupleElement(resource->value(), 0);
+ xla::XlaOp index = b->GetTupleElement(resource->value(), 1);
+ xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -184,7 +181,7 @@ class StackPopOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -199,9 +196,9 @@ class StackPopOp : public XlaOpKernel {
TensorShape stack_shape;
OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
- xla::ComputationDataHandle state = resource->value();
- xla::ComputationDataHandle ta = b->GetTupleElement(state, 0);
- xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
+ xla::XlaOp state = resource->value();
+ xla::XlaOp ta = b->GetTupleElement(state, 0);
+ xla::XlaOp index = b->GetTupleElement(state, 1);
index = b->Sub(index, b->ConstantR0<int32>(1));
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
@@ -216,8 +213,7 @@ class StackPopOp : public XlaOpKernel {
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- xla::ComputationDataHandle read =
- b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 5bb773d97f..6340c22518 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -30,9 +30,8 @@ namespace tensorflow {
namespace {
// Rotates a 32-bit integer 'v' left by 'distance' bits.
-xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& v,
- int distance) {
+xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v,
+ int distance) {
return builder->Or(
builder->ShiftLeft(v, builder->ConstantR0<int>(distance)),
builder->ShiftRightLogical(v, builder->ConstantR0<int>(32 - distance)));
@@ -40,25 +39,24 @@ xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder,
// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than
// building XOR out of other bitwise operators.
-xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& y) {
+xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const xla::XlaOp& y) {
return builder->Or(builder->And(x, builder->Not(y)),
builder->And(builder->Not(x), y));
}
-using ThreeFry2x32State = std::array<xla::ComputationDataHandle, 2>;
+using ThreeFry2x32State = std::array<xla::XlaOp, 2>;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
-ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder,
+ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder,
ThreeFry2x32State input, ThreeFry2x32State key) {
// Rotation distances specified by the Threefry2x32 algorithm.
constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
ThreeFry2x32State x;
- std::array<xla::ComputationDataHandle, 3> ks;
+ std::array<xla::XlaOp, 3> ks;
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
ks[2] = builder->ConstantR0<int32>(0x1BD11BDA);
for (int i = 0; i < 2; ++i) {
@@ -121,10 +119,9 @@ ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder,
// Returns a tensor of 'shape' random values uniformly distributed in the range
// [minval, maxval)
-xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder,
- const xla::ComputationDataHandle& seed,
- const TensorShape& shape,
- double minval, double maxval) {
+xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
+ const TensorShape& shape, double minval,
+ double maxval) {
// Split the seed into two 32-bit scalars to form a key.
auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {});
auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {});
@@ -178,9 +175,8 @@ xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder,
// p = sum_{i=1}^n gq[i]*w^i
// }
// return p*x
-xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b,
- const xla::ComputationDataHandle& x,
- const TensorShape& shape) {
+xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x,
+ const TensorShape& shape) {
constexpr int kDegree = 9;
constexpr std::array<float, 9> w_less_than_5_constants = {
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
@@ -220,7 +216,7 @@ class StatelessRandomUniformOp : public XlaOpKernel {
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
@@ -229,7 +225,7 @@ class StatelessRandomUniformOp : public XlaOpKernel {
OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
- xla::ComputationDataHandle seed = ctx->Input(1);
+ xla::XlaOp seed = ctx->Input(1);
ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0));
}
@@ -257,8 +253,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
- xla::ComputationDataHandle seed = ctx->Input(1);
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaOp seed = ctx->Input(1);
+ xla::XlaBuilder* builder = ctx->builder();
auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 6204aa4e27..55254c746e 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -90,7 +90,7 @@ class StridedSliceOp : public XlaOpKernel {
}
}
- xla::ComputationDataHandle slice = ctx->Input(0);
+ xla::XlaOp slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
}
@@ -168,7 +168,7 @@ class StridedSliceGradOp : public XlaOpKernel {
auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
- xla::ComputationDataHandle grad = ctx->Input(4);
+ xla::XlaOp grad = ctx->Input(4);
// Undo any new/shrink axes.
grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
@@ -255,7 +255,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
&strides_tensor));
TensorShape lhs_shape;
- xla::ComputationDataHandle lhs;
+ xla::XlaOp lhs;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
const TensorShape rhs_shape = ctx->InputShape(4);
@@ -284,7 +284,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
" does not match r-value shape ", rhs_shape.DebugString(),
". Automatic broadcasting not yet implemented."));
- xla::ComputationDataHandle rhs = ctx->Input(4);
+ xla::XlaOp rhs = ctx->Input(4);
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 000b50af6b..9adee78a1f 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -47,7 +47,7 @@ namespace {
// the TensorArray with elements of `elem_shape`. For both initialized and
// uninitialized TensorArrays, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
-Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
+Status MaybeInitializeTensorArray(xla::XlaBuilder* builder,
XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
if (resource->kind() != XlaResource::kTensorArray) {
@@ -64,9 +64,6 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
<< resource->name() << " size " << resource->tensor_array_size();
if (!resource->initialized()) {
- xla::ComputationDataHandle zero =
- XlaHelpers::Zero(builder, resource->type());
-
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
@@ -77,7 +74,7 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
}
TensorShape shape;
TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+ XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
@@ -114,23 +111,21 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
}
Status GetTensorArrayShape(const XlaResource* resource,
- xla::ComputationBuilder* builder,
- TensorShape* shape) {
+ xla::XlaBuilder* builder, TensorShape* shape) {
*shape = resource->shape();
shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
-// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
+// Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the
// relevant slice of 'operand'.
-xla::ComputationDataHandle DynamicAddSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
- const xla::ComputationDataHandle& update,
- const gtl::ArraySlice<int64>& update_dims,
- const xla::ComputationDataHandle& start_indices) {
- xla::ComputationDataHandle current =
+xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
+ const xla::XlaOp& update,
+ const gtl::ArraySlice<int64>& update_dims,
+ const xla::XlaOp& start_indices) {
+ xla::XlaOp current =
builder->DynamicSlice(operand, start_indices, update_dims);
- xla::ComputationDataHandle sum = builder->Add(current, update);
+ xla::XlaOp sum = builder->Add(current, update);
return builder->DynamicUpdateSlice(operand, sum, start_indices);
}
@@ -155,18 +150,18 @@ class TensorArrayOp : public XlaOpKernel {
OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument("TensorArray size must be >= 0"));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
- xla::ComputationDataHandle value;
+ xla::XlaOp value;
TensorShape shape;
if (element_shape_.IsFullyDefined()) {
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
ta_shape.AppendShape(shape);
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
+ xla::XlaOp zero = XlaHelpers::Zero(b, dtype_);
value = b->Broadcast(zero, ta_shape.dim_sizes());
}
@@ -202,7 +197,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape elem_shape = ctx->InputShape(2);
@@ -213,10 +208,10 @@ class TensorArrayWriteOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value();
- xla::ComputationDataHandle index = ctx->Input(1);
- xla::ComputationDataHandle value = ctx->Input(2);
- xla::ComputationDataHandle flow = ctx->Input(3);
+ xla::XlaOp ta = resource->value();
+ xla::XlaOp index = ctx->Input(1);
+ xla::XlaOp value = ctx->Input(2);
+ xla::XlaOp flow = ctx->Input(3);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -227,7 +222,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL);
auto update = b->Reshape(value, slice_shape.dim_sizes());
- xla::ComputationDataHandle written =
+ xla::XlaOp written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
OP_REQUIRES_OK(ctx, resource->SetValue(written));
@@ -249,7 +244,7 @@ class TensorArrayReadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -259,8 +254,8 @@ class TensorArrayReadOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value();
- xla::ComputationDataHandle index = ctx->Input(1);
+ xla::XlaOp ta = resource->value();
+ xla::XlaOp index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -270,8 +265,7 @@ class TensorArrayReadOp : public XlaOpKernel {
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
- xla::ComputationDataHandle read =
- b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
@@ -293,7 +287,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -309,7 +303,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
auto indices = ctx->Input(1);
DataType index_type = ctx->input_type(1);
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
// Look for the case where the gather takes a simple slice from the
// tensor array (0, 1, 2, 3, 4, ..., N)
@@ -337,7 +331,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
}
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
ctx,
XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0,
@@ -360,7 +354,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape value_shape = ctx->InputShape(2);
@@ -375,11 +369,11 @@ class TensorArrayScatterOp : public XlaOpKernel {
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
errors::InvalidArgument("indices must be rank 1"));
const int num_indices = indices_shape.dim_size(0);
- const xla::ComputationDataHandle indices = ctx->Input(1);
+ const xla::XlaOp indices = ctx->Input(1);
- xla::ComputationDataHandle ta = resource->value();
- const xla::ComputationDataHandle value = ctx->Input(2);
- const xla::ComputationDataHandle flow = ctx->Input(3);
+ xla::XlaOp ta = resource->value();
+ const xla::XlaOp value = ctx->Input(2);
+ const xla::XlaOp flow = ctx->Input(3);
// Look for the case where the scatter is for each sub-tensor in order. The
// tensor array implementation allows for this to be a straight addition.
@@ -443,7 +437,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -453,7 +447,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
@@ -503,12 +497,12 @@ class TensorArraySplitOp : public XlaOpKernel {
TensorShape elem_shape = value_shape;
elem_shape.set_dim(0, length);
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
@@ -520,8 +514,8 @@ class TensorArraySplitOp : public XlaOpKernel {
"TensorArray's size is not equal to the size of lengths (",
lengths.size(), " vs. ", resource->tensor_array_size(), ")"));
- const xla::ComputationDataHandle value = ctx->Input(1);
- const xla::ComputationDataHandle flow = ctx->Input(3);
+ const xla::XlaOp value = ctx->Input(1);
+ const xla::XlaOp flow = ctx->Input(3);
OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
errors::InvalidArgument("mismatched element count ",
@@ -569,7 +563,7 @@ class TensorArrayGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 9aefcd4fc7..e91075196b 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -112,7 +112,7 @@ class TileOp : public XlaOpKernel {
flattened.push_back(i);
flattened.push_back(i + output_shape.size());
}
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
ctx->builder()->Reshape(broadcasted, flattened, output_shape);
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index f750f7003b..34caefa050 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -30,8 +30,8 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle handle;
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaOp handle;
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(1);
TensorShape var_shape;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
@@ -63,12 +63,12 @@ class ResourceApplyMomentum : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
- xla::ComputationDataHandle var, accum;
+ xla::XlaOp var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
@@ -93,9 +93,9 @@ class ResourceApplyMomentum : public XlaOpKernel {
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- xla::ComputationDataHandle lr = ctx->Input(2);
- xla::ComputationDataHandle grad = ctx->Input(3);
- xla::ComputationDataHandle momentum = ctx->Input(4);
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp momentum = ctx->Input(4);
accum = b->Add(b->Mul(accum, momentum), grad);
if (use_nesterov_) {
@@ -121,12 +121,12 @@ class ResourceApplyAdagrad : public XlaOpKernel {
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
- xla::ComputationDataHandle var, accum;
+ xla::XlaOp var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
@@ -146,8 +146,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle lr = ctx->Input(2);
- xla::ComputationDataHandle grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp grad = ctx->Input(3);
accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
var = b->Sub(
@@ -168,7 +168,7 @@ class ResourceApplyAdam : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
TensorShape var_shape, m_shape, v_shape;
- xla::ComputationDataHandle var, m, v;
+ xla::XlaOp var, m, v;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
@@ -213,25 +213,25 @@ class ResourceApplyAdam : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle beta1_power = ctx->Input(3);
- xla::ComputationDataHandle beta2_power = ctx->Input(4);
- xla::ComputationDataHandle lr = ctx->Input(5);
- xla::ComputationDataHandle beta1 = ctx->Input(6);
- xla::ComputationDataHandle beta2 = ctx->Input(7);
- xla::ComputationDataHandle epsilon = ctx->Input(8);
- xla::ComputationDataHandle grad = ctx->Input(9);
+ xla::XlaOp beta1_power = ctx->Input(3);
+ xla::XlaOp beta2_power = ctx->Input(4);
+ xla::XlaOp lr = ctx->Input(5);
+ xla::XlaOp beta1 = ctx->Input(6);
+ xla::XlaOp beta2 = ctx->Input(7);
+ xla::XlaOp epsilon = ctx->Input(8);
+ xla::XlaOp grad = ctx->Input(9);
// alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
// variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
- xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
- xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
- xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
- xla::ComputationDataHandle alpha =
+ xla::XlaOp alpha =
b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)),
b->Sub(one, beta1_power));
m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1)));
@@ -255,12 +255,12 @@ class ResourceApplyRMSProp : public XlaOpKernel {
explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(3);
TensorShape var_shape, ms_shape, mom_shape;
- xla::ComputationDataHandle var, ms, mom;
+ xla::XlaOp var, ms, mom;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
@@ -297,11 +297,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle lr = ctx->Input(3);
- xla::ComputationDataHandle rho = ctx->Input(4);
- xla::ComputationDataHandle momentum = ctx->Input(5);
- xla::ComputationDataHandle epsilon = ctx->Input(6);
- xla::ComputationDataHandle grad = ctx->Input(7);
+ xla::XlaOp lr = ctx->Input(3);
+ xla::XlaOp rho = ctx->Input(4);
+ xla::XlaOp momentum = ctx->Input(5);
+ xla::XlaOp epsilon = ctx->Input(6);
+ xla::XlaOp grad = ctx->Input(7);
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
@@ -320,16 +320,16 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::ComputationDataHandle new_ms = b->Add(
+ xla::XlaOp new_ms = b->Add(
ms,
b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms),
b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
- xla::ComputationDataHandle new_mom =
+ xla::XlaOp new_mom =
b->Add(b->Mul(mom, momentum),
b->Mul(b->Mul(grad, lr),
b->Pow(b->Add(new_ms, epsilon),
XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::ComputationDataHandle new_var = b->Sub(var, new_mom);
+ xla::XlaOp new_var = b->Sub(var, new_mom);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
@@ -341,10 +341,10 @@ REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape var_shape, accum_shape, linear_shape;
- xla::ComputationDataHandle var, accum, linear;
+ xla::XlaOp var, accum, linear;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
@@ -399,12 +399,12 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
errors::InvalidArgument("lr_power is not a scalar: ",
lr_power_shape.DebugString()));
- xla::ComputationDataHandle grad = ctx->Input(3);
- xla::ComputationDataHandle lr = ctx->Input(4);
- xla::ComputationDataHandle l1 = ctx->Input(5);
- xla::ComputationDataHandle l2 = ctx->Input(6);
- xla::ComputationDataHandle l2_shrinkage;
- xla::ComputationDataHandle lr_power;
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp l1 = ctx->Input(5);
+ xla::XlaOp l2 = ctx->Input(6);
+ xla::XlaOp l2_shrinkage;
+ xla::XlaOp lr_power;
if (has_l2_shrinkage) {
l2_shrinkage = ctx->Input(7);
lr_power = ctx->Input(8);
@@ -421,26 +421,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
// var = (linear_clipped - linear) / quadratic
// accum = new_accum
- xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
- xla::ComputationDataHandle grad_to_use;
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
+ xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var)));
} else {
grad_to_use = grad;
}
- xla::ComputationDataHandle new_accum =
- b->Add(accum, b->Pow(grad_to_use, two));
- xla::ComputationDataHandle new_accum_lr_pow =
- b->Pow(new_accum, b->Neg(lr_power));
- xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
+ xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two));
+ xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power));
+ xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
linear = b->Add(
linear,
b->Sub(grad_to_use,
b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var)));
- xla::ComputationDataHandle linear_clipped = b->Clamp(b->Neg(l1), linear, l1);
- xla::ComputationDataHandle quadratic =
- b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
+ xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1);
+ xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
var = b->Div(b->Sub(linear_clipped, linear), quadratic);
accum = new_accum;
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 7cb47f908d..a4f50f52eb 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -33,9 +33,9 @@ namespace {
public: \
explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
void Compile(XlaOpKernelContext* ctx) { \
- xla::ComputationBuilder* b = ctx->builder(); \
- xla::ComputationDataHandle x = ctx->Input(0); \
- xla::ComputationDataHandle y = COMPUTATION; \
+ xla::XlaBuilder* b = ctx->builder(); \
+ xla::XlaOp x = ctx->Input(0); \
+ xla::XlaOp y = COMPUTATION; \
ctx->SetOutput(0, y); \
} \
}; \
@@ -124,9 +124,8 @@ XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
-static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
- DataType dtype,
- const xla::ComputationDataHandle& x) {
+static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype,
+ const xla::XlaOp& x) {
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
@@ -148,9 +147,8 @@ XLAJIT_MAKE_UNARY(Rsqrt,
b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
-static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
- DataType dtype,
- const xla::ComputationDataHandle& x) {
+static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype,
+ const xla::XlaOp& x) {
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
}
@@ -162,20 +160,18 @@ XLAJIT_MAKE_UNARY(Sinh,
b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-static xla::ComputationDataHandle Softplus(
- xla::ComputationBuilder* b, DataType dtype,
- const xla::ComputationDataHandle& features) {
- xla::ComputationDataHandle threshold =
- b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
- XlaHelpers::FloatLiteral(b, dtype, 2.0));
+static xla::XlaOp Softplus(xla::XlaBuilder* b, DataType dtype,
+ const xla::XlaOp& features) {
+ xla::XlaOp threshold = b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
+ XlaHelpers::FloatLiteral(b, dtype, 2.0));
// Value above which exp(x) may overflow, but softplus(x) == x
// is within machine epsilon.
- xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold));
+ xla::XlaOp too_large = b->Gt(features, b->Neg(threshold));
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
// is within machine epsilon.
- xla::ComputationDataHandle too_small = b->Lt(features, threshold);
- xla::ComputationDataHandle features_exp = b->Exp(features);
- xla::ComputationDataHandle output = b->Select(
+ xla::XlaOp too_small = b->Lt(features, threshold);
+ xla::XlaOp features_exp = b->Exp(features);
+ xla::XlaOp output = b->Select(
too_large, features,
b->Select(too_small, features_exp,
b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype)))));
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index 71173f5aea..6109db8e89 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -48,7 +48,7 @@ class ReadVariableOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
OP_REQUIRES_OK(
ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
ctx->SetOutput(0, handle);
@@ -74,7 +74,7 @@ class AssignAddVariableOp : public XlaOpKernel {
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(1);
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Add(handle, ctx->Input(1));
@@ -90,7 +90,7 @@ class AssignSubVariableOp : public XlaOpKernel {
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(1);
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Sub(handle, ctx->Input(1));
@@ -105,19 +105,19 @@ class ResourceGatherOp : public XlaOpKernel {
public:
explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
DataType type = ctx->expected_output_dtype(0);
TensorShape resource_shape;
- xla::ComputationDataHandle resource_handle;
+ xla::XlaOp resource_handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
&resource_handle));
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
DataType index_type = ctx->input_type(1);
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
/*axis=*/0, /*indices_are_nd=*/false, type, index_type,
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 0ff1b65ae9..5467c5d994 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -101,7 +101,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
ctx, MakeXlaCompilerArgumentsFromInputs(
ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays));
- xla::ComputationBuilder* builder = ctx->builder();
+ xla::XlaBuilder* builder = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
VLOG(1) << "Compiling body";
@@ -234,7 +234,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::ShapeUtil::HumanString(cond.xla_output_shape)));
int num_inputs = body.input_mapping.size();
- std::vector<xla::ComputationDataHandle> inputs(num_inputs);
+ std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
if (ctx->input_type(input_num) == DT_RESOURCE) {
@@ -246,24 +246,24 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
}
- xla::ComputationDataHandle init = builder->Tuple(inputs);
+ xla::XlaOp init = builder->Tuple(inputs);
VLOG(1) << "Building while loop";
// Wraps the condition in a computation that unpacks the output tuple.
- xla::Computation cond_wrapper;
+ xla::XlaComputation cond_wrapper;
{
- std::unique_ptr<xla::ComputationBuilder> cb =
+ std::unique_ptr<xla::XlaBuilder> cb =
builder->CreateSubBuilder("cond_wrapper");
auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
auto outputs = cb->Call(*cond.computation, {inputs});
cb->GetTupleElement(outputs, 0);
- xla::StatusOr<xla::Computation> result = cb->Build();
+ xla::StatusOr<xla::XlaComputation> result = cb->Build();
OP_REQUIRES_OK(ctx, result.status());
cond_wrapper = std::move(result.ValueOrDie());
}
- xla::ComputationDataHandle while_result =
+ xla::XlaOp while_result =
builder->While(cond_wrapper, *body.computation, init);
// Sets non-variable outputs.
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 12fdfb605d..04ad3694a0 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -44,8 +44,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -62,9 +62,9 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -82,8 +82,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -101,9 +101,9 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -122,8 +122,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -161,8 +161,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 798f0fa780..526694d5a0 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -25,24 +25,22 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> BatchDot(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
- xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
- bool conjugate_x, bool conjugate_y) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
- builder->GetShape(x));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
- builder->GetShape(y));
+xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
+ xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x,
+ bool conjugate_y) {
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
// Check that both tensors have the same number of dimensions. There must be
// at least two (the batch dimensions can be empty).
- if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) {
+ if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
return errors::InvalidArgument(
"Arguments to BatchedDot have different ranks: ",
- xla::ShapeUtil::HumanString(*x_shape), " vs. ",
- xla::ShapeUtil::HumanString(*y_shape));
+ xla::ShapeUtil::HumanString(x_shape), " vs. ",
+ xla::ShapeUtil::HumanString(y_shape));
}
- const int ndims = xla::ShapeUtil::Rank(*x_shape);
+ const int ndims = xla::ShapeUtil::Rank(x_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to BatchedDot must have rank >= 2: ", ndims);
@@ -52,46 +50,46 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(
// valid.
std::vector<int64> batch_dimension_numbers;
for (int i = 0; i < ndims - 2; ++i) {
- if (x_shape->dimensions(i) != y_shape->dimensions(i)) {
+ if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
return errors::InvalidArgument(
"Dimension ", i, " of inputs to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(*x_shape), " vs ",
- xla::ShapeUtil::HumanString(*y_shape));
+ xla::ShapeUtil::HumanString(x_shape), " vs ",
+ xla::ShapeUtil::HumanString(y_shape));
}
batch_dimension_numbers.push_back(i);
}
int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
- if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) {
+ if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
return errors::InvalidArgument(
"Dimensions ", x_inner_dim, " and ", y_inner_dim,
" of arguments to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x,
- " vs. ", xla::ShapeUtil::HumanString(*y_shape),
+ xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
+ " vs. ", xla::ShapeUtil::HumanString(y_shape),
" transpose: ", transpose_y);
}
// Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::HasZeroElements(*x_shape) ||
- xla::ShapeUtil::HasZeroElements(*y_shape)) {
+ if (xla::ShapeUtil::HasZeroElements(x_shape) ||
+ xla::ShapeUtil::HasZeroElements(y_shape)) {
std::vector<int64> dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
- dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]);
+ dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
}
int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
- dimensions.push_back(x_shape->dimensions(x_outer_dim));
- dimensions.push_back(y_shape->dimensions(y_outer_dim));
+ dimensions.push_back(x_shape.dimensions(x_outer_dim));
+ dimensions.push_back(y_shape.dimensions(y_outer_dim));
return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())),
+ builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())),
dimensions);
}
- if (x_shape->element_type() == xla::C64 && conjugate_x) {
+ if (x_shape.element_type() == xla::C64 && conjugate_x) {
x = builder->Conj(x);
}
- if (y_shape->element_type() == xla::C64 && conjugate_y) {
+ if (y_shape.element_type() == xla::C64 && conjugate_y) {
y = builder->Conj(y);
}
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index b230e885f1..1acc72033b 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -43,10 +43,10 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::StatusOr<xla::ComputationDataHandle> BatchDot(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
- xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
- bool conjugate_x = false, bool conjugate_y = false);
+xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
+ xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x = false,
+ bool conjugate_y = false);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 203365e2ab..83e7382786 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -47,23 +47,21 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- const int n_dims = xla::ShapeUtil::Rank(*a_shape);
- const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape->dimensions()),
+xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
+ const xla::XlaOp& a) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int n_dims = xla::ShapeUtil::Rank(a_shape);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - 2);
- xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ xla::XlaOp l = Zeros(builder, a_shape);
// Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
- xla::ComputationBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::Shape col_shape;
xla::Shape row_shape;
for (int64 d : major_dims) {
@@ -72,12 +70,12 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
}
row_shape.add_dimensions(1);
row_shape.add_dimensions(n);
- row_shape.set_element_type(a_shape->element_type());
+ row_shape.set_element_type(a_shape.element_type());
auto mask_zeros_row = Zeros(body_builder, row_shape);
col_shape.add_dimensions(n);
col_shape.add_dimensions(1);
- col_shape.set_element_type(a_shape->element_type());
+ col_shape.set_element_type(a_shape.element_type());
auto mask_zeros_col = Zeros(body_builder, col_shape);
std::vector<int32> mask_vector(n);
@@ -101,7 +99,7 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
{i, i}, {1, 1}));
// np.dot(row, np.swapaxes(row, -1, -2))
- xla::ComputationDataHandle diag_dot;
+ xla::XlaOp diag_dot;
TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
/*transpose_x=*/false,
/*transpose_y=*/true));
@@ -109,7 +107,7 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
// np.swapaxes(row, -1, -2)))
auto l_ii = body_builder->Pow(
body_builder->Sub(a_ii, diag_dot),
- FloatLiteral(body_builder, a_shape->element_type(), 0.5));
+ FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i]
auto ip1 = body_builder->Add(i, body_builder->ConstantR0<int32>(1));
@@ -140,7 +138,7 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
body_builder, body_l, l_ii, {i, i}));
- return std::vector<xla::ComputationDataHandle>{body_a, body_l};
+ return std::vector<xla::XlaOp>{body_a, body_l};
};
TF_ASSIGN_OR_RETURN(
@@ -152,22 +150,20 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
} // namespace
-xla::StatusOr<xla::ComputationDataHandle> Cholesky(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- const int ndims = xla::ShapeUtil::Rank(*a_shape);
+xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
+ int64 block_size) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to Cholesky must have rank >= 2: ", ndims);
}
- const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
- if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument(
"Arguments to Cholesky must be square matrices: ",
- xla::ShapeUtil::HumanString(*a_shape));
+ xla::ShapeUtil::HumanString(a_shape));
}
if (block_size < 1) {
@@ -179,7 +175,7 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
// execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::ComputationDataHandle l = Zeros(builder, *a_shape);
+ xla::XlaOp l = Zeros(builder, a_shape);
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
if (i > 0) {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 17da8d8b22..20fca7969e 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -30,9 +30,8 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::StatusOr<xla::ComputationDataHandle> Cholesky(
- xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
- int64 block_size = 256);
+xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
+ int64 block_size = 256);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 45699233ea..d5a27abb25 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -30,24 +30,19 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
- const xla::ComputationDataHandle& buffer,
- const xla::ComputationDataHandle& updates,
- const xla::ComputationDataHandle& indices, bool indices_are_vectors,
- const std::function<xla::ComputationDataHandle(
- xla::ComputationDataHandle, xla::ComputationDataHandle,
- xla::ComputationBuilder*)>& combiner,
- xla::ComputationBuilder* builder) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> buffer_shape,
- builder->GetShape(buffer));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> updates_shape,
- builder->GetShape(updates));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> indices_shape,
- builder->GetShape(indices));
+xla::StatusOr<xla::XlaOp> XlaScatter(
+ const xla::XlaOp& buffer, const xla::XlaOp& updates,
+ const xla::XlaOp& indices, bool indices_are_vectors,
+ const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
+ combiner,
+ xla::XlaBuilder* builder) {
+ TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
+ TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
+ TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
gtl::ArraySlice<int64> indices_dims =
- xla::AsInt64Slice(indices_shape->dimensions());
+ xla::AsInt64Slice(indices_shape.dimensions());
gtl::ArraySlice<int64> buffer_dims =
- xla::AsInt64Slice(buffer_shape->dimensions());
+ xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@@ -55,12 +50,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
if (indices_are_vectors) {
TF_RET_CHECK(!indices_dims.empty());
num_index_dims = indices_dims.back();
- if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) {
+ if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) {
return errors::InvalidArgument(
"The size of the minor dimension of the indices (shape: ",
- xla::ShapeUtil::HumanString(*indices_shape),
+ xla::ShapeUtil::HumanString(indices_shape),
") must be <= the rank of the buffer (shape: ",
- xla::ShapeUtil::HumanString(*buffer_shape), ")");
+ xla::ShapeUtil::HumanString(buffer_shape), ")");
}
indices_dims.pop_back();
}
@@ -78,10 +73,10 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// If any of the indexed dimensions are zero in the buffer, the update cannot
// succeed since it updates a slice of size 1.
for (int64 i = 0; i < num_index_dims; ++i) {
- if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) {
- return errors::InvalidArgument(
- "Scatter dimension ", i, " is of size zero in tensor with shape ",
- xla::ShapeUtil::HumanString(*buffer_shape));
+ if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) {
+ return errors::InvalidArgument("Scatter dimension ", i,
+ " is of size zero in tensor with shape ",
+ xla::ShapeUtil::HumanString(buffer_shape));
}
}
@@ -111,18 +106,17 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::ComputationDataHandle i,
- gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
- xla::ComputationBuilder* body_builder) {
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
auto zero_index = body_builder->ConstantLiteral(
- xla::Literal::Zero(indices_shape->element_type()));
+ xla::Literal::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
- xla::ComputationDataHandle index;
+ xla::XlaOp index;
auto indices_offset = body_builder->Reshape(i, {1});
if (indices_are_vectors) {
indices_offset = body_builder->Pad(indices_offset, zero_index,
@@ -180,12 +174,12 @@ xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
// Apply the update.
buffer = body_builder->DynamicUpdateSlice(buffer, update, index);
- return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
+ return std::vector<xla::XlaOp>{indices, updates, buffer};
};
- TF_ASSIGN_OR_RETURN(
- auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
- body_fn, init, "scatter", builder));
+ TF_ASSIGN_OR_RETURN(auto outputs,
+ XlaForEachIndex(num_indices, indices_shape.element_type(),
+ body_fn, init, "scatter", builder));
return outputs[2];
}
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h
index 41e6d3b195..87309e10ed 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.h
+++ b/tensorflow/compiler/tf2xla/lib/scatter.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <functional>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow {
@@ -39,14 +39,12 @@ namespace tensorflow {
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the
// existing values. The order of updates is implementation-defined.
-xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
- const xla::ComputationDataHandle& buffer,
- const xla::ComputationDataHandle& updates,
- const xla::ComputationDataHandle& indices, bool indices_are_vectors,
- const std::function<xla::ComputationDataHandle(
- xla::ComputationDataHandle, xla::ComputationDataHandle,
- xla::ComputationBuilder*)>& combiner,
- xla::ComputationBuilder* builder);
+xla::StatusOr<xla::XlaOp> XlaScatter(
+ const xla::XlaOp& buffer, const xla::XlaOp& updates,
+ const xla::XlaOp& indices, bool indices_are_vectors,
+ const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
+ combiner,
+ xla::XlaBuilder* builder);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 9bf5821b54..d0279d4412 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -29,21 +29,20 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
- bool conjugate_a, int64 block_size) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
- builder->GetShape(b));
- if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) {
+xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
+ const xla::XlaOp& a, xla::XlaOp b,
+ bool left_side, bool lower,
+ bool transpose_a, bool conjugate_a,
+ int64 block_size) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
return errors::InvalidArgument(
"Arguments to TriangularSolve have different ranks: ",
- xla::ShapeUtil::HumanString(*a_shape), " vs. ",
- xla::ShapeUtil::HumanString(*b_shape));
+ xla::ShapeUtil::HumanString(a_shape), " vs. ",
+ xla::ShapeUtil::HumanString(b_shape));
}
- const int ndims = xla::ShapeUtil::Rank(*a_shape);
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to TriangularSolve must have rank >= 2: ", ndims);
@@ -51,30 +50,30 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// The batch dimensions must be equal.
std::vector<int64> batch_dimensions;
for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape->dimensions(i);
- int64 b_size = b_shape->dimensions(i);
+ int64 a_size = a_shape.dimensions(i);
+ int64 b_size = b_shape.dimensions(i);
if (a_size != b_size) {
return errors::InvalidArgument(
"Batch dimensions of arguments to TriangularSolve must be equal: ",
- xla::ShapeUtil::HumanString(*a_shape), " vs ",
- xla::ShapeUtil::HumanString(*b_shape));
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
batch_dimensions.push_back(a_size);
}
- if (xla::ShapeUtil::GetDimension(*a_shape, -1) !=
- xla::ShapeUtil::GetDimension(*a_shape, -2)) {
+ if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
+ xla::ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument(
"The 'a' arguments to TriangularSolve must be square matrices: ",
- xla::ShapeUtil::HumanString(*a_shape));
+ xla::ShapeUtil::HumanString(a_shape));
}
- const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
- if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) {
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
return errors::InvalidArgument(
"Arguments to TriangularSolve have incompatible matrix shapes: ",
- xla::ShapeUtil::HumanString(*a_shape), " vs ",
- xla::ShapeUtil::HumanString(*b_shape));
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
if (block_size < 1) {
@@ -85,24 +84,23 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
- auto maybe_conj = [&](xla::ComputationBuilder* builder,
- xla::ComputationDataHandle x) {
- auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
+ auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) {
+ auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};
- std::map<int, xla::Computation> base_computations;
+ std::map<int, xla::XlaComputation> base_computations;
auto get_base_triangular_solve =
- [&](int k) -> xla::StatusOr<xla::Computation*> {
- xla::Computation& computation = base_computations[k];
+ [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
+ xla::XlaComputation& computation = base_computations[k];
if (computation.IsNull()) {
- std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder(
+ std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
tensorflow::strings::StrCat("trsm_base_", k));
auto a_param = sub->Parameter(
0,
xla::ShapeUtil::MakeShape(
- b_shape->element_type(),
+ b_shape.element_type(),
PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
"a");
@@ -115,7 +113,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
auto b_param = sub->Parameter(
1,
xla::ShapeUtil::MakeShape(
- b_shape->element_type(),
+ b_shape.element_type(),
PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
"b");
@@ -142,7 +140,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
return &computation;
};
- xla::ComputationDataHandle output = Zeros(builder, *b_shape);
+ xla::XlaOp output = Zeros(builder, b_shape);
// Right-looking blocked triangular solve.
// For an explanation of the algorithm, see the TRSM discussion in:
@@ -165,9 +163,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -181,7 +179,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
if (i + k < n) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
@@ -215,9 +213,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -231,7 +229,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
if (i + k < m) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
@@ -264,9 +262,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -280,7 +278,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
if (i - k >= 0) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
@@ -314,9 +312,9 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::ComputationDataHandle update;
+ xla::XlaOp update;
if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::Computation * solve,
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
@@ -330,7 +328,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
if (i - k >= 0) {
- xla::ComputationDataHandle a_slice_2;
+ xla::XlaOp a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
@@ -356,26 +354,25 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
return output;
}
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
- builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
- builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(*a_shape);
+xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
+ const xla::XlaOp& a,
+ const xla::XlaOp& b,
+ bool transpose_a,
+ bool conjugate_a) {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
std::vector<int64> batch_dimensions;
for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape->dimensions(i);
+ int64 a_size = a_shape.dimensions(i);
batch_dimensions.push_back(a_size);
}
- auto maybe_conj = [&](xla::ComputationBuilder* builder,
- xla::ComputationDataHandle x) {
- auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
+ auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) {
+ auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};
@@ -387,7 +384,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
// else:
// output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
- xla::ComputationDataHandle output = Zeros(builder, *b_shape);
+ xla::XlaOp output = Zeros(builder, b_shape);
{
auto i = transpose_a ? m - 1 : 0;
TF_ASSIGN_OR_RETURN(auto a_slice,
@@ -408,11 +405,11 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// The loop iteration counter is a scalar, incremented each iteration.
xla::ShapeUtil::MakeShape(xla::S32, {}),
// The output has the shape of b, with one row updated each iteration.
- *b_shape,
+ b_shape,
// The coefficient matrix a is a loop invariant.
- *a_shape,
+ a_shape,
// The right-hand-side matrix b is a loop invariant.
- *b_shape};
+ b_shape};
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
auto init = builder->Tuple({init_i, output, a, b});
@@ -421,7 +418,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// def cond_fun(loop_carry):
// i, output, a, b = loop_carry
// return i >= 0 if transpose_a else i < m
- std::unique_ptr<xla::ComputationBuilder> condb =
+ std::unique_ptr<xla::XlaBuilder> condb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
{
auto i = condb->GetTupleElement(
@@ -451,7 +448,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// return (i + 1, output, a, b)
// We have to do some extra FLOPs propagating zeros in the matrix multiply
// because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::ComputationBuilder> bodyb =
+ std::unique_ptr<xla::XlaBuilder> bodyb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
{
auto input_tuple = bodyb->Parameter(0, tuple_shape,
@@ -475,7 +472,7 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// But since we can't have intermediate array sizes depend on the loop
// counter, we instead exploit the fact that we initialized the output to
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
- xla::ComputationDataHandle a_row;
+ xla::XlaOp a_row;
if (transpose_a) {
TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
{zero, i}, {m, 1}));
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index e32223bfdd..fd8f2489d1 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
namespace tensorflow {
@@ -57,14 +57,17 @@ namespace tensorflow {
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
- bool conjugate_a, int64 block_size = 256);
+xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
+ const xla::XlaOp& a, xla::XlaOp b,
+ bool left_side, bool lower,
+ bool transpose_a, bool conjugate_a,
+ int64 block_size = 256);
-xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
- const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a);
+xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
+ const xla::XlaOp& a,
+ const xla::XlaOp& b,
+ bool transpose_a,
+ bool conjugate_a);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index 6617070629..87ea4763f7 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -80,9 +80,9 @@ xla::Array2D<float> AValsFull() {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -102,9 +102,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -124,9 +124,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -146,9 +146,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -168,9 +168,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -191,9 +191,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -214,9 +214,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -237,9 +237,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
@@ -260,9 +260,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data =
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
auto b_data =
@@ -288,9 +288,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data =
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
auto b_data =
@@ -318,9 +318,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
}
XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolveLeftLooking(&builder, a, b,
@@ -340,9 +340,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
}
XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b;
+ xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolveLeftLooking(&builder, a, b,
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 31d823ca33..cc7b13571c 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -27,15 +27,14 @@ limitations under the License.
namespace tensorflow {
-xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- const xla::Shape& shape) {
+xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
return builder->Broadcast(
builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
xla::AsInt64Slice(shape.dimensions()));
}
-xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, double value) {
+xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ double value) {
switch (type) {
case xla::F16:
return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
@@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
}
}
-xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type,
- int64 value) {
+xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ int64 value) {
xla::Literal literal;
switch (type) {
case xla::U8:
@@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
return builder->ConstantLiteral(literal);
}
-xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {
+xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end) {
TF_RET_CHECK(start.size() == end.size());
int64 n_minor_dims = start.size();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - n_minor_dims);
@@ -140,7 +139,7 @@ xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
return builder->Slice(x, padded_start, padded_end, strides);
}
-std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
const gtl::ArraySlice<int64>& major_dims,
const gtl::ArraySlice<int64>& indices) {
std::vector<int64> output(indices.size() + major_dims.size());
@@ -149,16 +148,16 @@ std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
return output;
}
-xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts,
+xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts,
const gtl::ArraySlice<int64>& sizes) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - sizes.size());
TF_ASSIGN_OR_RETURN(auto padded_starts,
@@ -167,27 +166,29 @@ xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
return builder->DynamicSlice(x, padded_starts, padded_sizes);
}
-xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
+xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start) {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
auto start_constant = builder->ConstantR1<int32>(start_as_int32);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> start_constant_shape,
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
builder->GetShape(start_constant));
const int64 start_length =
- xla::ShapeUtil::GetDimension(*start_constant_shape, -1);
+ xla::ShapeUtil::GetDimension(start_constant_shape, -1);
TF_RET_CHECK(start_length == n_dims);
return builder->DynamicUpdateSlice(x, update, start_constant);
}
-xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
const int64 n_minor_dims = start.size();
TF_RET_CHECK(n_minor_dims <= n_dims);
std::vector<int64> padded_start(n_dims, 0);
@@ -196,22 +197,21 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
return UpdateSlice(builder, x, update, padded_start);
}
-xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update,
- const std::vector<xla::ComputationDataHandle>& starts) {
+xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
+ const std::vector<xla::XlaOp>& starts) {
TF_ASSIGN_OR_RETURN(auto padded_starts,
PrependZerosInMajorDims(builder, x, starts));
return builder->DynamicUpdateSlice(x, update, padded_starts);
}
-xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
- std::vector<xla::ComputationDataHandle> padded_starts(n_dims, zero);
+ std::vector<xla::XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
padded_starts[n_dims - starts.size() + i] =
builder->Reshape(starts[i], {1});
@@ -219,10 +219,10 @@ xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
return builder->ConcatInDim(padded_starts, 0);
}
-xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index b684123f13..3df44ef035 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -16,75 +16,74 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
// Returns a zero-filled tensor with shape `shape`.
-xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- const xla::Shape& shape);
+xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape);
// Returns a floating point scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
-xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, double value);
+xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ double value);
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
// prepended until the array is length n_dims.
-xla::ComputationDataHandle PrependZerosInMajorDims(
- xla::ComputationBuilder* builder,
- gtl::ArraySlice<xla::ComputationDataHandle> starts);
+xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder,
+ gtl::ArraySlice<xla::XlaOp> starts);
// Returns a integer scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
-xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, int64 value);
+xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ int64 value);
// Builds a vector of zeros of length rank(x) with the last two values being
// those in `starts`.
-xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts);
+xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts);
// Performs a slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end);
+xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end);
// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`.
-std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
const gtl::ArraySlice<int64>& major_dims,
const gtl::ArraySlice<int64>& indices);
// Performs a dynamic slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts,
- const gtl::ArraySlice<int64>& sizes);
+xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts, const gtl::ArraySlice<int64>& sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
-xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
-xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
+xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start);
-xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update,
- const std::vector<xla::ComputationDataHandle>& starts);
+xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
+ const std::vector<xla::XlaOp>& starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
-xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
+xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index b6bd33af2e..265b39402c 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -65,9 +64,9 @@ xla::Array3D<float> BatchedAValsFull() {
}
XLA_TEST_F(UtilTest, Simple2dLookup) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, x, y;
+ xla::XlaOp a, x, y;
auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
@@ -80,9 +79,9 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
}
XLA_TEST_F(UtilTest, Simple3dLookup) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, index;
+ xla::XlaOp a, index;
auto a_data =
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
@@ -97,9 +96,9 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
}
XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
- xla::ComputationDataHandle a, b, x, y;
+ xla::XlaOp a, b, x, y;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
@@ -117,11 +116,11 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
}
XLA_TEST_F(UtilTest, RowBatchDot) {
- xla::ComputationBuilder builder(client_, TestName());
+ xla::XlaBuilder builder(TestName());
int n = 4;
- xla::ComputationDataHandle a, row, index;
+ xla::XlaOp a, row, index;
auto a_data =
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 495d9c6078..09ce594930 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -20,24 +20,24 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder) {
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
var_shapes.reserve(arity);
- for (const xla::ComputationDataHandle& input : initial_values) {
+ for (const xla::XlaOp& input : initial_values) {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
- var_shapes.push_back(std::move(*shape));
+ var_shapes.push_back(std::move(shape));
}
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
// Unpacks a tuple into its component parts.
- auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity,
- xla::ComputationBuilder* builder) {
- std::vector<xla::ComputationDataHandle> elements(arity);
+ auto unpack_tuple = [](xla::XlaOp tuple, int arity,
+ xla::XlaBuilder* builder) {
+ std::vector<xla::XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) {
elements[i] = builder->GetTupleElement(tuple, i);
}
@@ -45,20 +45,20 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
};
// Build the condition.
- std::unique_ptr<xla::ComputationBuilder> cond_builder =
+ std::unique_ptr<xla::XlaBuilder> cond_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
{
auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
- TF_ASSIGN_OR_RETURN(
- auto result,
+ TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
- cond_builder.get()));
+ cond_builder.get())
+ .status());
}
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
// Build the body.
- std::unique_ptr<xla::ComputationBuilder> body_builder =
+ std::unique_ptr<xla::XlaBuilder> body_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
{
auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
@@ -78,38 +78,38 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
return unpack_tuple(outputs, arity, builder);
}
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder) {
- auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
- xla::ComputationBuilder* cond_builder)
- -> xla::StatusOr<xla::ComputationDataHandle> {
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder) {
+ auto while_cond_fn =
+ [&](gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
return cond_builder->Lt(
values[0],
IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
};
- auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
- xla::ComputationBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
- xla::ComputationDataHandle iteration = values[0];
+ auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ xla::XlaOp iteration = values[0];
- std::vector<xla::ComputationDataHandle> updated_values;
+ std::vector<xla::XlaOp> updated_values;
updated_values.reserve(values.size());
updated_values.push_back(body_builder->Add(
iteration,
body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
values.remove_prefix(1);
- TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
+ TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(),
body_outputs.end());
return updated_values;
};
- std::vector<xla::ComputationDataHandle> values;
+ std::vector<xla::XlaOp> values;
values.reserve(initial_values.size() + 1);
values.push_back(
builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h
index 2e67a0c99b..5b6684c995 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.h
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.h
@@ -19,8 +19,8 @@ limitations under the License.
#include <functional>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -29,14 +29,14 @@ namespace tensorflow {
// Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds.
-typedef std::function<xla::StatusOr<xla::ComputationDataHandle>(
- gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
+typedef std::function<xla::StatusOr<xla::XlaOp>(gtl::ArraySlice<xla::XlaOp>,
+ xla::XlaBuilder*)>
LoopConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values.
-typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
- gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
+typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
+ gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
LoopBodyFunction;
// Helper function for building an XLA while loop, where the values carried by
@@ -47,27 +47,26 @@ typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
// init: (a, b, c)
// )
// 'name' is a descriptive name for the loop.
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder);
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times.
//
// The body function (ForEachIndexBodyFunction) takes as input a pair of
// (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values.
-typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
- xla::ComputationDataHandle, gtl::ArraySlice<xla::ComputationDataHandle>,
- xla::ComputationBuilder*)>
+typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
+ xla::XlaOp, gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
ForEachIndexBodyFunction;
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder);
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 6051d7dffd..3a08aa8cf4 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -251,7 +251,7 @@ Status CreateXlaArgs(const Graph& graph,
// Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO.
Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
- xla::Computation* computation) {
+ xla::XlaComputation* computation) {
XlaOpRegistry::RegisterCompilationKernels();
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(
@@ -303,7 +303,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
}
// InitGraph creates a graph based on the graph_def, that may then be converted
-// to an xla::Computation via ConvertGraphToXla.
+// to an xla::XlaComputation via ConvertGraphToXla.
//
// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
// and outputs of the function that will be compiled. Each feed id causes a new
@@ -348,7 +348,7 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status ConvertGraphDefToXla(const GraphDef& graph_def,
const tf2xla::Config& config, xla::Client* client,
- xla::Computation* computation) {
+ xla::XlaComputation* computation) {
std::unique_ptr<Graph> graph;
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation));
diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h
index 473c431b12..d02fc56c5b 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.h
+++ b/tensorflow/compiler/tf2xla/tf2xla.h
@@ -18,21 +18,21 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/core/framework/graph.pb.h"
namespace tensorflow {
-// Converts a tensorflow::GraphDef into an xla::Computation. The given `config`
-// specifies the portion of the graph to convert, via feeds and fetches. Each
-// feed is a positional input argument for the generated computation, while each
-// fetch is a positional output argument.
+// Converts a tensorflow::GraphDef into an xla::XlaComputation. The given
+// `config` specifies the portion of the graph to convert, via feeds and
+// fetches. Each feed is a positional input argument for the generated
+// computation, while each fetch is a positional output argument.
//
// The computation is built in the context of the given `client`, which may
// subsequently be used to compile or execute the computation.
Status ConvertGraphDefToXla(const GraphDef& graph_def,
const tf2xla::Config& config, xla::Client* client,
- xla::Computation* computation);
+ xla::XlaComputation* computation);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index b813668a9e..84c133ffab 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -69,7 +69,7 @@ TEST(ConvertGraphDefToXla, Sum) {
tf2xla::Config config = SumConfig();
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
- xla::Computation computation;
+ xla::XlaComputation computation;
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation));
// Set up arguments.
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index fcb0a4e638..fe7ec633ec 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/platform/mem.h"
@@ -108,7 +109,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
// If no sharding metadata is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on device
// 0.
- xla::ScopedShardingAssignment assign_sharding(b, op_sharding);
+ xla::XlaScopedShardingAssignment assign_sharding(b, op_sharding);
op_kernel->Compute(context);
b->ClearOpMetadata();
@@ -126,9 +127,7 @@ Status XlaCompilationDevice::MakeTensorFromProto(
XlaExpression::XlaExpression() = default;
-void XlaExpression::set_handle(const xla::ComputationDataHandle& h) {
- handle_ = h;
-}
+void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; }
void XlaExpression::set_constant_value(Tensor value) {
has_constant_value_ = true;
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index 0243ee332f..d0b9e34e16 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/tf2xla/xla_resource.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
@@ -69,7 +69,7 @@ class XlaCompilationDevice : public LocalDevice {
// A XlaExpression wraps an XLA computation. Each Tensor on an
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
-// matches the shape of the subcomputation in the ComputationDataHandle. Each
+// matches the shape of the subcomputation in the XlaOp. Each
// expression is either a constant, or a function of previously-compiled
// expressions.
class XlaExpression {
@@ -78,8 +78,8 @@ class XlaExpression {
// handle() stores the XLA handle of the computation that the
// expression represents.
- void set_handle(const xla::ComputationDataHandle& h);
- const xla::ComputationDataHandle& handle() const { return handle_; }
+ void set_handle(const xla::XlaOp& h);
+ const xla::XlaOp& handle() const { return handle_; }
void set_constant_value(Tensor value);
bool has_constant_value() const { return has_constant_value_; }
@@ -90,7 +90,7 @@ class XlaExpression {
private:
// The XLA handle of the expression's computation.
- xla::ComputationDataHandle handle_;
+ xla::XlaOp handle_;
// If this expression is a constant with a known value, 'constant_value' is a
// host-memory Tensor containing the value. Used to avoid invoking XLA for
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index c0e9967684..3d1946c332 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -339,11 +339,11 @@ Status BuildComputation(
const std::vector<int>& arg_cores,
const std::vector<XlaExpression>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
- bool return_updated_values_for_all_resources,
- xla::ComputationBuilder* builder, xla::Computation* computation,
- int* num_computation_outputs, int* num_nonconst_outputs,
+ bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
+ xla::XlaComputation* computation, int* num_computation_outputs,
+ int* num_nonconst_outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
- std::vector<xla::ComputationDataHandle> elems;
+ std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
for (const XlaExpression& retval : retvals) {
if (!retval.has_constant_value()) {
@@ -376,14 +376,12 @@ Status BuildComputation(
const XlaCompiler::Argument& arg = args[resource->arg_num()];
const int core = arg_cores[resource->arg_num()];
DCHECK_LT(resource->arg_num(), arg_cores.size());
- bool modified =
- resource->value().handle() != resource->initial_value().handle();
+ bool modified = resource->value() != resource->initial_value();
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
modified = modified ||
- grad.second->value().handle() !=
- grad.second->initial_value().handle() ||
+ grad.second->value() != grad.second->initial_value() ||
arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
@@ -398,11 +396,11 @@ Status BuildComputation(
}
// Request that the value be returned on a specific core.
- xla::ScopedShardingAssignment assign_sharding(
+ xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
- xla::ComputationDataHandle handle;
+ xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Since we can't change the sharding metadata of <value> as this point,
@@ -421,7 +419,7 @@ Status BuildComputation(
builder->Tuple(elems);
builder->ClearOpMetadata();
- xla::StatusOr<xla::Computation> computation_status = builder->Build();
+ xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status();
}
@@ -435,7 +433,7 @@ Status BuildComputation(
// `args` are the arguments to the computation.
Status XlaCompiler::BuildArguments(
const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
- bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context,
+ bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
bool is_entry_computation) {
@@ -461,8 +459,7 @@ Status XlaCompiler::BuildArguments(
// alias.
XlaResource* resource;
TF_RETURN_IF_ERROR(context->CreateResource(
- arg.resource_kind, i, arg.name, arg.type, arg.shape,
- xla::ComputationDataHandle(),
+ arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
/*tensor_array_size=*/arg.tensor_array_size,
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource);
@@ -531,9 +528,9 @@ Status XlaCompiler::BuildArguments(
builder->SetOpMetadata(arg_metadata);
// Build parameter handles for non-constant arguments.
- std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
+ std::vector<xla::XlaOp> arg_handles(input_mapping->size());
if (use_tuple_arg) {
- xla::ComputationDataHandle tuple;
+ xla::XlaOp tuple;
if (is_entry_computation) {
xla::OpSharding tuple_sharding;
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
@@ -544,15 +541,15 @@ Status XlaCompiler::BuildArguments(
core == -1 ? xla::sharding_builder::AssignDevice(root_device)
: xla::sharding_builder::AssignDevice(core);
}
- xla::ScopedShardingAssignment assign_tuple_sharding(builder,
- tuple_sharding);
+ xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
+ tuple_sharding);
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
} else {
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
}
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
- xla::ScopedShardingAssignment assign_sharding(
+ xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] = builder->GetTupleElement(tuple, i);
@@ -560,7 +557,7 @@ Status XlaCompiler::BuildArguments(
} else {
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
- xla::ScopedShardingAssignment assign_sharding(
+ xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] =
@@ -647,7 +644,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
std::unique_ptr<Graph> graph,
const std::vector<XlaCompiler::Argument>& args,
CompilationResult* result) {
- VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
+ VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
@@ -663,7 +660,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
TF_RETURN_IF_ERROR(
FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
- xla::ComputationBuilder builder(client(), name);
+ xla::XlaBuilder builder(name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
options.resolve_compile_time_constants,
@@ -683,7 +680,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_nonconst_outputs;
int num_computation_outputs;
- result->computation = std::make_shared<xla::Computation>();
+ result->computation = std::make_shared<xla::XlaComputation>();
TF_RETURN_IF_ERROR(BuildComputation(
args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources, &builder,
@@ -814,7 +811,7 @@ Status XlaCompiler::SetHostToDeviceMetadata(
}
Status XlaCompiler::GetHostComputeControlDependency(
- const string& host_compute_name, xla::ComputationDataHandle* handle) {
+ const string& host_compute_name, xla::XlaOp* handle) {
const auto iter = host_compute_control_output_.find(host_compute_name);
if (iter == host_compute_control_output_.end()) {
return errors::InvalidArgument(
@@ -827,7 +824,7 @@ Status XlaCompiler::GetHostComputeControlDependency(
}
Status XlaCompiler::SetHostComputeControlDependency(
- const string& host_compute_name, const xla::ComputationDataHandle& handle) {
+ const string& host_compute_name, const xla::XlaOp& handle) {
if (host_compute_control_output_.find(host_compute_name) !=
host_compute_control_output_.end()) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 8f564f35ec..ca6cd822ef 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -227,7 +227,7 @@ class XlaCompiler {
std::vector<ResourceUpdate> resource_updates;
// The XLA computation built from the tensorflow subgraph.
- std::shared_ptr<xla::Computation> computation;
+ std::shared_ptr<xla::XlaComputation> computation;
};
struct Options {
@@ -281,7 +281,7 @@ class XlaCompiler {
const NameAttrList& fn_name_attrs,
std::vector<Argument> args, CompilationResult* result);
- // Compiles a tensorflow::Graph into an xla::Computation.
+ // Compiles a tensorflow::Graph into an xla::XlaComputation.
// Similar to CompileFunction, but takes a Graph as input rather than a
// function.
Status CompileGraph(const CompileOptions& options, string const& name,
@@ -290,7 +290,7 @@ class XlaCompiler {
CompilationResult* result);
// Compiles a single Op, given by an OpKernelContext, into an
- // xla::Computation. Similar to CompileFunction but takes a single Op as
+ // xla::XlaComputation. Similar to CompileFunction but takes a single Op as
// input.
Status CompileSingleOp(const CompileOptions& options, string const& name,
OpKernelContext* ctx,
@@ -337,10 +337,9 @@ class XlaCompiler {
// a given HostCompute Op as long as the names are unique within the
// compilation.
Status GetHostComputeControlDependency(const string& host_compute_name,
- xla::ComputationDataHandle* handle);
- Status SetHostComputeControlDependency(
- const string& host_compute_name,
- const xla::ComputationDataHandle& handle);
+ xla::XlaOp* handle);
+ Status SetHostComputeControlDependency(const string& host_compute_name,
+ const xla::XlaOp& handle);
const Options& options() const { return options_; }
xla::Client* client() const { return options_.client; }
@@ -358,7 +357,7 @@ class XlaCompiler {
// `args` are the arguments to the computation.
Status BuildArguments(const Graph& graph,
const std::vector<XlaCompiler::Argument>& args,
- bool use_tuple_arg, xla::ComputationBuilder* builder,
+ bool use_tuple_arg, xla::XlaBuilder* builder,
XlaContext* context, std::vector<int>* arg_cores,
std::vector<XlaExpression>* arg_expressions,
std::vector<int>* input_mapping,
@@ -408,8 +407,7 @@ class XlaCompiler {
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
- std::unordered_map<string, xla::ComputationDataHandle>
- host_compute_control_output_;
+ std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 096dc7160b..6b8918b261 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -164,7 +164,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
DummyDuplicateOp);
-
// Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest, EmptyReturnValues) {
XlaCompiler compiler(DefaultOptions());
@@ -433,21 +432,26 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
}
for (int64 i = 1; i < test_count; ++i) {
- auto m1 =
- results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests();
- auto m2 =
- results[i].computation->Snapshot().ValueOrDie()->entry().requests();
- // Check if every entry is the same.
- for (auto& entry1 : m1) {
- int64 key = entry1.first;
- auto value1 = entry1.second;
- auto entry2 = m2.find(key);
- auto value2 = entry2->second;
- EXPECT_TRUE(entry2 != m2.end());
- string str1, str2;
- value1.AppendToString(&str1);
- value2.AppendToString(&str2);
- EXPECT_EQ(str1, str2);
+ const auto& m1 = results[i - 1].computation->proto();
+ const auto& m2 = results[i].computation->proto();
+ ASSERT_EQ(m1.computations_size(), m2.computations_size());
+ // Check if every hlo computation is the same.
+ for (int k = 0; k < m1.computations_size(); k++) {
+ const auto& c1 = m1.computations(k);
+ const auto& c2 = m2.computations(k);
+ ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
+ for (int j = 0; j < c1.instructions_size(); j++) {
+ auto instr1 = c1.instructions(j);
+ auto instr2 = c2.instructions(j);
+ instr1.clear_name();
+ instr2.clear_name();
+ // The names of instructions were uniquified by the XlaBuilder, the rest
+ // of the fields should be identical.
+ string str1, str2;
+ instr1.AppendPartialToString(&str1);
+ instr2.AppendPartialToString(&str2);
+ EXPECT_EQ(str1, str2);
+ }
}
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 8423921086..3dd2d183f3 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -63,7 +63,7 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
}
XlaContext::XlaContext(
- XlaCompiler* compiler, xla::ComputationBuilder* builder,
+ XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn)
@@ -78,7 +78,7 @@ string XlaContext::DebugString() { return "TLA JIT context"; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
void XlaContext::AddRetval(int retval_index, DataType type,
- const xla::ComputationDataHandle& handle) {
+ const xla::XlaOp& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up.
if (retvals_.size() <= retval_index) {
@@ -104,13 +104,12 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
return Status::OK();
}
-xla::ComputationBuilder* XlaContext::builder() { return builder_; }
+xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(
XlaResource::Kind kind, int arg_num, string name, DataType type,
- TensorShape shape, const xla::ComputationDataHandle& handle,
- int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
- XlaResource** resource) {
+ TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients, XlaResource** resource) {
resources_.emplace_back(
new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
handle, tensor_array_size, tensor_array_gradients));
@@ -123,11 +122,11 @@ TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
return (*variable_representation_shape_fn_)(shape, type);
}
-const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
+const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
return LookupOrCreate(type, &max_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Max() for " << type_string;
- xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">");
+ xla::XlaBuilder b("max<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@@ -137,11 +136,11 @@ const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
});
}
-const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) {
+const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
return LookupOrCreate(type, &min_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Min() for " << type_string;
- xla::ComputationBuilder b(builder()->client(), "min<" + type_string + ">");
+ xla::XlaBuilder b("min<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@@ -151,11 +150,11 @@ const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) {
});
}
-const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
+const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
return LookupOrCreate(type, &add_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Add() for " << type_string;
- xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">");
+ xla::XlaBuilder b("add<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@@ -165,11 +164,11 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
});
}
-const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) {
+const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
return LookupOrCreate(type, &mul_func_, [this, type] {
const string type_string = DataTypeString(type);
VLOG(1) << "Building Mul() for " << type_string;
- xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">");
+ xla::XlaBuilder b("mul<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
@@ -179,9 +178,9 @@ const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) {
});
}
-const xla::Computation* XlaContext::LookupOrCreate(
+const xla::XlaComputation* XlaContext::LookupOrCreate(
DataType type, ComputationMap* out,
- const std::function<xla::Computation()>& create) {
+ const std::function<xla::XlaComputation()>& create) {
{
const auto& entry = (*out)[type];
if (!entry.IsNull()) {
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 00fbaba37c..1136ffe507 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -43,7 +43,7 @@ class XlaContext : public ResourceBase {
static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext.
- XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
+ XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn);
@@ -53,9 +53,8 @@ class XlaContext : public ResourceBase {
XlaCompiler* compiler() const { return compiler_; }
- // Returns the ComputationBuilder that Ops use for compiling new
- // expressions.
- xla::ComputationBuilder* builder();
+ // Returns the XlaBuilder that Ops use for compiling new expressions.
+ xla::XlaBuilder* builder();
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
@@ -66,8 +65,7 @@ class XlaContext : public ResourceBase {
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
- void AddRetval(int retval_index, DataType type,
- const xla::ComputationDataHandle& handle);
+ void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle);
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
@@ -79,8 +77,7 @@ class XlaContext : public ResourceBase {
// Fails if the resource already exists.
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
DataType type, TensorShape shape,
- const xla::ComputationDataHandle& handle,
- int64 tensor_array_size,
+ const xla::XlaOp& handle, int64 tensor_array_size,
const std::set<string>& tensor_array_gradients,
XlaResource** resource);
@@ -96,22 +93,22 @@ class XlaContext : public ResourceBase {
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateMax(const DataType type);
+ const xla::XlaComputation* GetOrCreateMax(const DataType type);
// Get an XLA lambda to compute Min. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateMin(const DataType type);
+ const xla::XlaComputation* GetOrCreateMin(const DataType type);
// Get an XLA lambda to compute Add. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateAdd(const DataType type);
+ const xla::XlaComputation* GetOrCreateAdd(const DataType type);
// Get an XLA lambda to compute Mul. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateMul(const DataType type);
+ const xla::XlaComputation* GetOrCreateMul(const DataType type);
// The name of the XlaContext resource during symbolic graph execution.
static const char kXlaContextResourceName[];
@@ -119,9 +116,8 @@ class XlaContext : public ResourceBase {
private:
XlaCompiler* const compiler_;
- // The ComputationBuilder used to construct the subgraph's compiled
- // representation.
- xla::ComputationBuilder* builder_;
+ // The XlaBuilder used to construct the subgraph's compiled representation.
+ xla::XlaBuilder* builder_;
// Allow ops to emit CustomCall operations for CPU.
const bool allow_cpu_custom_calls_;
@@ -146,14 +142,14 @@ class XlaContext : public ResourceBase {
variable_representation_shape_fn_;
// Cache of prebuilt computations indexed by their type.
- using ComputationMap = std::map<DataType, xla::Computation>;
+ using ComputationMap = std::map<DataType, xla::XlaComputation>;
// Finds the value for the given type in out map if it already
// exists or makes a new value with create function and keeps it the
// map. The returned value != nullptr and is owned by the map.
- const xla::Computation* LookupOrCreate(
+ const xla::XlaComputation* LookupOrCreate(
DataType type, ComputationMap* out,
- const std::function<xla::Computation()>& create);
+ const std::function<xla::XlaComputation()>& create);
// Cached computation to compute Max of two elements, specialized by type.
ComputationMap max_func_;
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 62a5114837..f1594193af 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,13 +32,12 @@ namespace tensorflow {
namespace {
-Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis, bool is_min,
- xla::ComputationDataHandle* argminmax) {
- xla::ComputationDataHandle init_value;
- const xla::Computation* reducer;
+Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input, const TensorShape& input_shape,
+ DataType input_type, DataType output_type, int axis,
+ bool is_min, xla::XlaOp* argminmax) {
+ xla::XlaOp init_value;
+ const xla::XlaComputation* reducer;
if (is_min) {
init_value = XlaHelpers::MaxValue(builder, input_type);
reducer = ctx->GetOrCreateMin(input_type);
@@ -50,13 +49,13 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
xla::PrimitiveType xla_output_type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type));
- xla::ComputationDataHandle input_max = builder->Reduce(
- input, init_value, *reducer, /*dimensions_to_reduce=*/{axis});
+ xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer,
+ /*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(input_shape.dims() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.
- xla::ComputationDataHandle partial_mask = builder->ConvertElementType(
+ xla::XlaOp partial_mask = builder->ConvertElementType(
builder->Eq(input, input_max, broadcast_dims), xla_output_type);
// In order to make identity elements for a bitwise And, we:
@@ -65,23 +64,23 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
// 0xFF...F
int32 bits_in_type =
xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1;
- xla::ComputationDataHandle shift_amount =
+ xla::XlaOp shift_amount =
XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type);
- xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic(
+ xla::XlaOp full_mask = builder->ShiftRightArithmetic(
builder->ShiftLeft(partial_mask, shift_amount), shift_amount);
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
const int64 axis_size = input_shape.dim_size(axis);
TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota));
- xla::ComputationDataHandle product =
+ xla::XlaOp product =
builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis});
// If there are multiple maximum elements, choose the one with the highest
// index.
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
builder->Reduce(product, XlaHelpers::MinValue(builder, output_type),
*ctx->GetOrCreateMax(output_type),
/*dimensions_to_reduce=*/{axis});
@@ -91,36 +90,31 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
} // namespace
-xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::MinValue(type));
}
-xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::MaxValue(type));
}
-xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::Zero(type));
}
-xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::One(type));
}
-xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) {
switch (data_type) {
case DT_HALF:
return b->ConstantR0<Eigen::half>(
@@ -137,16 +131,15 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
}
}
-xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
- xla::ComputationBuilder* b, DataType data_type, int64 value) {
+xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
+ int64 value) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return ::tensorflow::IntegerLiteral(b, type, value);
}
-xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
- DataType data_type,
- double value) {
+xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
+ double value) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return ::tensorflow::FloatLiteral(b, type, value);
@@ -183,28 +176,24 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
return linspace;
}
-Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder,
- XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
+Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input,
const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis,
- xla::ComputationDataHandle* argmax) {
+ DataType output_type, int axis, xla::XlaOp* argmax) {
return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
axis, /*is_min=*/false, argmax);
}
-Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder,
- XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
+Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input,
const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis,
- xla::ComputationDataHandle* argmin) {
+ DataType output_type, int axis, xla::XlaOp* argmin) {
return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
axis, /*is_min=*/true, argmin);
}
-Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype,
- int64 size, xla::ComputationDataHandle* iota) {
+Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
+ xla::XlaOp* iota) {
TensorShape linspace_shape({size});
Tensor linspace;
switch (dtype) {
@@ -227,13 +216,10 @@ Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype,
return Status::OK();
}
-Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
- int axis, DataType index_type,
- const TensorShape& indices_shape,
- const xla::ComputationDataHandle& indices,
- const xla::ComputationDataHandle& on_value,
- const xla::ComputationDataHandle& off_value,
- xla::ComputationDataHandle* one_hot) {
+Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
+ DataType index_type, const TensorShape& indices_shape,
+ const xla::XlaOp& indices, const xla::XlaOp& on_value,
+ const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
const int indices_dims = indices_shape.dims();
const int output_dims = indices_dims + 1;
@@ -267,7 +253,7 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
std::vector<int64> broadcast_dims(indices_shape.dims());
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- xla::ComputationDataHandle one_hot_bool = builder->Eq(
+ xla::XlaOp one_hot_bool = builder->Eq(
indices, builder->ConstantLiteral(linspace_literal), broadcast_dims);
// Selects the user-provided off_value and on_value values.
@@ -278,16 +264,15 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
}
DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
- if (dtype == DT_BFLOAT16) {
+ if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return DT_FLOAT;
}
return dtype;
}
-xla::ComputationDataHandle XlaHelpers::ConvertElementType(
- xla::ComputationBuilder* const builder,
- const xla::ComputationDataHandle& operand,
- const DataType new_element_type) {
+xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder,
+ const xla::XlaOp& operand,
+ const DataType new_element_type) {
xla::PrimitiveType convert_to;
TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
return builder->ConvertElementType(operand, convert_to);
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index 68ab93b64a..c3fdc5252e 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -19,7 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
#include "tensorflow/compiler/tf2xla/xla_context.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -30,41 +30,34 @@ class XlaHelpers {
public:
// Returns a handle representing the minimum value of a scalar
// element of data_type.
- static xla::ComputationDataHandle MinValue(xla::ComputationBuilder* b,
- DataType data_type);
+ static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type);
// Returns a handle representing the maximum value of a scalar
// element of data_type.
- static xla::ComputationDataHandle MaxValue(xla::ComputationBuilder* b,
- DataType data_type);
+ static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type);
// Returns a handle representing the zero value of a scalar
// element of data_type.
- static xla::ComputationDataHandle Zero(xla::ComputationBuilder* b,
- DataType data_type);
+ static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
// Returns a handle representing the one value of a scalar
// element of data_type.
- static xla::ComputationDataHandle One(xla::ComputationBuilder* b,
- DataType data_type);
+ static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type);
// Returns the machine epsilon for floating-point type `data_type`, i.e.,
// the difference between 1.0 and the next representable value.
- static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b,
- DataType data_type);
+ static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type);
// Returns a handle representing the given value of an integer scalar
// element of data_type.
// Note that unlike One and Zero, does not work on boolean types.
- static xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* b,
- DataType data_type,
- int64 value);
+ static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
+ int64 value);
// Returns a handle representing the given value of a floating-point scalar
// element of data_type.
- static xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* b,
- DataType data_type,
- double value);
+ static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type,
+ double value);
// Reshapes literal 'input' to have 'shape'. Both the original shape and
// 'shape' must contain the same number of elements.
@@ -75,38 +68,32 @@ class XlaHelpers {
// Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and
// `input_dtype` are the shape and dtype of `input` respectively, and
// `output_type` is the dtype to use for `argmax`.
- static Status ArgMax(xla::ComputationBuilder* builder,
- XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis,
- xla::ComputationDataHandle* argmax);
+ static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input, const TensorShape& input_shape,
+ DataType input_type, DataType output_type, int axis,
+ xla::XlaOp* argmax);
// Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and
// `input_dtype` are the shape and dtype of `input` respectively, and
// `output_type` is the dtype to use for `argmin`.
- static Status ArgMin(xla::ComputationBuilder* builder,
- XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis,
- xla::ComputationDataHandle* argmin);
+ static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input, const TensorShape& input_shape,
+ DataType input_type, DataType output_type, int axis,
+ xla::XlaOp* argmin);
// Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`.
- static Status Iota(xla::ComputationBuilder* builder, DataType dtype,
- int64 size, xla::ComputationDataHandle* iota);
+ static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
+ xla::XlaOp* iota);
// Converts `indices` into a one-hot representation. `depth` is the size
// of the new axis to add. `axis` is the position at which to add the new
// axis. `indices_shape` is the shape of `indices`. `on_value` and
// `off_value` represent the values to use for the on and off positions,
// respectively.
- static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis,
+ static Status OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
DataType index_type, const TensorShape& indices_shape,
- const xla::ComputationDataHandle& indices,
- const xla::ComputationDataHandle& on_value,
- const xla::ComputationDataHandle& off_value,
- xla::ComputationDataHandle* one_hot);
+ const xla::XlaOp& indices, const xla::XlaOp& on_value,
+ const xla::XlaOp& off_value, xla::XlaOp* one_hot);
// Certain DataTypes should use increased precision DataTypes when performing
// reductions. This function remaps a given DataType to a higher precision
@@ -115,10 +102,9 @@ class XlaHelpers {
// A helper for creating a ConvertElementType xla op given a DataType rather
// than the xla::PrimitiveType.
- static xla::ComputationDataHandle ConvertElementType(
- xla::ComputationBuilder* const builder,
- const xla::ComputationDataHandle& operand,
- const DataType new_element_type);
+ static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder,
+ const xla::XlaOp& operand,
+ const DataType new_element_type);
};
} // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 1fe6e69ff2..9e17756b27 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -112,10 +112,10 @@ void CollectNames(const T& entries, std::vector<string>* nonempty_names,
XlaJitCompiledCpuFunction::Compile(
const GraphDef& graph_def, const tf2xla::Config& config,
const xla::ExecutableBuildOptions& build_options) {
- // Convert the graph_def into an xla::Computation.
+ // Convert the graph_def into an xla::XlaComputation.
TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
xla::ClientLibrary::GetOrCreateLocalClient());
- xla::Computation computation;
+ xla::XlaComputation computation;
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client,
&computation));
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index c4bb90d587..2b65f4d5d5 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -30,7 +30,7 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
return context_->ValidateInputsAreSameShape(op);
}
-xla::ComputationBuilder* XlaOpKernelContext::builder() const {
+xla::XlaBuilder* XlaOpKernelContext::builder() const {
return XlaContext::Get(this).builder();
}
@@ -38,9 +38,9 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
- CHECK(expression->handle().handle() != 0 ||
+ CHECK(expression->handle().builder() != nullptr ||
expression->resource() != nullptr);
- VLOG(1) << "Fetched T" << expression->handle().handle();
+ VLOG(1) << "Fetched T" << expression->handle();
return expression;
}
@@ -48,20 +48,18 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
- CHECK_EQ(expression->handle().handle(), 0);
+ CHECK_EQ(expression->handle().builder(), nullptr);
return const_cast<XlaExpression*>(expression);
}
-// Retrieves the ComputationDataHandle from an input Tensor to an Op. This
-// computation was constructed by an Op that executed previously and
-// created the output Tensor using CreateOutputTensorFromComputation
-// or CreateConstantOutputTensor.
-static const xla::ComputationDataHandle& GetComputationFromTensor(
- const Tensor& tensor) {
+// Retrieves the XlaOp from an input Tensor to an Op. This computation was
+// constructed by an Op that executed previously and created the output Tensor
+// using CreateOutputTensorFromComputation or CreateConstantOutputTensor.
+static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) {
return CastExpressionFromTensor(tensor)->handle();
}
-const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) {
+const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
}
@@ -106,7 +104,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
return HostTensorToLiteral(temp, constant_literal);
}
- xla::ComputationDataHandle handle = expression->handle();
+ xla::XlaOp handle = expression->handle();
if (new_shape != tensor.shape()) {
// Reshape the handle to the desired shape.
handle = builder()->Reshape(handle, new_shape.dim_sizes());
@@ -141,8 +139,17 @@ Status XlaOpKernelContext::ConstantInputReshaped(
}
// Ask the XLA compiler to evaluate the data handle to a literal.
+ xla::StatusOr<xla::XlaComputation> constant_graph =
+ builder()->BuildConstantSubGraph(handle);
+ if (!constant_graph.ok()) {
+ return errors::Internal(
+ "Error getting a compile-time constant graph for ",
+ context_->op_kernel().name(), " input ", index,
+ ".\nError: ", constant_graph.status().error_message());
+ }
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
- builder()->ComputeConstant(handle, &layout);
+ compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
+ &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
@@ -260,9 +267,9 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
return Status::OK();
}
-Status XlaOpKernelContext::InputList(
- StringPiece name, std::vector<xla::ComputationDataHandle>* handles,
- std::vector<TensorShape>* shapes) {
+Status XlaOpKernelContext::InputList(StringPiece name,
+ std::vector<xla::XlaOp>* handles,
+ std::vector<TensorShape>* shapes) {
OpInputList inputs;
TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
handles->clear();
@@ -285,9 +292,9 @@ Status XlaOpKernelContext::ConstantInputList(
return Status::OK();
}
-Status XlaOpKernelContext::ReadVariableInput(
- int index, DataType type, TensorShape* shape,
- xla::ComputationDataHandle* value) {
+Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
@@ -334,8 +341,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
-void XlaOpKernelContext::SetOutput(int index,
- const xla::ComputationDataHandle& handle) {
+void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
auto shape = builder()->GetShape(handle);
@@ -349,7 +355,7 @@ void XlaOpKernelContext::SetOutput(int index,
// corresponds.
TensorShape tensor_shape;
OP_REQUIRES_OK(context_,
- XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape));
+ XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
OP_REQUIRES_OK(context_,
context_->allocate_output(index, tensor_shape, &output));
@@ -364,8 +370,8 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
xla::Literal literal;
OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal));
- xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal);
- CHECK_NE(handle.handle(), 0);
+ xla::XlaOp handle = builder()->ConstantLiteral(literal);
+ CHECK_NE(handle.builder(), nullptr);
// Make the Tensor that will refer to the expression.
Tensor* output = nullptr;
@@ -386,8 +392,7 @@ void XlaOpKernelContext::SetInvalidOutput(int index) {
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape({}), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
- xla::ComputationDataHandle handle;
- handle.set_handle(0);
+ xla::XlaOp handle;
expression->set_handle(handle);
}
@@ -410,8 +415,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
}
Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
- xla::ComputationDataHandle handle) {
- TF_RET_CHECK(handle.handle() != 0);
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.builder() != nullptr);
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(input_index));
@@ -425,7 +430,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
}
TensorShape shape;
TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+ XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
@@ -457,22 +462,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
context_->CtxFailureWithWarning(file, line, s);
}
-const xla::Computation* XlaOpKernelContext::GetOrCreateMax(
+const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
const DataType type) {
return XlaContext::Get(context_).GetOrCreateMax(type);
}
-const xla::Computation* XlaOpKernelContext::GetOrCreateMin(
+const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
const DataType type) {
return XlaContext::Get(context_).GetOrCreateMin(type);
}
-const xla::Computation* XlaOpKernelContext::GetOrCreateAdd(
+const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
const DataType type) {
return XlaContext::Get(context_).GetOrCreateAdd(type);
}
-const xla::Computation* XlaOpKernelContext::GetOrCreateMul(
+const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const DataType type) {
return XlaContext::Get(context_).GetOrCreateMul(type);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 4e4b97e0ce..667dc262ca 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@@ -58,8 +58,8 @@ class XlaOpKernelContext {
public:
explicit XlaOpKernelContext(OpKernelContext* context);
- // Returns the XLA ComputationBuilder containing the output of compilation.
- xla::ComputationBuilder* builder() const;
+ // Returns the XLA XlaBuilder containing the output of compilation.
+ xla::XlaBuilder* builder() const;
// Inputs
@@ -72,10 +72,10 @@ class XlaOpKernelContext {
// Returns the shape of input 'index'.
TensorShape InputShape(int index);
- // Returns input 'index' as a ComputationDataHandle. Unlike
+ // Returns input 'index' as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor.
- const xla::ComputationDataHandle& Input(int index);
+ const xla::XlaOp& Input(int index);
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
@@ -85,8 +85,7 @@ class XlaOpKernelContext {
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list.
- Status InputList(StringPiece name,
- std::vector<xla::ComputationDataHandle>* handles,
+ Status InputList(StringPiece name, std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes);
// Helper methods for constant inputs.
@@ -132,10 +131,10 @@ class XlaOpKernelContext {
return context_->expected_output_dtype(index);
}
- // Sets output 'index' to the ComputationDataHandle 'handle'.
+ // Sets output 'index' to the XlaOp 'handle'.
// All outputs should be set using SetOutput and SetConstantOutput, not
// via the underlying OpKernelContext.
- void SetOutput(int index, const xla::ComputationDataHandle& handle);
+ void SetOutput(int index, const xla::XlaOp& handle);
// Sets output 'index' to compile-time constant 'host_tensor', where
// 'host_tensor' is a tensor in host memory. It is preferable to use
@@ -168,14 +167,13 @@ class XlaOpKernelContext {
// variable. Returns an error if the variable has not been initialized, or if
// its type does not match `type`.
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
- xla::ComputationDataHandle* value);
+ xla::XlaOp* value);
// Assigns the value `handle` to the variable referenced by input
// `input_index`. The variable must be of `type`. Returns an error if the
// variable has been initialized with a different type or with a
// different shape.
- Status AssignVariable(int input_index, DataType type,
- xla::ComputationDataHandle handle);
+ Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
@@ -205,22 +203,22 @@ class XlaOpKernelContext {
// Gets an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateMax(const DataType type);
+ const xla::XlaComputation* GetOrCreateMax(const DataType type);
// Gets an XLA lambda to compute Min. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateMin(const DataType type);
+ const xla::XlaComputation* GetOrCreateMin(const DataType type);
// Gets an XLA lambda to compute Add. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateAdd(const DataType type);
+ const xla::XlaComputation* GetOrCreateAdd(const DataType type);
// Gets an XLA lambda to compute Mul. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
- const xla::Computation* GetOrCreateMul(const DataType type);
+ const xla::XlaComputation* GetOrCreateMul(const DataType type);
private:
OpKernelContext* const context_;
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index c2075b44b8..540c65c597 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -26,8 +26,7 @@ limitations under the License.
namespace tensorflow {
XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
- TensorShape shape,
- const xla::ComputationDataHandle& initial_value,
+ TensorShape shape, const xla::XlaOp& initial_value,
int64 tensor_array_size,
const std::set<string>& tensor_array_gradients)
: kind_(kind),
@@ -41,11 +40,10 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
CHECK(kind_ != kInvalid);
for (const string& gradient : tensor_array_gradients) {
- tensor_array_gradients_[gradient].reset(
- new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
- /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
- type_, shape_, xla::ComputationDataHandle(),
- tensor_array_size_, /*tensor_array_gradients=*/{}));
+ tensor_array_gradients_[gradient].reset(new XlaResource(
+ /*kind=*/kTensorArray, /*arg_num=*/-1,
+ /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_,
+ xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}));
}
}
@@ -73,7 +71,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
return Status::OK();
}
-Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
+Status XlaResource::SetValue(const xla::XlaOp& value) {
if (type_ == DT_INVALID) {
return errors::InvalidArgument(
"Resource '", name_,
@@ -83,7 +81,7 @@ Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
return Status::OK();
}
-Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
+Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
if (type_ == DT_INVALID) {
return errors::InvalidArgument(
"Resource '", name_,
@@ -121,9 +119,9 @@ Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
return Status::OK();
}
-Status XlaResource::GetOrCreateTensorArrayGradient(
- const string& source, xla::ComputationBuilder* builder,
- XlaResource** gradient_out) {
+Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
+ xla::XlaBuilder* builder,
+ XlaResource** gradient_out) {
VLOG(2) << "Gradient lookup for resource: " << name_
<< " gradient: " << source;
TF_RET_CHECK(kind_ == kTensorArray);
@@ -132,7 +130,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
- xla::ComputationDataHandle gradient_value = builder->Broadcast(
+ xla::XlaOp gradient_value = builder->Broadcast(
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
@@ -144,13 +142,12 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
return Status::OK();
}
-Status XlaResource::Pack(xla::ComputationDataHandle* pack,
- xla::ComputationBuilder* builder) const {
+Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
if (tensor_array_gradients_.empty()) {
*pack = value_;
} else {
TF_RET_CHECK(kind_ == kTensorArray);
- std::vector<xla::ComputationDataHandle> elems;
+ std::vector<xla::XlaOp> elems;
elems.push_back(value_);
for (const auto& gradient : tensor_array_gradients_) {
elems.push_back(gradient.second->value_);
@@ -161,8 +158,8 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack,
}
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
- const xla::ComputationDataHandle& pack,
- xla::ComputationBuilder* builder) {
+ const xla::XlaOp& pack,
+ xla::XlaBuilder* builder) {
if (gradient_sources.empty()) {
if (!initialized()) {
initial_value_ = pack;
diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h
index 1bb2c7274e..9ce36d1aa7 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.h
+++ b/tensorflow/compiler/tf2xla/xla_resource.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -37,8 +37,7 @@ class XlaResource {
};
XlaResource(Kind kind, int arg_num, string name, DataType type,
- TensorShape shape,
- const xla::ComputationDataHandle& initial_value,
+ TensorShape shape, const xla::XlaOp& initial_value,
int64 tensor_array_size,
const std::set<string>& tensor_array_gradients);
@@ -69,16 +68,14 @@ class XlaResource {
// this is the shape of each entry in the TensorArray/Stack.
const TensorShape& shape() const { return shape_; }
- const xla::ComputationDataHandle& value() const { return value_; }
+ const xla::XlaOp& value() const { return value_; }
// Value of the resource at computation entry. Used to detect which
// variables have new values that need to be written back.
- const xla::ComputationDataHandle& initial_value() const {
- return initial_value_;
- }
+ const xla::XlaOp& initial_value() const { return initial_value_; }
// A variable is initialized if it has a value.
- bool initialized() const { return value_.handle() > 0; }
+ bool initialized() const { return value_.builder() != nullptr; }
// Sets the type and shape of the resource. The type and shape of a resource
// must not change once the variable has been initialized.
@@ -86,17 +83,17 @@ class XlaResource {
// Sets the current value of the resource. Returns an error if the type is not
// set to a valid value.
- Status SetValue(const xla::ComputationDataHandle& value);
+ Status SetValue(const xla::XlaOp& value);
// Sets the current value of the resource to an all-zero value.
- Status SetZeroValue(xla::ComputationBuilder* builder);
+ Status SetZeroValue(xla::XlaBuilder* builder);
// Looks up the gradient for `source`, or creates it if it does not already
// exist. The call target must be an initialized TensorArray resource. A
// TensorArray can have multiple named gradients; see the operator
// documentation for TensorArrayGradV3 for details.
Status GetOrCreateTensorArrayGradient(const string& source,
- xla::ComputationBuilder* builder,
+ xla::XlaBuilder* builder,
XlaResource** gradient_out);
// Packs a resource into a single XLA value `pack`, suitable for use as
@@ -104,8 +101,7 @@ class XlaResource {
// gradients, sets `*pack` to `value`.
// For TensorArrays with gradients, packs the value and its gradient values in
// a tuple; the gradients values are packed in order by source name.
- Status Pack(xla::ComputationDataHandle* pack,
- xla::ComputationBuilder* builder) const;
+ Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const;
// Updates the resource with values from `pack`. If `gradient_sources` is
// non-empty, treats `pack` as a tuple that represents a TensorArray and
@@ -114,8 +110,7 @@ class XlaResource {
// values.
// Opposite of Pack().
Status SetFromPack(const std::set<string>& gradient_sources,
- const xla::ComputationDataHandle& pack,
- xla::ComputationBuilder* builder);
+ const xla::XlaOp& pack, xla::XlaBuilder* builder);
// TensorArray and Stack specific fields
@@ -144,8 +139,8 @@ class XlaResource {
DataType type_;
TensorShape shape_;
- xla::ComputationDataHandle value_;
- xla::ComputationDataHandle initial_value_;
+ xla::XlaOp value_;
+ xla::XlaOp initial_value_;
int64 tensor_array_size_ = -1;
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 286d06d12f..aac3273d5f 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -106,6 +106,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 1c12705903..1acc6f8686 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -51,27 +51,49 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
- const ComputationLayout& computation_layout =
- executable_->module_config().entry_computation_layout();
+ const ComputationLayout& host_computation_layout =
+ executable_->module_config().host_entry_computation_layout();
+ const ComputationLayout& device_computation_layout =
+ executable_->module_config().device_entry_computation_layout();
// Check argument number, shapes, and layouts.
- if (arguments.size() != computation_layout.parameter_count()) {
+ if (arguments.size() != host_computation_layout.parameter_count()) {
return InvalidArgument(
"invalid number of arguments for computation: expected %d, got %zu",
- computation_layout.parameter_count(), arguments.size());
+ host_computation_layout.parameter_count(), arguments.size());
+ }
+ if (arguments.size() != device_computation_layout.parameter_count()) {
+ return InvalidArgument(
+ "invalid number of arguments for computation: expected %d, got %zu",
+ device_computation_layout.parameter_count(), arguments.size());
}
for (int i = 0; i < arguments.size(); ++i) {
- if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape(
arguments[i]->on_host_shape())) {
return InvalidParameterArgument(
executable_.get(), i,
- "Argument does not match shape or layout of computation parameter "
+ "Argument does not match host shape or layout of computation "
+ "parameter "
"%d: want %s, got %s",
i,
- ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
+ ShapeUtil::HumanString(
+ host_computation_layout.parameter_layout(i).shape())
.c_str(),
ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str());
}
+ if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ arguments[i]->on_device_shape())) {
+ return InvalidParameterArgument(
+ executable_.get(), i,
+ "Argument does not match device shape or layout of computation "
+ "parameter "
+ "%d: want %s, got %s",
+ i,
+ ShapeUtil::HumanString(
+ device_computation_layout.parameter_layout(i).shape())
+ .c_str(),
+ ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str());
+ }
}
if (run_options.stream() != nullptr) {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index f306c520ed..d8fd7a5623 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -42,15 +43,6 @@ class LocalExecutable {
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
ExecutableRunOptions run_options);
- // Return the layout (contained in a shape) of the result produced by the
- // computation.
- const Shape& result_layout() const {
- return executable_->module_config()
- .entry_computation_layout()
- .result_layout()
- .shape();
- }
-
// Return the options used to build the executable.
const ExecutableBuildOptions& build_options() const { return build_options_; }
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index fdc4bbdd8b..c6f8f6766e 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -465,4 +466,25 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) {
return out;
}
+/*static*/ size_t LayoutUtil::Hash(const Layout& layout) {
+ using tensorflow::hash;
+ using tensorflow::Hash64Combine;
+
+ size_t hash_value = hash<Format>()(layout.format());
+
+ for (int64 minor_to_major : layout.minor_to_major()) {
+ hash_value = Hash64Combine(hash_value, hash<int64>()(minor_to_major));
+ }
+
+ for (int64 padded_dim : layout.padded_dimensions()) {
+ hash_value = Hash64Combine(hash_value, hash<int64>()(padded_dim));
+ }
+
+ hash_value =
+ Hash64Combine(hash_value, hash<PaddingValue>()(layout.padding_value()));
+ hash_value = Hash64Combine(hash_value, layout.max_sparse_elements());
+
+ return hash_value;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 6c54eb2201..6cec750101 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -195,6 +195,9 @@ class LayoutUtil {
static bool AreDimensionsConsecutive(const Layout& layout,
tensorflow::gtl::ArraySlice<int64> dims);
+ // Compute a hash for `layout`.
+ static size_t Hash(const Layout& layout);
+
private:
TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil);
};
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index bb6dd4f909..b3b5e34ba2 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -2148,6 +2149,27 @@ string Literal::GetR1U8AsString() const {
return LiteralView(literal, view_root);
}
+size_t Literal::Hash() const {
+ using tensorflow::Hash64;
+ using tensorflow::Hash64Combine;
+
+ size_t hash_value = ShapeUtil::Hash(shape());
+
+ ShapeUtil::ForEachSubshape(
+ shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsTuple(subshape)) {
+ return;
+ }
+
+ CHECK(LayoutUtil::IsDense(subshape.layout()));
+ hash_value = Hash64Combine(
+ hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
+ size_bytes(index)));
+ });
+
+ return hash_value;
+}
+
LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
pieces_ = ShapeTree<Piece>(shape_);
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 8aa19222dc..290f388078 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -74,6 +74,10 @@ class Literal {
Literal(const Literal& other) = delete;
Literal& operator=(const Literal& other) = delete;
Literal(Literal&& other);
+ // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+ // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+ // to nullptr.
+ Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
// Literals are equal if they have compatible shapes and the same data
@@ -658,12 +662,11 @@ class Literal {
// LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
int64 sparse_element_count() const;
- protected:
- // 'allocate_arrays' indicates whether to allocate memory for the arrays in
- // the shape. If false, buffer pointers inside of the Literal::Pieces are set
- // to nullptr.
- Literal(const Shape& shape, bool allocate_arrays);
+ // Compute a hash for this literal. This literal must not be a sparse tensor
+ // or a tuple containing a sparse tensor.
+ size_t Hash() const;
+ protected:
// Internal template helper for the Literal::CopySliceFrom(), matching its
// arguments one by one.
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ed0da47681..6e2510aa10 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2421,7 +2421,6 @@ tf_cc_test(
":hlo_graph_dumper",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index dbe45e932c..94ccfedf62 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -292,112 +292,6 @@ BufferAllocationProto BufferAllocation::ToProto() const {
return proto;
}
-std::pair<int64, std::vector<const LogicalBuffer*>>
-BufferAllocation::ComputePeakMemoryLogicalBuffers() const {
- if (HeapTraces().empty()) {
- // Just return the largest LogicalBuffer in the allocation.
- const LogicalBuffer* largest_buffer = nullptr;
- int64 largest_size = 0;
- for (const auto& pair : assigned_buffers()) {
- const LogicalBuffer* buffer = pair.first;
- int64 size = pair.second.size;
- if (largest_buffer == nullptr) {
- largest_buffer = buffer;
- largest_size = size;
- continue;
- }
- // Tie-break with LogicalBuffer::Id so the return value is stable relative
- // to changing addresses.
- if (size > largest_size ||
- ((size == largest_size) && (largest_buffer->id() > buffer->id()))) {
- largest_buffer = buffer;
- largest_size = size;
- }
- }
- CHECK(largest_buffer != nullptr)
- << "No logical buffers in allocation: " << ToString();
- return {largest_size, {largest_buffer}};
- }
-
- // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
- // buffers in this allocation.
- tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*>
- id_to_buffer;
- tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes;
- for (const auto& pair : assigned_buffers()) {
- const LogicalBuffer* buffer = pair.first;
- const OffsetSize& offset_size = pair.second;
- id_to_buffer[buffer->id()] = buffer;
- buffer_sizes[buffer] = offset_size.size;
- }
-
- // Returns how much the given event increases the total size of live
- // buffers. Can be negative.
- auto memory_delta = [this, &id_to_buffer, &buffer_sizes](
- const HeapSimulatorTrace::Event& event) -> int64 {
- const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
- const int64 buffer_size = buffer_sizes.at(buffer);
- if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
- return buffer_size;
- } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
- // Sharing a buffer does not change the live set size for the purposes of
- // the heap simulator. Even though the shared-with buffer may be smaller,
- // the entire allocation remains live.
- return 0;
- } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
- return -1 * buffer_size;
- }
- LOG(FATAL) << "Unknown event kind: " << event.kind();
- };
-
- int64 total_max_live_size = 0;
- std::vector<const LogicalBuffer*> live_buffers_vector;
- for (const HeapSimulatorTrace& heap_trace : HeapTraces()) {
- // First compute the size of the maximal live set.
- int64 max_live_size = 0;
- int64 live_size = 0;
- for (const auto& event : heap_trace.events()) {
- live_size += memory_delta(event);
- if (max_live_size < live_size) {
- max_live_size = live_size;
- }
- }
-
- // Next gather the set of logical buffers live at the earliest point of
- // maximal live set size.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
- live_size = 0;
- for (const auto& event : heap_trace.events()) {
- const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
- if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
- InsertOrDie(&live_buffers, buffer);
- } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
- // Nothing to do.
- } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
- CHECK(ContainsKey(live_buffers, buffer));
- live_buffers.erase(buffer);
- }
-
- live_size += memory_delta(event);
- if (live_size == max_live_size) {
- break;
- }
- }
- CHECK_EQ(live_size, max_live_size);
- total_max_live_size += max_live_size;
-
- live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(),
- live_buffers.end());
- }
-
- // Stabily sort the live buffers.
- std::sort(live_buffers_vector.begin(), live_buffers_vector.end(),
- [](const LogicalBuffer* a, const LogicalBuffer* b) {
- return a->id() < b->id();
- });
- return {total_max_live_size, live_buffers_vector};
-}
-
string BufferAllocation::ToString() const {
string output;
Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size());
@@ -610,6 +504,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
BufferAllocation* allocation =
NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color());
AddAssignment(allocation, buffer, /*offset=*/0, size);
+ allocation->peak_buffers_.push_back(&buffer);
return allocation;
}
@@ -680,6 +575,10 @@ void BufferAssignment::CombineTempAllocations() {
CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
}
+ combined_allocation->peak_buffers_.insert(
+ combined_allocation->peak_buffers_.end(),
+ temp_allocation.peak_buffers_.begin(),
+ temp_allocation.peak_buffers_.end());
}
// Replace all existing temporary allocations with the new combined
// allocations.
@@ -1228,6 +1127,89 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
return Status::OK();
}
+namespace {
+
+// Computes and returns the set of logical buffers live at the point of maximal
+// liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id.
+std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
+ const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
+ // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
+ // buffers in this allocation.
+ tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*>
+ id_to_buffer;
+ tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes;
+ for (const auto& pair : allocation.assigned_buffers()) {
+ const LogicalBuffer* buffer = pair.first;
+ const BufferAllocation::OffsetSize& offset_size = pair.second;
+ id_to_buffer[buffer->id()] = buffer;
+ buffer_sizes[buffer] = offset_size.size;
+ }
+
+ // Returns how much the given event increases the total size of live
+ // buffers. Can be negative.
+ auto memory_delta = [&id_to_buffer, &buffer_sizes](
+ const HeapSimulatorTrace::Event& event) -> int64 {
+ const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
+ const int64 buffer_size = buffer_sizes.at(buffer);
+ if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
+ return buffer_size;
+ } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
+ // Sharing a buffer does not change the live set size for the purposes of
+ // the heap simulator. Even though the shared-with buffer may be smaller,
+ // the entire allocation remains live.
+ return 0;
+ } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
+ return -1 * buffer_size;
+ }
+ LOG(FATAL) << "Unknown event kind: " << event.kind();
+ };
+
+ // First compute the size of the maximal live set.
+ int64 max_live_size = 0;
+ int64 live_size = 0;
+ for (const auto& event : heap_trace.events()) {
+ live_size += memory_delta(event);
+ if (max_live_size < live_size) {
+ max_live_size = live_size;
+ }
+ }
+
+ // Next gather the set of logical buffers live at the earliest point of
+ // maximal live set size.
+ tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
+ live_size = 0;
+ for (const auto& event : heap_trace.events()) {
+ const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
+ if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
+ InsertOrDie(&live_buffers, buffer);
+ } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
+ // Nothing to do.
+ } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
+ CHECK(ContainsKey(live_buffers, buffer));
+ live_buffers.erase(buffer);
+ }
+
+ live_size += memory_delta(event);
+ if (live_size == max_live_size) {
+ break;
+ }
+ }
+ CHECK_EQ(live_size, max_live_size);
+
+ std::vector<const LogicalBuffer*> live_buffers_vector;
+ live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(),
+ live_buffers.end());
+
+ // Stabily sort the live buffers.
+ std::sort(live_buffers_vector.begin(), live_buffers_vector.end(),
+ [](const LogicalBuffer* a, const LogicalBuffer* b) {
+ return a->id() < b->id();
+ });
+ return live_buffers_vector;
+}
+
+} // namespace
+
void BufferAssigner::AssignBuffersFromHeapSimulator(
const HeapSimulator::Result& result, BufferAssignment* assignment,
LogicalBuffer::Color color) {
@@ -1246,6 +1228,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator(
const HeapSimulator::Chunk& chunk = buffer_chunk.second;
assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
}
+ allocation->peak_buffers_ =
+ ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString();
allocation->AddHeapTrace(result.debug_trace);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 3086d0e2ca..15fd905e8d 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -206,17 +206,15 @@ class BufferAllocation {
return heap_traces_;
}
- // Compute and return the LogicalBuffers which are live at the point of peak
- // memory usage for the given allocation. The point of peak memory usage is
- // the point at which the total size of all live logical buffers is
- // maximal. If peak memory is reached at multiple points, the set of logical
- // buffers live at the earliest maximal point is returned. The vector is
- // stabily asserted by LogicalBuffer::Index.
- //
- // The return value is a pair of total size of the logical buffers at peak,
- // and the buffers themselves.
- std::pair<int64, std::vector<const LogicalBuffer*>>
- ComputePeakMemoryLogicalBuffers() const;
+ // Returns the LogicalBuffers which are live at the point of peak memory usage
+ // for this allocation. The point of peak memory usage is the point at which
+ // the total size of all live logical buffers is maximal. If peak memory is
+ // reached at multiple points, the set of logical buffers live at the earliest
+ // maximal point is returned. The vector is stabily sorted by
+ // LogicalBuffer::Index.
+ const std::vector<const LogicalBuffer*>& PeakMemoryLogicalBuffers() const {
+ return peak_buffers_;
+ }
// Get the number of bytes lost to fragmentation. This is equal to the
// difference between the size of the allocation and the size of the maximal
@@ -291,6 +289,9 @@ class BufferAllocation {
int64 fragmentation_bytes_ = 0;
std::vector<HeapSimulatorTrace> heap_traces_;
+
+ // Set of buffers live at the point of peak memory usage for this allocation.
+ std::vector<const LogicalBuffer*> peak_buffers_;
};
// Add stream operators for nicer output of CHECK/RET_CHECK failures.
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 3ec9795a65..f6d6b5c36a 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1519,12 +1519,8 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
// single logical buffer should be exactly the logical buffer in that
// allocation.
const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
- int64 peak_size;
- std::vector<const LogicalBuffer*> peak_buffers;
-
- std::tie(peak_size, peak_buffers) =
- mul_buffer.ComputePeakMemoryLogicalBuffers();
- EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(f32vec100_));
+ const std::vector<const LogicalBuffer*>& peak_buffers =
+ mul_buffer.PeakMemoryLogicalBuffers();
ASSERT_EQ(peak_buffers.size(), 1);
EXPECT_EQ(peak_buffers[0]->instruction(), mul);
}
@@ -1555,6 +1551,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
// Make the root tiny so no interior nodes can share its buffer.
auto root = builder.AddInstruction(HloInstruction::CreateSlice(
+
ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
auto module = CreateNewModule();
@@ -1569,12 +1566,10 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
ASSERT_EQ(buffer.assigned_buffers().size(), 4);
- int64 peak_size;
- std::vector<const LogicalBuffer*> peak_buffers;
- std::tie(peak_size, peak_buffers) = buffer.ComputePeakMemoryLogicalBuffers();
+ const std::vector<const LogicalBuffer*>& peak_buffers =
+ buffer.PeakMemoryLogicalBuffers();
// The peak live set should be concat and its inputs.
- EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {400})));
ASSERT_EQ(peak_buffers.size(), 3);
std::vector<const HloInstruction*> peak_instructions;
for (const LogicalBuffer* logical_buffer : peak_buffers) {
@@ -1583,6 +1578,69 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
}
+TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
+ auto module = CreateNewModule();
+ const Shape shape = ShapeUtil::MakeShape(F32, {123, 123});
+ HloComputation* condition;
+ {
+ auto b = HloComputation::Builder(TestName() + ".cond");
+ b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
+ b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ condition = module->AddEmbeddedComputation(b.Build());
+ }
+ HloComputation* body;
+ {
+ auto b = HloComputation::Builder(TestName() + ".body");
+ auto param =
+ b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
+ b.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param));
+ body = module->AddEmbeddedComputation(b.Build());
+ }
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param));
+ auto while_op = builder.AddInstruction(
+ HloInstruction::CreateWhile(shape, condition, body, copy));
+ // This broadcast should get a temporary allocation which is merged with the
+ // allocation for the while. Peak buffers should include the while and the
+ // broadcast.
+ auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1}));
+ builder.AddInstruction(HloInstruction::CreateReverse(
+ ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0}));
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+ const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast);
+ const std::vector<const LogicalBuffer*>& peak_buffers =
+ buffer.PeakMemoryLogicalBuffers();
+ ASSERT_EQ(peak_buffers.size(), 2);
+
+ // The peak buffers should include the broadcast and one of the colocated
+ // buffers of the while (body param, condition param, body root, or the while
+ // itself).
+ const LogicalBuffer* bcast_buffer;
+ const LogicalBuffer* nonbcast_buffer;
+ if (peak_buffers[0]->instruction() == bcast) {
+ bcast_buffer = peak_buffers[0];
+ nonbcast_buffer = peak_buffers[1];
+ } else {
+ bcast_buffer = peak_buffers[1];
+ nonbcast_buffer = peak_buffers[0];
+ }
+ EXPECT_EQ(bcast_buffer->instruction(), bcast);
+ EXPECT_TRUE(
+ nonbcast_buffer->instruction() == copy ||
+ nonbcast_buffer->instruction() == while_op ||
+ nonbcast_buffer->instruction() == body->parameter_instruction(0) ||
+ nonbcast_buffer->instruction() == body->root_instruction() ||
+ nonbcast_buffer->instruction() == condition->parameter_instruction(0));
+}
+
class WhileBufferAssignmentTest : public HloTestBase {
protected:
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ec2bb6c762..e298d67e09 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -294,7 +294,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout());
+ module->device_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
@@ -787,6 +787,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_RETURN_IF_ERROR(verify_status);
}
+ XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(llvm_module));
+
Disassembler disassembler(*target_machine);
CompilerFunctor compiler_functor(
target_machine.get(), &disassembler, opt_level,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index aabf4d5161..32613b8690 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -249,8 +249,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
std::vector<bool>* buffers_in_result) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
- /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(),
- run_options->allocator(), stream->parent()->device_ordinal());
+ /*on_host_shape=*/host_result_shape(),
+ /*on_device_shape=*/host_result_shape(), run_options->allocator(),
+ stream->parent()->device_ordinal());
// Copy DeviceMemoryBase values which contain the array(s) of the result into
// the respective location in ShapedBuffer which is returned to the caller.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index c8edbb9e15..09adb5cb02 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -27,7 +27,8 @@ namespace cpu {
// layout constraints for operands and results of library calls.
class CpuLayoutAssignment : public LayoutAssignment {
public:
- explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout)
+ explicit CpuLayoutAssignment(
+ const ComputationLayout& entry_computation_layout)
: LayoutAssignment(entry_computation_layout) {}
~CpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 6ba030fff3..ba4c5a23d3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -49,7 +49,7 @@ class CpuLayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout) {
- cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout);
+ cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -311,7 +311,7 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
result.addend_fusion_param = fusion_instruction->operand(
fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number());
- cpu::CpuLayoutAssignment layout_assignment(&computation_layout);
+ cpu::CpuLayoutAssignment layout_assignment(computation_layout);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index d582b5aaae..e473389a29 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -160,10 +160,8 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
-Status IrEmitter::HandleConstant(HloInstruction* constant) {
- VLOG(2) << "HandleConstant: " << constant->ToString();
- const Literal& literal = constant->literal();
- llvm::GlobalVariable* global_for_const;
+llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
+ llvm::GlobalVariable* result;
// We avoid creating large constants in the LLVM IR since LLVM is not
// efficient for large constant arrays. We still emit "small enough" constant
@@ -174,27 +172,42 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) {
ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) {
string global_name = tensorflow::strings::StrCat(
"constant_global_", external_global_constant_counter_++);
- global_for_const = new llvm::GlobalVariable(
+ result = new llvm::GlobalVariable(
/*Module=*/*module_,
/*Type=*/IrShapeType(literal.shape()),
/*isConstant=*/true,
/*Linkage=*/llvm::GlobalValue::ExternalLinkage,
/*Initializer=*/nullptr,
/*Name=*/AsStringRef(global_name));
- global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ result->setAlignment(MinimumAlignmentForShape(literal.shape()));
external_constant_pool_->Insert(global_name, literal,
MinimumAlignmentForShape(literal.shape()));
} else {
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
- global_for_const = new llvm::GlobalVariable(
+ result = new llvm::GlobalVariable(
/*Module=*/*module_,
/*Type=*/initializer->getType(),
/*isConstant=*/true,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/initializer,
/*Name=*/"");
- global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ result->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ }
+ return result;
+}
+
+Status IrEmitter::HandleConstant(HloInstruction* constant) {
+ VLOG(2) << "HandleConstant: " << constant->ToString();
+ const Literal& literal = constant->literal();
+ llvm::GlobalVariable* global_for_const;
+
+ auto it = emitted_literals_.find(&literal);
+ if (it != emitted_literals_.end()) {
+ global_for_const = it->second;
+ } else {
+ global_for_const = EmitGlobalForLiteral(literal);
+ emitted_literals_[&literal] = global_for_const;
}
emitted_value_[constant] = global_for_const;
VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 0f2f3d1817..5a04076080 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -530,6 +530,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address);
+ llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal);
+
const HloModuleConfig& hlo_module_config_;
bool is_top_level_computation_;
@@ -539,6 +541,20 @@ class IrEmitter : public DfsHloVisitorWithDefault {
int64 external_global_constant_counter_ = 0;
ExternalConstantPool* external_constant_pool_;
+ struct LiteralPtrHashFunctor {
+ size_t operator()(const Literal* literal) const { return literal->Hash(); }
+ };
+
+ struct LiteralPtrEqualityFunctor {
+ bool operator()(const Literal* lhs, const Literal* rhs) const {
+ return *lhs == *rhs;
+ }
+ };
+
+ tensorflow::gtl::FlatMap<const Literal*, llvm::GlobalVariable*,
+ LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
+ emitted_literals_;
+
TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
};
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
new file mode 100644
index 0000000000..4ddb7a85bc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -0,0 +1,163 @@
+# Description:
+# Tests for LLVM-based CPU backend for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [":friends"],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "cpu_codegen_test",
+ testonly = True,
+ hdrs = ["cpu_codegen_test.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_fusion_test",
+ srcs = ["cpu_fusion_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_bytesizeof_test",
+ srcs = ["cpu_bytesizeof_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_external_constants_test",
+ srcs = ["cpu_external_constants_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:filecheck",
+ "//tensorflow/core:test",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_noalias_test",
+ srcs = ["cpu_noalias_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/tests:filecheck",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@llvm//:core",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_intrinsic_test",
+ srcs = ["cpu_intrinsic_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_eigen_dot_operation_test",
+ srcs = ["cpu_eigen_dot_operation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_infeed_test",
+ srcs = ["cpu_infeed_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "cpu_literal_caching_test",
+ srcs = ["cpu_literal_caching_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc
new file mode 100644
index 0000000000..d5bbe7677a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/test.h"
+
+class CpuByteSizeOfTest : public ::testing::Test {};
+
+TEST_F(CpuByteSizeOfTest, ARM32) {
+ llvm::DataLayout data_layout(
+ "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64");
+ auto tuple_shape =
+ xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})});
+ EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout),
+ data_layout.getPointerSize(0 /* default address space */));
+}
+
+TEST_F(CpuByteSizeOfTest, ARM64) {
+ llvm::DataLayout data_layout("e-m:e-i64:64-i128:128-n32:64-S128");
+ auto tuple_shape =
+ xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})});
+ EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout),
+ data_layout.getPointerSize(0 /* default address space */));
+}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
new file mode 100644
index 0000000000..7c8d07a10b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
@@ -0,0 +1,30 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_
+
+#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h"
+
+namespace xla {
+namespace cpu {
+
+// Tests that verify IR emitted by the CPU backend is as expected.
+class CpuCodegenTest : public LLVMIRGenTestBase {};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
new file mode 100644
index 0000000000..6fcce42eaa
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -0,0 +1,113 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that we call into Eigen for dot operations as needed.
+
+#include <algorithm>
+#include <cctype>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+struct DotTestSpec {
+ PrimitiveType primitive_type;
+ string filecheck_lines;
+};
+
+string DotTestSpecToString(const ::testing::TestParamInfo<DotTestSpec>& info) {
+ return PrimitiveType_Name(info.param.primitive_type);
+}
+
+class CpuEigenDotOperationTest
+ : public CpuCodegenTest,
+ public ::testing::WithParamInterface<DotTestSpec> {
+ protected:
+ void CompileAndCheck(std::unique_ptr<HloComputation> entry_computation,
+ const string& filecheck_lines) {
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(std::move(entry_computation));
+
+ CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options,
+ filecheck_lines,
+ /*match_optimized_ir=*/true);
+ }
+};
+
+TEST_P(CpuEigenDotOperationTest, SimpleDotOp) {
+ HloComputation::Builder builder(TestName());
+ DotTestSpec spec = GetParam();
+
+ auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128});
+
+ HloInstruction* lhs = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, param_shape, "input"));
+ HloInstruction* rhs = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, param_shape, "input"));
+
+ builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs));
+ CompileAndCheck(builder.Build(), spec.filecheck_lines);
+}
+
+TEST_P(CpuEigenDotOperationTest, DotTransposeOp) {
+ HloComputation::Builder builder(TestName());
+ DotTestSpec spec = GetParam();
+
+ auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128});
+
+ HloInstruction* lhs = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, param_shape, "input"));
+ HloInstruction* rhs = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, param_shape, "input"));
+ HloInstruction* lhs_transposed = builder.AddInstruction(
+ HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
+
+ builder.AddInstruction(
+ HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs));
+ CompileAndCheck(builder.Build(), spec.filecheck_lines);
+}
+
+std::vector<DotTestSpec> GetDotTestCases() {
+ std::vector<DotTestSpec> result;
+ result.push_back(
+ {F16, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF16)"});
+ result.push_back(
+ {F32, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF32)"});
+ result.push_back(
+ {F64, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF64)"});
+ return result;
+}
+
+INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation,
+ CpuEigenDotOperationTest,
+ ::testing::ValuesIn(GetDotTestCases()),
+ DotTestSpecToString);
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
new file mode 100644
index 0000000000..ed8f375bd6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuExternalConstantsTest : public CpuCodegenTest {
+ public:
+ void TestWithArray(int64 rows, int64 cols, const char* filecheck_pattern) {
+ HloComputation::Builder builder(TestName());
+
+ Array2D<float> backing_array(rows, cols);
+ backing_array.FillUnique();
+
+ auto shape = ShapeUtil::MakeShape(F32, {rows, cols});
+
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2FromArray2D(backing_array)));
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant));
+
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ CompileAndVerifyIr(std::move(module), filecheck_pattern,
+ /*match_optimized_ir=*/false);
+ }
+};
+
+TEST_F(CpuExternalConstantsTest, Basic) {
+ TestWithArray(/*rows=*/1024, /*cols=*/1024, R"(
+CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16
+)");
+}
+
+TEST_F(CpuExternalConstantsTest, BasicNegative) {
+ // The constant array in this test case is small enough that there is no need
+ // to externalize it.
+ TestWithArray(/*rows=*/4, /*cols=*/4, R"(
+CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8
+CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8
+)");
+}
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
new file mode 100644
index 0000000000..23e7a3de4d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -0,0 +1,330 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+class CpuFusionTest : public HloTestBase {
+ protected:
+ CpuFusionTest() {}
+
+ ErrorSpec error_spec_{0.0001, 1e-5};
+};
+
+TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
+ Shape vshape = input_literal1->shape();
+
+ auto input1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(input_literal1)));
+ auto input2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(input_literal2)));
+
+ auto add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(vshape, HloOpcode::kAdd, input1, input2));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ CpuInstructionFusion fusion;
+ EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+
+ // The computation root instruction was fused. Verify the fusion instruction
+ // is now the root.
+ auto computation = module->entry_computation();
+ auto fusion_instruction = computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
+ EXPECT_EQ(HloOpcode::kNegate,
+ fusion_instruction->fused_expression_root()->opcode());
+ // There should be four fused instructions: 2 parameters, the add, and the
+ // negate.
+ EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
+
+ // Compile and execute the computation.
+ auto result = ExecuteAndTransfer(std::move(module), {});
+
+ // Check the output correctness.
+ LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+}
+
+TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
+ Shape vshape = input_literal->shape();
+
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(input_literal)));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input));
+ auto ceil = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil));
+ auto floor = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp));
+ auto two = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ CpuInstructionFusion fusion;
+ EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+
+ // The computation root instruction was fused. Verify the fusion instruction
+ // is now the root.
+ auto computation = module->entry_computation();
+ auto fusion_instruction = computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
+ EXPECT_EQ(HloOpcode::kMultiply,
+ fusion_instruction->fused_expression_root()->opcode());
+ // There should be 7 fused instructions: 2 parameters and the fused
+ // operations.
+ EXPECT_EQ(7, fusion_instruction->fused_instruction_count());
+
+ // Compile and execute the computation.
+ auto result = ExecuteAndTransfer(std::move(module), {});
+
+ // Check the output correctness.
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
+ error_spec_);
+}
+
+TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
+ // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the
+ // middle.
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
+ Shape vshape = input_literal->shape();
+
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(input_literal)));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input));
+ auto ceil = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
+
+ auto cshape = ShapeUtil::MakeShape(F32, {6});
+ auto concatenate = builder.AddInstruction(
+ HloInstruction::CreateConcatenate(cshape, {ceil, ceil}, /*dimension=*/0));
+
+ // Build an x+y computation to use in a reduce.
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ auto embedded_builder = HloComputation::Builder("f32+f32");
+ embedded_builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kAdd,
+ embedded_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "x")),
+ embedded_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "y"))));
+ auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
+
+ // This is a nop reduction.
+ auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
+ cshape,
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
+ /*init_value=*/
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ /*dimensions_to_reduce=*/{1}, add_f32));
+
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
+ auto floor = builder.AddInstruction(
+ HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp));
+ auto two = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
+
+ module->AddEntryComputation(builder.Build());
+
+ CpuInstructionFusion fusion;
+ EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+
+ // The computation root instruction was fused. Verify the fusion instruction
+ // is now the root.
+ auto computation = module->entry_computation();
+
+ auto fusion_instruction1 = computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
+ EXPECT_EQ(HloOpcode::kMultiply,
+ fusion_instruction1->fused_expression_root()->opcode());
+ // There should be 5 fused instructions in the root fusion instruction: 2
+ // parameters, multiply, floor, and exp.
+ EXPECT_EQ(5, fusion_instruction1->fused_instruction_count())
+ << fusion_instruction1->fused_instructions_computation()->ToString();
+
+ auto fusion_instruction2 = reduce->operand(0);
+ EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
+ EXPECT_EQ(HloOpcode::kReshape,
+ fusion_instruction2->fused_expression_root()->opcode());
+ // There should be 5 fused instructions in the second fusion instruction: 1
+ // parameter, negate, ceil, concat, and reshape.
+ EXPECT_EQ(5, fusion_instruction2->fused_instruction_count())
+ << fusion_instruction2->fused_instructions_computation()->ToString();
+
+ // Compile and execute the computation.
+ auto result = ExecuteAndTransfer(std::move(module), {});
+
+ // Check the output correctness.
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
+ *result, error_spec_);
+}
+
+TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
+ // Test that the operands of an instruction to be fused are considered in the
+ // proper order to avoid duplication. Test input:
+ //
+ // constant = {...}
+ // negate = neg(constant)
+ // ceil = ceil(negate)
+ // add1 = add(negate, ceil)
+ // add2 = add(ceil, negate)
+ //
+ // In this example, the operands of both add1 and add2 should be fused in the
+ // order {ceil, negate} even though they have different orders in their
+ // operand vectors. Test for this problem by counting the number of nodes in
+ // each fusion instruction to ensure that negate is not duplicated.
+ auto builder = HloComputation::Builder(TestName());
+ auto input_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ Shape vshape = input_literal->shape();
+
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(input_literal)));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, constant));
+ auto ceil = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
+
+ auto add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, negate, ceil));
+ auto add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, ceil, negate));
+
+ // Tie together the two adds with a tuple to create a single root.
+ auto result =
+ builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
+
+ // Create computation and module.
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ // Run fusion.
+ CpuInstructionFusion fusion;
+ EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+
+ auto fusion1 = result->operand(0);
+ auto fusion2 = result->operand(1);
+ EXPECT_EQ(HloOpcode::kFusion, fusion1->opcode());
+ EXPECT_EQ(HloOpcode::kFusion, fusion2->opcode());
+
+ // Each fusion instruction should have 4 fused instruction inside: add, ceil,
+ // negate, and the fused parameter.
+ EXPECT_EQ(4, fusion1->fused_instruction_count());
+ EXPECT_EQ(4, fusion2->fused_instruction_count());
+
+ // Each fusion instruction should have one parameter and the parameter should
+ // be the constant.
+ EXPECT_EQ(1, fusion1->operand_count());
+ EXPECT_EQ(constant, fusion1->operand(0));
+ EXPECT_EQ(1, fusion2->operand_count());
+ EXPECT_EQ(constant, fusion2->operand(0));
+}
+
+TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
+ // Verify that expensive operations will not be fused if the fusion results in
+ // duplication. Test code:
+ //
+ // constant = 42.0
+ // exp1 = exp(constant)
+ // negate1 = negate(exp1)
+ // exp2 = exp(constant)
+ // negate2 = negate(exp2)
+ // tuple = tuple(negate1, negate2, exp2)
+ //
+ // exp1 should be fused down into negate1, but exp2 will not be fused into
+ // negate2 because this will result in duplication of the expensive exp
+ // computation. The duplication is caused by the other use of exp2 in the
+ // tuple.
+ auto builder = HloComputation::Builder(TestName());
+ auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ Shape shape = constant->shape();
+
+ auto exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant));
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp1));
+
+ auto exp2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant));
+ auto negate2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp2));
+
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({negate1, negate2, exp2}));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ CpuInstructionFusion fusion;
+ EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+
+ // The only fusion instruction should be operand 0 of the tuple (formerly
+ // negate1).
+ EXPECT_EQ(HloOpcode::kFusion, tuple->operand(0)->opcode());
+ EXPECT_EQ(HloOpcode::kNegate, tuple->operand(1)->opcode());
+ EXPECT_EQ(HloOpcode::kExp, tuple->operand(2)->opcode());
+
+ auto fusion_inst = tuple->operand(0);
+ // There should be three fused instructions: negate2, exp2, and the fused
+ // parameter.
+ EXPECT_EQ(3, fusion_inst->fused_instruction_count());
+ EXPECT_EQ(1, fusion_inst->operand_count());
+ EXPECT_EQ(constant, fusion_inst->operand(0));
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
new file mode 100644
index 0000000000..dd63b998e9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -0,0 +1,294 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unistd.h>
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class InfeedTest : public ClientLibraryTestBase {
+ protected:
+ // Transfers the given literal to the infeed interface of the device, and
+ // check if the returned data from Infeed HLO is same as the literal.
+ void TestInfeedRoundTrip(const Literal& literal) {
+ // TODO(b/31037751) Explicitly reset the Infeed state so that the
+ // test is not affected by the state from the previous tests by
+ // adding ClearInfeed if necessary when it is implemented. For now
+ // don't use ResetDevice since it is not implemented on CPU.
+ ASSERT_IS_OK(client_->TransferToInfeed(literal));
+ XlaBuilder builder(TestName());
+ builder.Infeed(literal.shape());
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ // TODO(b/30609564): Use ComputeAndCompareLiteral instead.
+ ComputeAndCompareTuple(&builder, literal, {});
+ } else {
+ ComputeAndCompareLiteral(&builder, literal, {});
+ }
+ }
+};
+
+TEST_F(InfeedTest, SingleInfeedR0Bool) {
+ TestInfeedRoundTrip(*Literal::CreateR0<bool>(true));
+}
+
+TEST_F(InfeedTest, SingleInfeedR1U32) {
+ TestInfeedRoundTrip(*Literal::CreateR1<uint32>({1, 2, 3}));
+}
+
+TEST_F(InfeedTest, SingleInfeedR2F32) {
+ TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+}
+
+TEST_F(InfeedTest, SingleInfeedR3F32) {
+ TestInfeedRoundTrip(
+ *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+}
+
+TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
+ const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
+ const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
+
+ TestInfeedRoundTrip(
+ *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0minor));
+
+ TestInfeedRoundTrip(
+ *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0major));
+}
+
+TEST_F(InfeedTest, SingleInfeedR4S32) {
+ TestInfeedRoundTrip(*Literal::CreateR4(
+ {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
+ {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
+}
+
+TEST_F(InfeedTest, SingleInfeedTuple) {
+ TestInfeedRoundTrip(
+ *Literal::MakeTuple({Literal::CreateR1<uint32>({1, 2, 3}).get(),
+ Literal::CreateR0<bool>(false).get()}));
+}
+
+TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
+ TestInfeedRoundTrip(*Literal::MakeTuple({}));
+}
+
+// Tests Infeed operation used in a while loop, as in the code below. The
+// computation is launched asynchronously, and then infeed data is transferred.
+//
+// float acc = 0.0f;
+// while (acc < 40.0f) {
+// acc += reduce_add(Infeed());
+// }
+// return acc;
+// TODO(b/30671675) enable this test once asynchronous execution is
+// implemented for CPU.
+TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
+ XlaBuilder builder(TestName());
+ const auto infeed_shape = ShapeUtil::MakeShape(F32, {3});
+ const auto result_shape = ShapeUtil::MakeShape(F32, {});
+
+ // Create a computation for the condition: repeat until (prev < 40.0f) holds.
+ XlaComputation condition;
+ {
+ XlaBuilder builder("condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.Gt(builder.ConstantR0<float>(40.0f), prev);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+ // Create a computation for the body: add the reduced value of the Infeed
+ // data to the result variable.
+ XlaComputation body;
+ {
+ XlaBuilder builder("body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto infeed = builder.Infeed(infeed_shape);
+ auto addend =
+ builder.Reduce(infeed, builder.ConstantR0<float>(0.0f),
+ CreateScalarAddComputation(F32, &builder), {0});
+ builder.Add(prev, addend);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+ // Create a While node with computations for the condition and the body.
+ auto init = builder.ConstantR0<float>(0.0f);
+ builder.While(condition, body, init);
+
+ // Build and asynchronously launch the computation.
+ auto computation = builder.Build().ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> result;
+ tensorflow::Thread* computation_thread =
+ tensorflow::Env::Default()->StartThread(
+ tensorflow::ThreadOptions{}, "computation_thread", [&] {
+ result = client_->Execute(computation, {}, &execution_options_)
+ .ValueOrDie();
+ });
+
+ // Send 5 Infeed data of shape F32[3].
+ ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({1, 2, 3})));
+ ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({4, 5, 6})));
+ ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({7, 8, 9})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*Literal::CreateR1<float>({10, 11, 12})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*Literal::CreateR1<float>({13, 14, 15})));
+
+ delete computation_thread; // Joins the thread.
+ auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
+
+ // Only the first 3 infeed data should be added.
+ LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+}
+
+// Tests two Infeed operations with a total order. The order is enforced by
+// using the result of the first while loop as the initial value of the second
+// while loop. The shapes of both Infeeds are Tuples, where the first tuple
+// element (R1F32) is for the data to reduce and accumulate, and the second
+// tuple element (PRED) to indicate whether the loop should continue. The
+// computation is launched asynchronously, and then infeed data is transferred.
+//
+// float acc = 0.0f;
+// continue = true;
+// while (!continue) {
+// (data, continue) = Infeed(shape1);
+// acc += reduce_add(data)
+// }
+// continue = true;
+// while(!continue) {
+// (data, continue) = Infeed(shape2);
+// acc += reduce_add(data)
+// }
+// return acc;
+// TODO(b/30671675) enable this test once asynchronous execution is
+// implemented for CPU.
+TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
+ XlaBuilder builder(TestName());
+ const auto infeed1_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(PRED, {})});
+ const auto infeed2_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(PRED, {})});
+ const auto result_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(PRED, {})});
+
+ // Create a computation for the condition: repeat until the second tuple
+ // element is false.
+ XlaComputation condition;
+ {
+ XlaBuilder builder("condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.GetTupleElement(prev, 1);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // A lambda that builds the body computation of a while loop with the given
+ // infeed shape, and returns the computation with the ownership.
+ //
+ // The body adds the reduced value of the Infeed data (first tuple element)
+ // to the previous accumulator, and returns the accumulator and the continue
+ // flag (second tuple element) as a tuple.
+ const auto build_body = [this, &result_shape](const Shape& infeed_shape) {
+ XlaComputation body;
+ XlaBuilder builder("body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto infeed = builder.Infeed(infeed_shape);
+ auto addend = builder.Reduce(
+ builder.GetTupleElement(infeed, 0), builder.ConstantR0<float>(0.0f),
+ CreateScalarAddComputation(F32, &builder), {0});
+ auto result = builder.Add(builder.GetTupleElement(prev, 0), addend);
+ builder.Tuple({result, builder.GetTupleElement(infeed, 1)});
+ return builder.Build().ConsumeValueOrDie();
+ };
+
+ // Create the first while loop with infeed1_shape.
+ auto init = builder.Tuple(
+ {builder.ConstantR0<float>(0.0f), builder.ConstantR0<bool>(true)});
+ auto while1 = builder.While(condition, build_body(infeed1_shape), init);
+ auto result1 = builder.Tuple(
+ {builder.GetTupleElement(while1, 0), builder.ConstantR0<bool>(true)});
+
+ // Create the second while loop with infeed2_shape. Note that the result from
+ // the first while loop is used as the initial value.
+ auto while2 = builder.While(condition, build_body(infeed2_shape), result1);
+ builder.GetTupleElement(while2, 0);
+
+ // Build the computation.
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ // Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
+ Literal::CreateR0<bool>(true).get()})));
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({3, 4}).get(),
+ Literal::CreateR0<bool>(true).get()})));
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({5, 6}).get(),
+ Literal::CreateR0<bool>(true).get()})));
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({7, 8}).get(),
+ Literal::CreateR0<bool>(false).get()})));
+
+ // Asynchronously launch the execution on the device.
+ std::unique_ptr<GlobalData> result;
+ tensorflow::Thread* computation_thread =
+ tensorflow::Env::Default()->StartThread(
+ tensorflow::ThreadOptions{}, "computation_thread", [&] {
+ result = client_->Execute(computation, {}, &execution_options_)
+ .ValueOrDie();
+ });
+
+ // Wait for a second to ensure testing that the execution is waiting on the
+ // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
+ sleep(1);
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({1, 2, 3}).get(),
+ Literal::CreateR0<bool>(true).get()})));
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({7, 8, 9}).get(),
+ Literal::CreateR0<bool>(false).get()})));
+ ASSERT_IS_OK(client_->TransferToInfeed(
+ *Literal::MakeTuple({Literal::CreateR1<float>({4, 5, 6}).get(),
+ Literal::CreateR0<bool>(true).get()})));
+
+ // Wait for the execution to be done, and transfer the result.
+ delete computation_thread; // Joins the thread.
+ auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
+
+ // Only the first 6 infeed data should be added.
+ LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
new file mode 100644
index 0000000000..973aac8766
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cctype>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+const char* const kTriple_x86_64 = "x86_64-pc-linux";
+const char* const kTriple_android_arm = "armv7-none-android";
+
+struct IntrinsicTestSpec {
+ HloOpcode opcode;
+ tensorflow::StringPiece triple;
+ tensorflow::StringPiece features;
+ tensorflow::StringPiece check_lines;
+};
+
+// Tests that unary functions get lowered using intrinsic calls.
+class CpuUnaryIntrinsicTest
+ : public CpuCodegenTest,
+ public ::testing::WithParamInterface<IntrinsicTestSpec> {
+ public:
+ static string Name(const ::testing::TestParamInfo<IntrinsicTestSpec>& info) {
+ auto spec = info.param;
+
+ string opcode = HloOpcodeString(spec.opcode);
+ opcode[0] = toupper(opcode[0]);
+
+ string triple{spec.triple.data(), spec.triple.size()};
+ if (triple == kTriple_x86_64) {
+ triple = "x86_64";
+ } else if (triple == kTriple_android_arm) {
+ triple = "android_arm";
+ } else {
+ triple = "Unknown";
+ }
+
+ string features{spec.features.data(), spec.features.size()};
+ if (!features.empty()) {
+ std::replace_if(features.begin(), features.end(),
+ [](char c) { return c != '_' && !isalnum(c); }, '_');
+ } else {
+ features = "";
+ }
+
+ return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(),
+ features.empty() ? "" : "_With",
+ features.c_str());
+ }
+};
+
+// Creates a module with a call to the unary op, and tests if the
+// compiler replaced it with a call to the intrinsic.
+TEST_P(CpuUnaryIntrinsicTest, DoIt) {
+ HloComputation::Builder builder(TestName());
+ IntrinsicTestSpec spec = GetParam();
+
+ auto param_shape = ShapeUtil::MakeShape(F32, {1024});
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, param_shape, "input"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(param_shape, spec.opcode, param));
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ string triple{spec.triple.data(), spec.triple.size()};
+ string features{spec.features.data(), spec.features.size()};
+
+ CpuAotCompilationOptions options{
+ /*triple=*/triple, /*cpu_name=*/"", /*features=*/features,
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ string check_lines{spec.check_lines.data(), spec.check_lines.size()};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, check_lines,
+ /*match_optimized_ir=*/true);
+}
+
+IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = {
+ // The intrinsics are always inlined, so we match a line from it instead of
+ // a function call.
+
+ IntrinsicTestSpec{
+ HloOpcode::kExp, kTriple_x86_64, "",
+ R"(CHECK: fmul fast <4 x float> <float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kExp, kTriple_x86_64, "+avx",
+ R"(CHECK: fmul fast <8 x float> <float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kExp, kTriple_android_arm, "+neon",
+ R"(CHECK: fmul fast <4 x float> <float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000, float 0xBF2BD01060000000>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kTanh, kTriple_x86_64, "",
+ R"(CHECK: fcmp fast uge <4 x float> %wide.load, <float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kTanh, kTriple_x86_64, "+avx",
+ R"(CHECK: fcmp fast uge <8 x float> %wide.load, <float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kTanh, kTriple_android_arm, "",
+ R"(CHECK: fcmp fast uge <4 x float> %wide.load, <float -9.000000e+00, float -9.000000e+00, float -9.000000e+00, float -9.000000e+00>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kLog, kTriple_x86_64, "",
+ R"(CHECK: fadd fast <4 x float> <float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kLog, kTriple_x86_64, "+avx",
+ R"(CHECK: fadd fast <8 x float> <float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000>)"},
+
+ IntrinsicTestSpec{
+ HloOpcode::kLog, kTriple_android_arm, "",
+ R"(CHECK: fadd fast <4 x float> <float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000>)"}};
+
+INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation,
+ CpuUnaryIntrinsicTest,
+ ::testing::ValuesIn(CpuUnaryIntrinsicTestCases),
+ CpuUnaryIntrinsicTest::Name);
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
new file mode 100644
index 0000000000..f0404d07d9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
@@ -0,0 +1,125 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuExternalConstantsTest : public CpuCodegenTest {};
+
+TEST_F(CpuExternalConstantsTest, RepeatedArrayConstants) {
+ // We use a while loop here to force the two constant HloInstructions to be in
+ // different computations. Otherwise the HLO optimizer itself CSEs them.
+ const string hlo_text = R"(
+HloModule RepeatedConstants
+
+while_body {
+ arg_body = f32[2,3,2] parameter(0)
+ ROOT const = f32[2,3,2] constant(
+ f32[2,3,2]
+ {{{1, 2}, {1001, 1002}, {2001, 2002}},
+ {{2, 1}, {2001, 3002}, {2001, 2002}}})
+}
+
+while_cond {
+ arg_cond = f32[2,3,2] parameter(0)
+ ROOT unknown = pred[] infeed()
+}
+
+ENTRY main {
+ param = f32[2,3,2] parameter(0)
+ const_a = f32[2,3,2] constant(
+ f32[2,3,2]
+ {{{1, 2}, {1001, 1002}, {2001, 2002}},
+ {{2, 1}, {2001, 3002}, {2001, 2002}}})
+ const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body
+
+ out0 = () outfeed(f32[2,3,2] const_a)
+ out1 = () outfeed(f32[2,3,2] const_b)
+
+ ROOT root = f32[] constant(1)
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: private constant [2 x [3 x [2 x float]]]
+CHECK-NOT: private constant [2 x [3 x [2 x float]]]
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/false);
+}
+
+TEST_F(CpuExternalConstantsTest, RepeatedTupleConstants) {
+ // We use a while loop here to force the two constant HloInstructions to be in
+ // different computations. Otherwise the HLO optimizer itself CSEs them.
+ const string hlo_text = R"(
+HloModule RepeatedConstants
+
+while_body {
+ arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0)
+ ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
+}
+
+while_cond {
+ arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0)
+ ROOT unknown = pred[] infeed()
+}
+
+ENTRY main {
+ param = f32[2,3,2] parameter(0)
+ const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
+ const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body
+
+ out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a)
+ out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b)
+
+ ROOT root = f32[] constant(1)
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: private constant [2 x float]
+CHECK: private constant [2 x [1 x float]]
+CHECK-NOT: private constant [2 x float]
+CHECK-NOT: private constant [2 x [1 x float]]
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/false);
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
new file mode 100644
index 0000000000..3b6b0ed740
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -0,0 +1,136 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+
+class CpuNoAliasTest : public CpuCodegenTest {};
+
+// Creates a simple HLO ir_module (runs concat(concat(x, y), x)), and then
+// inspects the aliasing information for loads to its buffers.
+TEST_F(CpuNoAliasTest, Concat) {
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal =
+ Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* param_x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, param_shape, "x"));
+ HloInstruction* param_y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, param_shape, "y"));
+ HloInstruction* concat1 =
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1));
+ HloInstruction* concat2 =
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(F32, {2, 6}), {concat1, param_x}, 1));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it.
+ auto status_or_buffer_assn = BufferAssigner::Run(
+ hlo_module.get(), MakeUnique<DependencyHloOrdering>(hlo_module.get()),
+ backend().compiler()->BufferSizeBytesFunction(),
+ [](LogicalBuffer::Color) { return /*alignment=*/1; });
+ ASSERT_EQ(status_or_buffer_assn.status(), Status::OK());
+
+ llvm::LLVMContext context;
+ llvm_ir::AliasAnalysis aa(*hlo_module, *status_or_buffer_assn.ValueOrDie(),
+ &context);
+
+ // Construct an LLVM module containing loads that we annotate as being from
+ // the buffers in the HLO module. We'll inspect these loads to ensure that
+ // they have the expected alias information.
+ llvm::Module ir_module("test", context);
+ llvm::Function* func = llvm::cast<llvm::Function>(
+ ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context)));
+ llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func);
+ llvm::IRBuilder<> ir_builder(bb);
+ auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0);
+ llvm_ir::IrArray::Index zero2D({zero, zero});
+
+ llvm::ArrayType* array2d_type = llvm::ArrayType::get(
+ llvm::ArrayType::get(llvm::Type::getFloatTy(context), 100), 100);
+
+ {
+ llvm::Value* param_x_val =
+ ir_module.getOrInsertGlobal("param_x", array2d_type);
+ llvm_ir::IrArray param_x_array(param_x_val, param_shape);
+ aa.AddAliasingInformationToIrArray(*param_x, &param_x_array);
+ param_x_array.EmitReadArrayElement(zero2D, &ir_builder)
+ ->setName("read_param_x_array");
+ }
+
+ {
+ llvm::Value* concat1_val =
+ ir_module.getOrInsertGlobal("concat1", array2d_type);
+ auto shape = ShapeUtil::MakeShape(F32, {2, 4});
+ llvm_ir::IrArray concat1_array(concat1_val, shape);
+ aa.AddAliasingInformationToIrArray(*concat1, &concat1_array);
+ concat1_array.EmitReadArrayElement(zero2D, &ir_builder)
+ ->setName("read_concat1_array");
+ }
+
+ {
+ llvm::Value* concat2_val =
+ ir_module.getOrInsertGlobal("concat2", array2d_type);
+ auto shape = ShapeUtil::MakeShape(F32, {2, 6});
+ llvm_ir::IrArray concat2_array(concat2_val, shape);
+ aa.AddAliasingInformationToIrArray(*concat2, &concat2_array);
+ concat2_array.EmitReadArrayElement(zero2D, &ir_builder)
+ ->setName("read_concat2_array");
+ }
+
+ // Check the AA info in the loads.
+ const char* filecheck_pattern = R"(
+ CHECK: %read_param_x_array = load {{.*}} !noalias [[param_x_noalias:![0-9]+]]
+ CHECK: %read_concat1_array = load {{.*}} !alias.scope [[concat1_scope:![0-9]+]], !noalias [[concat1_noalias:![0-9]+]]
+ CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]]
+ CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32
+ CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48
+ CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]}
+ CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]}
+ CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]}
+ )";
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool filecheck_match,
+ RunFileCheck(llvm_ir::DumpModuleToString(ir_module), filecheck_pattern));
+ EXPECT_TRUE(filecheck_match);
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 021f09d310..8119478ce9 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -143,6 +143,19 @@ Status Executable::DumpSessionModule() {
*session_module_);
}
+Status Executable::DumpHloSnapshot() {
+ TF_RET_CHECK(dumping_snapshot());
+ TF_RET_CHECK(hlo_snapshot_->has_hlo() &&
+ hlo_snapshot_->hlo().has_hlo_module());
+ const string& directory_path =
+ module_config().debug_options().xla_dump_executions_to();
+ const auto& module = hlo_snapshot_->hlo().hlo_module();
+ string filename = tensorflow::strings::Printf(
+ "computation_%lld__%s__execution_%lld", module.id(),
+ module.entry_computation_name().c_str(), ++execution_count_);
+ return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_);
+}
+
/* static */ Status Executable::DumpToDirectory(
const string& directory_path, string filename,
const SessionModule& session_module) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index f7af1ca574..4f0466c544 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -140,11 +140,11 @@ class Executable {
// The shape (including layout) that results from this execution. This is the
// shape of the DeviceMemoryBase result value in ExecuteOnStream above.
- const Shape& result_shape() const {
- return hlo_module_->config().entry_computation_layout().result_shape();
+ const Shape& host_result_shape() const {
+ return hlo_module_->config().host_entry_computation_layout().result_shape();
}
- // Dumping helpers.
+ // TODO(b/74197823): Delete the session module dumping helpers.
void set_session_module(std::unique_ptr<xla::SessionModule> session_module) {
session_module_ = std::move(session_module);
}
@@ -152,6 +152,14 @@ class Executable {
SessionModule* session_module() const { return session_module_.get(); }
Status DumpSessionModule();
+ // Dumping helpers.
+ void set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot) {
+ hlo_snapshot_ = std::move(hlo_snapshot);
+ }
+ bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; }
+ HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); }
+ Status DumpHloSnapshot();
+
// Dump session_module to directory_path/filename.
static Status DumpToDirectory(const string& directory_path, string filename,
const SessionModule& session_module);
@@ -174,6 +182,9 @@ class Executable {
// SessionModule this was compiled from. Null if not dumping executions.
std::unique_ptr<SessionModule> session_module_;
+ // HloSnapshot this was compiled from. Null if not dumping executions.
+ std::unique_ptr<HloSnapshot> hlo_snapshot_;
+
// Execution count, used to generate a unique filename for each dumped
// execution.
int64 execution_count_ = 0;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 796c3070f2..4fdc4c8961 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -248,7 +248,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
{
HloPassPipeline pipeline("layout_assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_entry_computation_layout());
+ hlo_module->device_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 86a3a7111f..51aae79c3d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -27,7 +27,8 @@ namespace gpu {
// layout constraints for operands and results of library calls.
class GpuLayoutAssignment : public LayoutAssignment {
public:
- explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
+ explicit GpuLayoutAssignment(
+ const ComputationLayout& entry_computation_layout)
: LayoutAssignment(entry_computation_layout) {}
~GpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 4c45d2e94a..7c80195594 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
*computation_layout.mutable_result_layout() =
ShapeLayout(result_shape_with_layout);
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -156,7 +156,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
*computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -225,7 +225,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
{result_shape, offset_scale_shape, offset_scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -305,7 +305,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
{result_shape, scale_shape, scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(computation_layout);
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 516e14b464..bb4db89f0a 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -804,7 +804,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
- if (ShapeUtil::HasZeroElements(shape)) {
+ if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) {
return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index b589cd573d..8e52d926d8 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -48,9 +47,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
-class HloGraphDumperTest : public HloTestBase {};
-
-TEST_F(HloGraphDumperTest, NestedFusion) {
+TEST(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
// Build param0 + param1 + param2 + param3 + param4.
@@ -67,9 +64,10 @@ TEST_F(HloGraphDumperTest, NestedFusion) {
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, sums[i], params[i + 2])));
}
- auto m = CreateNewModule();
- m->AddEntryComputation(b.Build());
- HloComputation* root_computation = m->entry_computation();
+ HloModuleConfig config;
+ HloModule m(TestName(), config);
+ m.AddEntryComputation(b.Build());
+ HloComputation* root_computation = m.entry_computation();
// Fuse into fusion(param0 + param1 + param2 + param3 + param4).
auto* outer_fusion = root_computation->CreateFusionInstruction(
@@ -119,18 +117,37 @@ TEST_F(HloGraphDumperTest, NestedFusion) {
HasSubstr(inner_sum->name()));
}
-TEST_F(HloGraphDumperTest, Constant) {
+TEST(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
instruction->set_name("i_am_a_constant_root_instruction");
- auto m = CreateNewModule();
- HloComputation* root_computation = m->AddEntryComputation(b.Build());
+ HloModuleConfig config;
+ HloModule m(TestName(), config);
+ HloComputation* root_computation = m.AddEntryComputation(b.Build());
string graph = hlo_graph_dumper::DumpGraph(
*root_computation, /*label=*/"an_empty_graph", DebugOptions());
EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
}
+TEST(HloGraphDumperTest, TupleConstant) {
+ Shape tuple_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
+ HloComputation::Builder b("b");
+ auto constant = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape)));
+ auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::MakeShape(F32, {3, 2}), constant, 0));
+
+ HloModuleConfig config;
+ HloModule m(TestName(), config);
+ HloComputation* root_computation = m.AddEntryComputation(b.Build(gte));
+ string graph = hlo_graph_dumper::DumpGraph(
+ *root_computation, /*label=*/"tuple_constant", DebugOptions());
+ EXPECT_THAT(graph, HasSubstr("tuple_constant"));
+ EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
+}
+
} // anonymous namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index d4bad16f79..987c4b2719 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -55,7 +55,7 @@ HloComputation* HloModule::AddComputationInternal(
// If the module configuration has no entry layout computation set, create a
// default one based on the program shape.
- if (!config_.has_entry_computation_layout()) {
+ if (!config_.has_host_entry_computation_layout()) {
config_.SetDefaultComputationLayout(
entry_computation_->ComputeProgramShape());
}
@@ -229,11 +229,14 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
TF_RET_CHECK(proto.has_program_shape())
<< "No program shape found in the proto";
const auto& expected_program_shape = proto.program_shape();
- TF_RET_CHECK(expected_program_shape.parameters_size() ==
- module_config.entry_computation_layout().parameter_count());
+ TF_RET_CHECK(
+ expected_program_shape.parameters_size() ==
+ module_config.device_entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
const Shape& parameter_shape =
- module_config.entry_computation_layout().parameter_layout(i).shape();
+ module_config.device_entry_computation_layout()
+ .parameter_layout(i)
+ .shape();
TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
parameter_shape))
<< "HloModuleConfig has different shape for parameter " << i
@@ -243,7 +246,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
}
const Shape& result_shape =
- module_config.entry_computation_layout().result_layout().shape();
+ module_config.device_entry_computation_layout().result_layout().shape();
TF_RET_CHECK(
ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
<< "HloModuleConfig has different result shape than the HLO module. "
@@ -303,7 +306,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
// The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values.
ComputationLayout* entry_layout =
- module_config.mutable_entry_computation_layout();
+ module_config.mutable_host_entry_computation_layout();
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
TF_RETURN_IF_ERROR(
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
@@ -311,6 +314,8 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
}
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
program_shape.result()));
+ *module_config.mutable_device_entry_computation_layout() =
+ module_config.host_entry_computation_layout();
return module_config;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index aa843ead51..82d790ec3b 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -98,12 +98,20 @@ class HloModule {
return entry_computation_;
}
- ComputationLayout* mutable_entry_computation_layout() {
- return config_.mutable_entry_computation_layout();
+ ComputationLayout* mutable_host_entry_computation_layout() {
+ return config_.mutable_host_entry_computation_layout();
}
- const ComputationLayout& entry_computation_layout() const {
- return config_.entry_computation_layout();
+ const ComputationLayout& host_entry_computation_layout() const {
+ return config_.host_entry_computation_layout();
+ }
+
+ ComputationLayout* mutable_device_entry_computation_layout() {
+ return config_.mutable_device_entry_computation_layout();
+ }
+
+ const ComputationLayout& device_entry_computation_layout() const {
+ return config_.device_entry_computation_layout();
}
const VersionedComputationHandle& entry_computation_handle() const {
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 4205b0402c..dae5578a31 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -31,11 +31,13 @@ using tensorflow::strings::StrAppend;
HloModuleConfig::HloModuleConfig() {}
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape)
- : entry_computation_layout_(program_shape) {}
+ : host_entry_computation_layout_(program_shape),
+ device_entry_computation_layout_(program_shape) {}
void HloModuleConfig::SetDefaultComputationLayout(
const ProgramShape& program_shape) {
- entry_computation_layout_ = ComputationLayout(program_shape);
+ host_entry_computation_layout_ = ComputationLayout(program_shape);
+ device_entry_computation_layout_ = ComputationLayout(program_shape);
}
string HloModuleConfig::compilation_cache_key() const {
@@ -44,11 +46,18 @@ string HloModuleConfig::compilation_cache_key() const {
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
- entry_computation_layout_->parameter_layouts()) {
+ host_entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
- entry_computation_layout_->result_shape().SerializeAsString());
+ host_entry_computation_layout_->result_shape().SerializeAsString());
+ for (const ShapeLayout& param_layout :
+ device_entry_computation_layout_->parameter_layouts()) {
+ params.push_back(param_layout.shape().DebugString());
+ }
+ StrAppend(
+ &key, tensorflow::str_util::Join(params, ", "), ") => ",
+ device_entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
static std::atomic<int> counter{0};
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 586a03d412..cdb0b29a23 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -41,26 +41,44 @@ class HloModuleConfig {
explicit HloModuleConfig(const ProgramShape& program_shape);
// Checks if this config has an entry computation layout already.
- bool has_entry_computation_layout() const {
- return entry_computation_layout_.has_value();
+ bool has_host_entry_computation_layout() const {
+ return host_entry_computation_layout_.has_value();
+ }
+
+ bool has_device_entry_computation_layout() const {
+ return device_entry_computation_layout_.has_value();
}
// Sets the entry computation layout for this config. If the entry computation
// layout already exists, it is silently replaced.
void SetDefaultComputationLayout(const ProgramShape& program_shape);
- // Returns a constant reference to the layout of the entry computation.
+ // Returns a constant reference to the on-host layout of the entry
+ // computation. Assumes the layout was set.
+ const ComputationLayout& host_entry_computation_layout() const {
+ CHECK(host_entry_computation_layout_.has_value());
+ return *host_entry_computation_layout_;
+ }
+
+ // Returns a mutable pointer to the layout of the on-host entry computation.
// Assumes the layout was set.
- const ComputationLayout& entry_computation_layout() const {
- CHECK(entry_computation_layout_.has_value());
- return *entry_computation_layout_;
+ ComputationLayout* mutable_host_entry_computation_layout() {
+ CHECK(host_entry_computation_layout_.has_value());
+ return &(*host_entry_computation_layout_);
}
- // Returns a mutable pointer to the layout of the entry computation. Assumes
- // the layout was set.
- ComputationLayout* mutable_entry_computation_layout() {
- CHECK(entry_computation_layout_.has_value());
- return &(*entry_computation_layout_);
+ // Returns a constant reference to the on-device layout of the entry
+ // computation. Assumes the layout was set.
+ const ComputationLayout& device_entry_computation_layout() const {
+ CHECK(device_entry_computation_layout_.has_value());
+ return *device_entry_computation_layout_;
+ }
+
+ // Returns a mutable pointer to the layout of the on-device entry computation.
+ // Assumes the layout was set.
+ ComputationLayout* mutable_device_entry_computation_layout() {
+ CHECK(device_entry_computation_layout_.has_value());
+ return &(*device_entry_computation_layout_);
}
// Returns whether to enable HLO-level profiling.
@@ -109,7 +127,8 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
+ tensorflow::gtl::optional<ComputationLayout> host_entry_computation_layout_;
+ tensorflow::gtl::optional<ComputationLayout> device_entry_computation_layout_;
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 54c34ce116..3367d76ded 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -194,6 +194,13 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
LOG(FATAL) << "unknown module";
}
+int64 HloModuleGroupMetadata::GetDeviceModulesCount() const {
+ return std::count_if(modules_.begin(), modules_.end(),
+ [](const HloModule* module) {
+ return !module->config().is_host_module();
+ });
+}
+
Status HloModuleGroupMetadata::RecordInstructions() {
const auto visitor = [this](HloInstruction* hlo) -> Status {
if (hlo->opcode() == HloOpcode::kWhile) {
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index c48a7ab0b5..d619082616 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -147,6 +147,9 @@ class HloModuleGroupMetadata {
// the module in the module vector.
int64 GetModuleId(const HloModule* module) const;
+ // Returns the number of modules for devices (excluding the host module).
+ int64 GetDeviceModulesCount() const;
+
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 76b3ecad26..eecbbcb93d 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -45,7 +45,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_entry_computation_layout());
+ hlo_module->device_entry_computation_layout());
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 2494569db5..cfa7ba5e81 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -909,22 +909,19 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
}
LayoutAssignment::LayoutAssignment(
- ComputationLayout* entry_computation_layout,
+ const ComputationLayout& entry_computation_layout,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
VLOG(1) << "entry computation layout given to layout assignment: "
- << entry_computation_layout_->ToString();
+ << entry_computation_layout_.ToString();
// Layouts of all parameter instructions must be set.
for (const ShapeLayout& parameter_layout :
- entry_computation_layout_->parameter_layouts()) {
+ entry_computation_layout_.parameter_layouts()) {
CHECK(parameter_layout.LayoutIsSet());
}
- // If the result layout is not set, then choose the default.
- // TODO(b/29118294): Choose a better layout in this case.
- if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
- entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
- }
+ // TODO(b/29118294): Choose a better layout if the result layout is not set.
+ CHECK(entry_computation_layout_.result_layout().LayoutIsSet());
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1597,7 +1594,7 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
}
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(
- *entry_computation_layout_, *points_to_analysis,
+ entry_computation_layout_, *points_to_analysis,
module->entry_computation(), channel_layout_constraints_));
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index ae4986d6ad..c83ae0388b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -288,7 +288,7 @@ class LayoutAssignment : public HloPassInterface {
// If channel_constraints is nullptr, no kSend or kRecvs must be contained
// within any module passed to `Run`.
explicit LayoutAssignment(
- ComputationLayout* entry_computation_layout,
+ const ComputationLayout& entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
tensorflow::StringPiece name() const override { return "layout-assignment"; }
@@ -402,7 +402,7 @@ class LayoutAssignment : public HloPassInterface {
// necessary conditions.
Status CheckLayouts(HloModule* module);
- ComputationLayout* entry_computation_layout_;
+ const ComputationLayout& entry_computation_layout_;
protected:
// Sets up the copy instruction according to the characteristic (sharding,
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 4b1c9bad41..7e1bb11eaa 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -53,7 +53,7 @@ class LayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout) {
- LayoutAssignment layout_assignment(entry_computation_layout);
+ LayoutAssignment layout_assignment(*entry_computation_layout);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -285,7 +285,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- LayoutAssignment layout_assignment(&computation_layout);
+ LayoutAssignment layout_assignment(computation_layout);
AssignLayouts(module.get(), &computation_layout);
// Layout assignment should have deep copied the result of the computation to
@@ -488,7 +488,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
public:
explicit OperandsMustBeTheSameLayoutAssignment(
ComputationLayout* entry_computation_layout)
- : LayoutAssignment(entry_computation_layout) {}
+ : LayoutAssignment(*entry_computation_layout) {}
protected:
Status PropagateBufferConstraint(
@@ -808,7 +808,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- LayoutAssignment layout_assignment(&computation_layout);
+ LayoutAssignment layout_assignment(computation_layout);
Status error_status = layout_assignment.Run(module.get()).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 6e0d07a12f..6ce03ab39d 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -91,6 +91,34 @@ tensorflow::Status RecordResult(const ShapedBuffer& result,
return tensorflow::Status::OK();
}
+// Records the arguments used to invoke a computation in an HloSnapshot proto.
+tensorflow::Status RecordArguments(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ se::StreamExecutor* executor, TransferManager* transfer_manager,
+ HloSnapshot* module) {
+ module->clear_arguments();
+ for (const ShapedBuffer* argument : arguments) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> literal,
+ transfer_manager->TransferLiteralFromDevice(executor, *argument));
+ *module->add_arguments() = literal->ToProto();
+ }
+ return tensorflow::Status::OK();
+}
+
+// Records the result of a computation in a HloSnapshot proto.
+tensorflow::Status RecordResult(const ShapedBuffer& result,
+ se::StreamExecutor* executor,
+ TransferManager* transfer_manager,
+ HloSnapshot* module) {
+ module->clear_result();
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> literal,
+ transfer_manager->TransferLiteralFromDevice(executor, result));
+ *module->mutable_result() = literal->ToProto();
+ return tensorflow::Status::OK();
+}
+
} // namespace
ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
@@ -268,8 +296,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ExecutionOptions* execution_options,
const UserComputation* user_computation) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
- auto* computation_layout = config->mutable_entry_computation_layout();
-
+ ComputationLayout* host_computation_layout =
+ config->mutable_host_entry_computation_layout();
+ ComputationLayout* device_computation_layout =
+ config->mutable_device_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
return InvalidArgument("computation takes %d parameters, but %zu given",
program_shape.parameters_size(),
@@ -294,9 +324,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
- TF_RETURN_IF_ERROR(
- computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
- *argument_shapes[i]));
+ TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i)
+ ->CopyLayoutFromShape(*argument_shapes[i]));
+ TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i)
+ ->CopyLayoutFromShape(*argument_shapes[i]));
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
@@ -305,10 +336,17 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout,
program_shape.result()));
TF_RETURN_IF_ERROR(
- computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ host_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ shape_with_output_layout));
+ TF_RETURN_IF_ERROR(
+ device_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
- computation_layout->mutable_result_layout()->Clear();
+ // If the result layout is not set, then choose the default.
+ // TODO(b/29118294): Allow the compiler to choose a better layout in this
+ // case.
+ host_computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ device_computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
config->set_replica_count(options_.number_of_replicas());
@@ -409,6 +447,28 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
DeviceMemoryAllocator* device_allocator) {
VLOG(1) << Printf("BuildExecutable on service %p", this);
+ // Dump computation proto state if flag is set.
+ std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
+ for (int64 i = 0; i < module_protos.size(); ++i) {
+ const string& directory_path =
+ module_configs[i]->debug_options().xla_dump_computations_to();
+ const string& execution_directory_path =
+ module_configs[i]->debug_options().xla_dump_executions_to();
+ if (directory_path.empty() && execution_directory_path.empty()) {
+ continue;
+ }
+ auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
+ if (!directory_path.empty()) {
+ string filename =
+ Printf("computation_%lld__%s", module_protos[i]->id(),
+ module_protos[i]->entry_computation_name().c_str());
+ TF_RETURN_IF_ERROR(
+ Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
+ hlo_snapshots.push_back(std::move(hlo_snapshot));
+ }
+ }
+
VLOG(1) << "Computations:";
for (const HloModuleProto* proto : module_protos) {
VLOG(1) << proto->name();
@@ -429,9 +489,31 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
backend->compiler()->Compile(std::move(modules), std::move(executors),
device_allocator));
+ for (size_t i = 0; i < module_protos.size(); ++i) {
+ if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {
+ executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i]));
+ }
+ }
+
return std::move(executables);
}
+Status Service::ValidateEntryComputationLayout(HloModule* module) {
+ const ComputationLayout& on_device =
+ module->device_entry_computation_layout();
+ for (int64 i = 0; i < on_device.parameter_count(); ++i) {
+ TF_RET_CHECK(ShapeUtil::Equal(
+ on_device.parameter_shape(i),
+ execute_backend_->transfer_manager()->HostShapeToDeviceShape(
+ module->host_entry_computation_layout().parameter_shape(i))));
+ }
+ TF_RET_CHECK(ShapeUtil::Equal(
+ module->device_entry_computation_layout().result_shape(),
+ execute_backend_->transfer_manager()->HostShapeToDeviceShape(
+ module->host_entry_computation_layout().result_shape())));
+ return tensorflow::Status::OK();
+}
+
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
@@ -470,6 +552,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
+ // Check that on-host and on-device shapes are consistent.
+ TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
@@ -542,9 +626,16 @@ Service::ExecuteParallelAndRegisterResult(
// profiled.
std::map<int64, se::Stream*> index_to_profiled_streams;
- TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
- backend->computation_placer()->AssignDevices(
- options_.number_of_replicas(), executables.size()));
+ // Build DeviceAssignment for all cores based on the provided device handles.
+ DeviceAssignment device_assignment(options_.number_of_replicas(),
+ executables.size());
+ for (int64 i = 0; i < executables.size(); i++) {
+ TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
+ CHECK_EQ(replicas.size(), arguments[i].size());
+ for (int64 replica = 0; replica < replicas.size(); ++replica) {
+ device_assignment(replica, i) = replicas[replica]->device_ordinal();
+ }
+ }
for (int64 i = 0; i < executables.size(); i++) {
// Stream executors for the replicas of the current computation.
@@ -826,7 +917,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
CreateModuleConfig(*program_shape, replicated_arguments.front(),
request.execution_options(), user_computation));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
// Adds to the vectors to build and execute the computations after the loop.
all_arguments.push_back(replicated_arguments);
@@ -929,7 +1020,7 @@ tensorflow::Status Service::ExecuteGraphParallel(
/*user_computation=*/nullptr));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
// Adds to the vectors to build and execute the computations after the loop.
all_arguments.push_back(replicated_arguments);
@@ -1079,7 +1170,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
arg->execution_options(), user_computation));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
TF_ASSIGN_OR_RETURN(
std::shared_ptr<Executable> executable,
@@ -1125,6 +1216,22 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
"BuildExecutable on service %p with serialized module proto: %s", this,
module_proto.name().c_str());
+ // Dump computation proto state if flag is set.
+ auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ const string& directory_path =
+ module_config->debug_options().xla_dump_computations_to();
+ const string& execution_directory_path =
+ module_config->debug_options().xla_dump_executions_to();
+ if (!directory_path.empty() || !execution_directory_path.empty()) {
+ *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
+ if (!directory_path.empty()) {
+ string filename = Printf("computation_%lld__%s", module_proto.id(),
+ module_proto.entry_computation_name().c_str());
+ TF_RETURN_IF_ERROR(
+ Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
+ }
+ }
+
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
@@ -1133,6 +1240,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
+ // Check that on-host and on-device shapes are consistent.
+ TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
@@ -1175,12 +1284,31 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
execute_backend_->default_stream_executor(),
/*device_allocator=*/nullptr));
+ if (executable->dumping_snapshot()) {
+ executable->hlo_snapshot()->set_execution_platform(
+ execute_backend_->platform()->Name());
+ TF_RETURN_IF_ERROR(RecordArguments(
+ replicated_arguments.front(),
+ execute_backend_->default_stream_executor(),
+ execute_backend_->transfer_manager(), executable->hlo_snapshot()));
+ }
+
TF_ASSIGN_OR_RETURN(
*result->mutable_output(),
ExecuteAndRegisterResult(
executable.get(), replicated_arguments, execute_backend_.get(),
"result of " + arg->computation().name(), result->mutable_profile()));
+ if (executable->dumping_snapshot()) {
+ TF_ASSIGN_OR_RETURN(
+ const ShapedBuffer* result_buffer,
+ allocation_tracker_.ResolveForReplica(result->output(), 0));
+ TF_RETURN_IF_ERROR(RecordResult(
+ *result_buffer, execute_backend_->default_stream_executor(),
+ execute_backend_->transfer_manager(), executable->hlo_snapshot()));
+ TF_RETURN_IF_ERROR(executable->DumpHloSnapshot());
+ }
+
VLOG(1) << "successfully completed 'execute-graph' request";
return tensorflow::Status::OK();
}
@@ -1215,7 +1343,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
arg->execution_options(), user_computation));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
- << module_config->entry_computation_layout().ToString();
+ << module_config->host_entry_computation_layout().ToString();
ExecutionProfile profile;
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 476bd0597d..f84fe407e0 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -295,6 +295,9 @@ class Service : public ServiceInterface {
const ExecutionOptions& execution_options,
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
+ // Assert that host- and device-shapes are in a consistent state.
+ Status ValidateEntryComputationLayout(HloModule* module);
+
protected:
friend class LocalExecutable;
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index d58baa3220..c330473cda 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -1472,4 +1473,26 @@ std::ostream& operator<<(std::ostream& out, const Shape& shape) {
return out;
}
+/*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
+ using tensorflow::hash;
+ using tensorflow::Hash64Combine;
+
+ size_t hash_value = hash<PrimitiveType>()(shape.element_type());
+
+ if (shape.tuple_shapes().empty()) {
+ for (int64 dim : shape.dimensions()) {
+ hash_value = Hash64Combine(hash_value, hash<int64>()(dim));
+ }
+
+ hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout()));
+ } else {
+ hash_value = 0;
+ for (const Shape& subshape : shape.tuple_shapes()) {
+ hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape));
+ }
+ }
+
+ return hash_value;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 5fa728e7c2..cb8bf5a2b9 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -650,6 +650,9 @@ class ShapeUtil {
.ok());
}
+ // Compute a hash for `shape`.
+ static size_t Hash(const Shape& shape);
+
private:
// Validates all of the non-layout properties of the shape -- this is a helper
// used by both the layout-optional and layout-required public method.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 840292010d..54cf0543b8 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -632,6 +632,7 @@ xla_test(
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc
index a5f6872c46..93d1c921c4 100644
--- a/tensorflow/compiler/xla/tests/filecheck.cc
+++ b/tensorflow/compiler/xla/tests/filecheck.cc
@@ -38,7 +38,7 @@ StatusOr<bool> RunFileCheck(const string& input, const string& pattern) {
TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern));
// Invoke FileCheck to check whether input matches `pattern`.
- const char* file_check_path_suffix = "external/llvm/FileCheck";
+ const char* file_check_path_suffix = "org_tensorflow/external/llvm/FileCheck";
string file_check_path;
if (const char* test_srcdir = getenv("TEST_SRCDIR")) {
file_check_path = JoinPath(test_srcdir, file_check_path_suffix);
@@ -66,6 +66,11 @@ StatusOr<bool> RunFileCheck(const string& input, const string& pattern) {
// the error message generated by FileCheck and the inputs.
bool succeeded = (exit_status == 0);
if (!succeeded) {
+ LOG(WARNING) << "Tried to execute FileCheck at " << file_check_path;
+ if (!env->FileExists(file_check_path).ok()) {
+ LOG(WARNING) << "NOTE: FileCheck binary does not exist!";
+ }
+
LOG(WARNING) << "FileCheck error: " << standard_error;
LOG(WARNING) << "FileCheck input was:";
XLA_LOG_LINES(tensorflow::WARNING, input);
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 6491208895..9539ae0680 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -177,9 +177,13 @@ class HloTestBase : public ::testing::Test {
// 'layout'.
void ForceParameterLayout(HloModule* module, int64 param_no,
const Layout& layout) {
- ASSERT_LT(param_no,
- module->mutable_entry_computation_layout()->parameter_count());
- module->mutable_entry_computation_layout()
+ ASSERT_LT(
+ param_no,
+ module->mutable_host_entry_computation_layout()->parameter_count());
+ module->mutable_host_entry_computation_layout()
+ ->mutable_parameter_layout(param_no)
+ ->ResetLayout(layout);
+ module->mutable_device_entry_computation_layout()
->mutable_parameter_layout(param_no)
->ResetLayout(layout);
}
@@ -187,7 +191,10 @@ class HloTestBase : public ::testing::Test {
// Convenience method to force the layout of the computation result in a
// module. The result layout of 'module' is set to 'layout'.
void ForceResultLayout(HloModule* module, const Layout& layout) {
- module->mutable_entry_computation_layout()
+ module->mutable_host_entry_computation_layout()
+ ->mutable_result_layout()
+ ->ResetLayout(layout);
+ module->mutable_device_entry_computation_layout()
->mutable_result_layout()
->ResetLayout(layout);
}
@@ -195,7 +202,10 @@ class HloTestBase : public ::testing::Test {
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {
- module->mutable_entry_computation_layout()
+ module->mutable_host_entry_computation_layout()
+ ->mutable_result_layout()
+ ->Clear();
+ module->mutable_device_entry_computation_layout()
->mutable_result_layout()
->Clear();
}
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
index 3023df47cd..2c45f19c09 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -62,8 +62,8 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const AotCompilationOptions& options,
const string& pattern, bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
- ASSERT_TRUE(
- CompileToAotCompilationResult(std::move(hlo_module), options).ok());
+ TF_ASSERT_OK(
+ CompileToAotCompilationResult(std::move(hlo_module), options).status());
ResetIrHook();
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index fdbfc0210e..1bb31ddb7b 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -303,12 +303,18 @@ bool HloParser::ParseComputations() {
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
+ ->mutable_parameter_layout(p)
+ ->CopyLayoutFromShape(param_shape));
+ TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
+ ->mutable_result_layout()
+ ->CopyLayoutFromShape(result_shape));
+ TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index adc8b1d620..4e085bc89c 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -1239,7 +1239,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
auto module = Parse(original);
TF_ASSERT_OK(module.status());
- auto program_layout = module.ValueOrDie()->entry_computation_layout();
+ auto program_layout = module.ValueOrDie()->host_entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1);
auto param_layout = program_layout.parameter_layout(0).layout();
auto result_layout = program_layout.result_layout().layout();
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index d23f9e5918..750d72d797 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -134,6 +134,8 @@ enum Format {
// example, Convert) are ignored.
//
// See the XLA documentation for more information on shapes and layouts.
+//
+// LINT.IfChange
message Layout {
// The method used to store the data in memory. The format determines which of
// the other fields are used by the layout.
@@ -159,9 +161,12 @@ message Layout {
// memory. This field must be unset unless the format is SPARSE.
int64 max_sparse_elements = 5;
- // Important: if any field is added, be sure to modify ShapeUtil::Equal()
- // appropriately to account for the new field.
+ // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
+ // LayoutUtil::Hash appropriately to account for the new field.
}
+// LINT.ThenChange( \
+// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \
+// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
// A shape describes the number of dimensions in the array, the size of each
// dimension, and the primitive component type.
@@ -170,6 +175,8 @@ message Layout {
// defined.
//
// See the XLA documentation for more information on shapes and layouts.
+//
+// LINT.IfChange
message Shape {
reserved 1;
reserved "rank";
@@ -190,9 +197,12 @@ message Shape {
// The layout used to back this shape.
Layout layout = 5;
- // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
- // ShapeUtil::Compatible() appropriately to account for the new field.
+ // Important: if any field is added, be sure to modify ShapeUtil::Equal(),
+ // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
+ // the new field.
}
+// LINT.ThenChange( \
+// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
// Shape of the parameters and output of a computation (like a traditional
// function signature).
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 91de82f0a7..1be1c96dd3 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -114,9 +114,9 @@ class BreakStatementTransformer(transformer.Base):
template,
var_name=break_var,
for_stmt=node)
- extra_cond = templates.replace_as_expression(
+ extra_test = templates.replace_as_expression(
'not var_name', var_name=break_var)
- anno.setanno(node[1], 'extra_cond', extra_cond)
+ anno.setanno(node[1], 'extra_test', extra_test)
return node
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 2e26cdb3d9..935a2786db 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -207,7 +207,7 @@ class ControlFlowTransformer(transformer.Base):
def body_name(state_ssf):
body
return state_ssf,
- state_ast_tuple = ag__.while_loop(
+ state_ast_tuple = ag__.while_stmt(
test_name, body_name, (state,), (extra_deps,))
"""
node = templates.replace(
@@ -252,31 +252,31 @@ class ControlFlowTransformer(transformer.Base):
state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
node_body = ast_util.rename_symbols(node.body, ssf_map)
- if anno.hasanno(node, 'extra_cond'):
- extra_cond = anno.getanno(node, 'extra_cond')
- extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+ if anno.hasanno(node, 'extra_test'):
+ extra_test = anno.getanno(node, 'extra_test')
+ extra_test = ast_util.rename_symbols(extra_test, ssf_map)
else:
- extra_cond = parser.parse_expression('True')
+ extra_test = parser.parse_expression('True')
template = """
- def extra_cond_name(state_ssf):
- return extra_cond_expr
+ def extra_test_name(state_ssf):
+ return extra_test_expr
def body_name(iterate, state_ssf):
body
return state_ssf,
- state_ast_tuple = ag__.for_loop(
- iterated, extra_cond_name, body_name, (state,))
+ state_ast_tuple = ag__.for_stmt(
+ iter_, extra_test_name, body_name, (state,))
"""
node = templates.replace(
template,
state=state,
state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
- iterated=node.iter,
+ iter_=node.iter,
iterate=node.target,
- extra_cond_name=self.context.namer.new_symbol('extra_cond',
+ extra_test_name=self.context.namer.new_symbol('extra_test',
all_referenced),
- extra_cond_expr=extra_cond,
+ extra_test_expr=extra_test,
body_name=self.context.namer.new_symbol('loop_body', all_referenced),
body=node_body)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 04b4734551..38b761d97d 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -19,11 +19,19 @@ conditionals and loops, implemented in functional form, using for example
closures for the body.
"""
+# Naming conventions:
+# * operator names match the name usually used for the respective Python
+# idiom; examples: for_stmt, list_append
+# * operator arguments match either of:
+# - the corresponding Python AST attribute (e.g. the condition of an if
+# statement is called test) if the operator represents an AST construct
+# - the names used in the Python docs, if the operator is a function (e.g.
+# list_ and x for append, see
+# https://docs.python.org/3.7/tutorial/datastructures.html)
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(mdan): Add a container for implementation-specific toggles (throughout).
-
-from tensorflow.contrib.autograph.operators.control_flow import for_loop
-from tensorflow.contrib.autograph.operators.control_flow import while_loop
+from tensorflow.contrib.autograph.operators.control_flow import for_stmt
+from tensorflow.contrib.autograph.operators.control_flow import while_stmt
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index d9d8b0d593..9f7202821f 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -25,44 +25,55 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
-# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature.
-# TODO(mdan): Rename arguments to match the AST names.
-
-def for_loop(iterated, extra_cond, loop_body, init_state):
+def for_stmt(iter_, extra_test, body, init_state):
"""Functional form of a for statement.
- The loop operates on a so-called state, which includes all symbols that are
- variant across loop iterations, excluding the iterate. In what follows we
- refer to state as either a tuple of entities that represent an actual state,
- or a list of arguments of the corresponding types.
+ The loop operates on a state, which includes all symbols that are
+ variant across loop iterations, excluding the iterate as well as the
+ variables local to the loop.
+
+ For example, given the loop below that calculates the geometric and
+ arithmetic means or some numbers:
+
+ geo_mean = 1
+ arith_mean = 0
+ for i in range(n):
+ a = numbers[i]
+ geo_mean *= a
+ arith_mean += a
+
+ The state is represented by the variables geo_mean and arith_mean. The
+ argument for initial_state may contain the tuple (1, 0), the body will
+ include the arguments geo_mean and arith_mean and will return a tuple
+ representing the new values for geo_mean and respectively arith_mean.
Args:
- iterated: The entity being iterated over.
- extra_cond: Callable with the state as arguments, and boolean return type.
+ iter_: The entity being iterated over.
+ extra_test: Callable with the state as arguments, and boolean return type.
An additionnal loop condition.
- loop_body: Callable with the iterate and the state as arguments, and
+ body: Callable with the iterate and the state as arguments, and
state as return type. The actual loop body.
init_state: Tuple containing the initial state.
Returns:
Tuple containing the final state.
"""
- if tensor_util.is_tensor(iterated):
- return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
- elif isinstance(iterated, dataset_ops.Dataset):
- return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+ if tensor_util.is_tensor(iter_):
+ return _known_len_for_stmt(iter_, extra_test, body, init_state)
+ elif isinstance(iter_, dataset_ops.Dataset):
+ return _dataset_for_stmt(iter_, extra_test, body, init_state)
else:
- return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+ return _py_for_stmt(iter_, extra_test, body, init_state)
-def _py_for_loop(iterated, extra_cond, loop_body, init_state):
- """Overload of for_loop that executes a Python for loop."""
+def _py_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that executes a Python for loop."""
state = init_state
- for iterate in iterated:
- if not extra_cond(*state):
+ for target in iter_:
+ if not extra_test(*state):
break
- state = loop_body(iterate, *state)
+ state = body(target, *state)
# TODO(mdan): Remove this special case.
if len(state) == 1:
@@ -70,23 +81,23 @@ def _py_for_loop(iterated, extra_cond, loop_body, init_state):
return state
-def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
- """Overload of for_loop that iterates over objects that define a length."""
- n = builtins.dynamic_len(iterated)
+def _known_len_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over objects that define a length."""
+ n = builtins.dynamic_len(iter_)
def while_body(iterate_index, *state):
- iterate = iterated[iterate_index]
- new_state = loop_body(iterate, *state)
+ iterate = iter_[iterate_index]
+ new_state = body(iterate, *state)
return (iterate_index + 1,) + new_state
def while_cond(iterate_index, *state):
- return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+ return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))
- results = while_loop(
+ results = while_stmt(
while_cond,
while_body,
init_state=(0,) + init_state,
- extra_deps=(iterated,),
+ extra_deps=(iter_,),
opts=dict(maximum_iterations=n))
# Dropping the iteration index because it's not syntactically visible.
results = results[1:]
@@ -97,8 +108,8 @@ def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
return results
-def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
- """Overload of for_loop that iterates over TF Datasets."""
+def _dataset_for_stmt(ds, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over TF Datasets."""
# Because Datsets only expose get_next, in the style of Python iterators,
# we are forced to unpack the loop as:
#
@@ -117,15 +128,15 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
epoch_number, iterate = iterator.get_next()
def while_body(epoch_number, iterate, *state):
- new_state = loop_body(iterate, *state)
+ new_state = body(iterate, *state)
epoch_number, iterate = iterator.get_next()
return (epoch_number, iterate) + new_state
def while_cond(epoch_number, iterate, *state):
del iterate
- return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+ return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state))
- results = while_loop(
+ results = while_stmt(
while_cond,
while_body,
init_state=(epoch_number, iterate) + init_state,
@@ -140,7 +151,7 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
return results
-def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
+def while_stmt(test, body, init_state, extra_deps, opts=None):
"""Functional form of a while statement.
The loop operates on a so-called state, which includes all symbols that are
@@ -149,13 +160,13 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
of the corresponding types.
Args:
- loop_cond: Callable with the state as arguments, and boolean return type.
+ test: Callable with the state as arguments, and boolean return type.
The loop condition.
- loop_body: Callable with the state as arguments, and state as return type.
+ body: Callable with the state as arguments, and state as return type.
The actual loop body.
init_state: Tuple containing the initial state.
extra_deps: Tuple containing additional entities on which the loop may
- depend, such as loop invariants referenced by loop_cond. Used
+ depend, such as loop invariants referenced by test. Used
exclusively for dispatch control.
opts: Optional dict of extra loop parameters.
@@ -166,24 +177,24 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
# That could be somethins as simple as a collection of dispatch rules, with
# some prioritization.
if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
- return _tf_while_loop(loop_cond, loop_body, init_state, opts)
+ return _tf_while_stmt(test, body, init_state, opts)
else:
- return _py_while_loop(loop_cond, loop_body, init_state, opts)
+ return _py_while_stmt(test, body, init_state, opts)
-def _tf_while_loop(loop_cond, loop_body, init_state, opts):
- """Overload of while_loop that stages a TF while_loop."""
+def _tf_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that stages a TF while_stmt."""
if opts is None:
opts = {}
- return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts)
+ return control_flow_ops.while_loop(test, body, init_state, **opts)
-def _py_while_loop(loop_cond, loop_body, init_state, opts):
- """Overload of while_loop that executes a Python while loop."""
+def _py_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that executes a Python while loop."""
del opts
state = init_state
- while loop_cond(*state):
- state = loop_body(*state)
+ while test(*state):
+ state = body(*state)
return state
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
index a0cd0bfa82..b14d7edba3 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -29,28 +29,28 @@ from tensorflow.python.platform import test
class ForLoopTest(test.TestCase):
def test_tensor(self):
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
constant_op.constant([1, 2, 3, 4]),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
with self.test_session() as sess:
self.assertEqual((10,), sess.run(s))
def test_python(self):
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
range(5),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
self.assertEqual(10, s)
def test_dataset(self):
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
dataset_ops.Dataset.range(5).map(to_int32),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
with self.test_session() as sess:
self.assertEqual((10,), sess.run(s))
@@ -60,9 +60,9 @@ class WhileLoopTest(test.TestCase):
def test_tensor(self):
n = constant_op.constant(5)
- results = control_flow.while_loop(
- loop_cond=lambda i, s: i < n,
- loop_body=lambda i, s: (i + 1, s + i,),
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i,),
init_state=(0, 0),
extra_deps=(n,))
with self.test_session() as sess:
@@ -70,9 +70,9 @@ class WhileLoopTest(test.TestCase):
def test_python(self):
n = 5
- results = control_flow.while_loop(
- loop_cond=lambda i, s: i < n,
- loop_body=lambda i, s: (i + 1, s + i),
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i),
init_state=(0, 0),
extra_deps=(n,))
self.assertEqual((5, 10), results)
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index d2beff849e..2d2cbdc199 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -52,6 +52,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
@@ -147,7 +148,9 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
# partition function.
forward_cell = CrfForwardRnnCell(transition_params)
# Sequence length is not allowed to be less than zero.
- sequence_lengths_less_one = math_ops.maximum(0, sequence_lengths - 1)
+ sequence_lengths_less_one = math_ops.maximum(
+ constant_op.constant(0, dtype=sequence_lengths.dtype),
+ sequence_lengths - 1)
_, alphas = rnn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 55a56b83a8..bd3e034211 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -35,6 +36,179 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
+class GroupByReducerTest(test.TestCase):
+
+ def checkResults(self, dataset, shapes, values):
+ self.assertEqual(shapes, dataset.output_shapes)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ for expected in values:
+ got = sess.run(get_next)
+ self.assertEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSum(self):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(lambda x: x % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testAverage(self):
+
+ def reduce_fn(x, y):
+ return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+ x[1] + 1), x[1] + 1
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: (0.0, 0.0),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x: x[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(
+ lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+ def testConcat(self):
+ components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+ reducer = grouping.Reducer(
+ init_func=lambda x: "",
+ reduce_func=lambda x, y: x + y[0],
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(components),
+ dataset_ops.Dataset.range(2 * i))).apply(
+ grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+ self.checkResults(
+ dataset,
+ shapes=tensor_shape.scalar(),
+ values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+ def testSparseSum(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1], dtype=np.int64)),
+ dense_shape=np.array([1, 1]))
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: _sparse(np.int64(0)),
+ reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+ finalize_func=lambda x: x.values[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+ grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testChangingStateShape(self):
+
+ def reduce_fn(x, _):
+ # Statically known rank, but dynamic length.
+ larger_dim = array_ops.concat([x[0], x[0]], 0)
+ # Statically unknown rank.
+ larger_rank = array_ops.expand_dims(x[1], 0)
+ return larger_dim, larger_rank
+
+ reducer = grouping.Reducer(
+ init_func=lambda x: ([0], 1),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x: x)
+
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+ grouping.group_by_reducer(lambda x: x, reducer))
+ self.assertEqual([None], dataset.output_shapes[0].as_list())
+ self.assertIs(None, dataset.output_shapes[1].ndims)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual([0] * (2**i), x)
+ self.assertAllEqual(np.array(1, ndmin=i), y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testTypeMismatch(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+ reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+ # TODO(b/78665031): Remove once non-scalar keys are supported.
+ def testInvalidKeyShape(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+ # TODO(b/78665031): Remove once non-int64 keys are supported.
+ def testInvalidKeyType(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+
+class GroupByReducerSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ return dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_reducer(lambda x: x % 5, reducer))
+
+ def testCoreGroupByReducer(self):
+ components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 5,
+ verify_exhausted=True)
+
+
class GroupByWindowTest(test.TestCase):
def testSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 1a97a84b2c..eb2ceff893 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -35,15 +36,19 @@ from tensorflow.python.platform import test
class ScanDatasetTest(test.TestCase):
- def _count(self, start, step):
- return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
- scan_ops.scan(start, lambda state, _: (state + step, state)))
+ def _counting_dataset(self, start, scan_fn):
+ return dataset_ops.Dataset.from_tensors(0).repeat().apply(
+ scan_ops.scan(start, scan_fn))
def testCount(self):
+ def make_scan_fn(step):
+ return lambda state, _: (state + step, state)
+
start = array_ops.placeholder(dtypes.int32, shape=[])
step = array_ops.placeholder(dtypes.int32, shape=[])
take = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = self._count(start, step).take(take).make_initializable_iterator()
+ iterator = self._counting_dataset(
+ start, make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
with self.test_session() as sess:
@@ -78,6 +83,37 @@ class ScanDatasetTest(test.TestCase):
self.assertEqual(5, self.evaluate(next_element()))
self.assertEqual(8, self.evaluate(next_element()))
+ def testSparseCount(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def make_scan_fn(step):
+ return lambda state, _: (_sparse(state.values[0] + step), state)
+
+ start = array_ops.placeholder(dtypes.int32, shape=[])
+ step = array_ops.placeholder(dtypes.int32, shape=[])
+ take = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = self._counting_dataset(
+ _sparse(start),
+ make_scan_fn(step)).take(take).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+
+ for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+ (10, 2, 10), (10, -1, 10),
+ (10, -2, 10)]:
+ sess.run(iterator.initializer,
+ feed_dict={start: start_val, step: step_val, take: take_val})
+ for expected, _ in zip(
+ itertools.count(start_val, step_val), range(take_val)):
+ self.assertEqual(expected, sess.run(next_element).values[0])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testChangingStateShape(self):
# Test the fixed-point shape invariant calculations: start with
# initial values with known shapes, and use a scan function that
@@ -132,7 +168,7 @@ class ScanDatasetTest(test.TestCase):
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-class ScanDatasetSerialzationTest(
+class ScanDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_elements):
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 2152bcde84..42ec2b0b01 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -364,7 +364,7 @@ class _RestructuredDataset(dataset_ops.Dataset):
with the structure of `dataset`.
"""
super(_RestructuredDataset, self).__init__()
- self._dataset = dataset
+ self._input_dataset = dataset
if not allow_unsafe_cast:
# Validate that the types are compatible.
@@ -408,7 +408,7 @@ class _RestructuredDataset(dataset_ops.Dataset):
self._output_classes = output_classes
def _as_variant_tensor(self):
- return self._dataset._as_variant_tensor() # pylint: disable=protected-access
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
@property
def output_classes(self):
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 0531f9cbb9..ea229b5b27 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -33,6 +34,35 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
+def group_by_reducer(key_func, reducer):
+ """A transformation that groups elements and performs a reduction.
+
+ This transformation maps element of a dataset to a key using `key_func` and
+ groups the elements by key. The `reducer` is used to process each group; its
+ `init_func` is used to initialize state for each group when it is created, the
+ `reduce_func` is used to update the state every time an element is mapped to
+ the matching group, and the `finalize_func` is used to map the final state to
+ an output value.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reducer: An instance of `Reducer`, which captures the reduction logic using
+ the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return GroupByReducerDataset(dataset, key_func, reducer)
+
+ return _apply_fn
+
+
def group_by_window(key_func,
reduce_func,
window_size=None,
@@ -227,6 +257,250 @@ class _VariantDataset(dataset_ops.Dataset):
return self._output_types
+class GroupByReducerDataset(dataset_ops.Dataset):
+ """A `Dataset` that groups its input and performs a reduction."""
+
+ def __init__(self, input_dataset, key_func, reducer):
+ """See `group_by_reducer()` for details."""
+ super(GroupByReducerDataset, self).__init__()
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_init_func(reducer.init_func)
+ self._make_reduce_func(reducer.reduce_func, input_dataset)
+ self._make_finalize_func(reducer.finalize_func)
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+
+ @function.Defun(*nest.flatten(
+ sparse.as_dense_types(input_dataset.output_types,
+ input_dataset.output_classes)))
+ def tf_key_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ # Pass in shape information from the input_dataset.
+ dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
+ input_dataset.output_classes)
+ for arg, shape in zip(args, nest.flatten(dense_shapes)):
+ arg.set_shape(shape)
+
+ nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
+ nested_args = sparse.deserialize_sparse_tensors(
+ nested_args, input_dataset.output_types, input_dataset.output_shapes,
+ input_dataset.output_classes)
+ # pylint: disable=protected-access
+ if dataset_ops._should_unpack_args(nested_args):
+ ret = key_func(*nested_args)
+ # pylint: enable=protected-access
+ else:
+ ret = key_func(nested_args)
+ ret = ops.convert_to_tensor(ret)
+ if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
+ return ret
+
+ self._key_func = tf_key_func
+ self._key_func.add_to_graph(ops.get_default_graph())
+
+ def _make_init_func(self, init_func):
+ """Make wrapping Defun for init_func."""
+
+ @function.Defun(dtypes.int64)
+ def tf_init_func(key):
+ """A wrapper for Defun that facilitates shape inference."""
+ key.set_shape([])
+ ret = init_func(key)
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ self._state_classes = sparse.get_classes(ret)
+ self._state_shapes = nest.pack_sequence_as(
+ ret, [t.get_shape() for t in nest.flatten(ret)])
+ self._state_types = nest.pack_sequence_as(
+ ret, [t.dtype for t in nest.flatten(ret)])
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ self._init_func = tf_init_func
+ self._init_func.add_to_graph(ops.get_default_graph())
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ # Create a list in which `tf_reduce_func` will store the new shapes.
+ flat_new_state_shapes = []
+
+ @function.Defun(*(nest.flatten(
+ sparse.as_dense_types(
+ self._state_types, self._state_classes)) + nest.flatten(
+ sparse.as_dense_types(input_dataset.output_types,
+ input_dataset.output_classes))))
+ def tf_reduce_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ for arg, shape in zip(
+ args,
+ nest.flatten(
+ sparse.as_dense_shapes(self._state_shapes, self._state_classes))
+ + nest.flatten(
+ sparse.as_dense_shapes(input_dataset.output_shapes,
+ input_dataset.output_classes))):
+ arg.set_shape(shape)
+
+ pivot = len(nest.flatten(self._state_shapes))
+ nested_state_args = nest.pack_sequence_as(self._state_types,
+ args[:pivot])
+ nested_state_args = sparse.deserialize_sparse_tensors(
+ nested_state_args, self._state_types, self._state_shapes,
+ self._state_classes)
+ nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
+ args[pivot:])
+ nested_input_args = sparse.deserialize_sparse_tensors(
+ nested_input_args, input_dataset.output_types,
+ input_dataset.output_shapes, input_dataset.output_classes)
+
+ ret = reduce_func(nested_state_args, nested_input_args)
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ # Extract shape information from the returned values.
+ flat_new_state = nest.flatten(ret)
+ flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])
+
+ # Extract and validate type information from the returned values.
+ for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
+ if t.dtype != dtype:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types,
+ nest.pack_sequence_as(self._state_types,
+ [t.dtype for t in flat_new_state])))
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret,
+ [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ # Use the private method that will execute `tf_reduce_func` but delay
+ # adding it to the graph in case we need to rerun the function.
+ tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access
+
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ weakened_state_shapes = [
+ old.most_specific_compatible_shape(new)
+ for old, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for old_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if old_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ old_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._reduce_func = tf_reduce_func
+ self._reduce_func.add_to_graph(ops.get_default_graph())
+
+ def _make_finalize_func(self, finalize_func):
+ """Make wrapping Defun for finalize_func."""
+
+ @function.Defun(*(nest.flatten(
+ sparse.as_dense_types(self._state_types, self._state_classes))))
+ def tf_finalize_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ for arg, shape in zip(
+ args,
+ nest.flatten(
+ sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
+ arg.set_shape(shape)
+
+ nested_args = nest.pack_sequence_as(self._state_types, args)
+ nested_args = sparse.deserialize_sparse_tensors(
+ nested_args, self._state_types, self._state_shapes,
+ self._state_classes)
+
+ ret = finalize_func(nested_args)
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ self._output_classes = sparse.get_classes(ret)
+ self._output_shapes = nest.pack_sequence_as(
+ ret, [t.get_shape() for t in nest.flatten(ret)])
+ self._output_types = nest.pack_sequence_as(
+ ret, [t.dtype for t in nest.flatten(ret)])
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ self._finalize_func = tf_finalize_func
+ self._finalize_func.add_to_graph(ops.get_default_graph())
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_reducer_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._init_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._finalize_func.captured_inputs,
+ key_func=self._key_func,
+ init_func=self._init_func,
+ reduce_func=self._reduce_func,
+ finalize_func=self._finalize_func,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+
class GroupByWindowDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
@@ -336,3 +610,30 @@ class GroupByWindowDataset(dataset_ops.Dataset):
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+
+class Reducer(object):
+ """A reducer is used for reducing a set of elements.
+
+ A reducer is represented as a tuple of the three functions:
+ 1) initialization function: key => initial state
+ 2) reduce function: (old state, input) => new state
+ 3) finalization function: state => result
+ """
+
+ def __init__(self, init_func, reduce_func, finalize_func):
+ self._init_func = init_func
+ self._reduce_func = reduce_func
+ self._finalize_func = finalize_func
+
+ @property
+ def init_func(self):
+ return self._init_func
+
+ @property
+ def reduce_func(self):
+ return self._reduce_func
+
+ @property
+ def finalize_func(self):
+ return self._finalize_func
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 60ef7efba4..e911ad0fa0 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -24,6 +24,7 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
@@ -36,18 +37,22 @@ class _ScanDataset(dataset_ops.Dataset):
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
self._initial_state = nest.pack_sequence_as(initial_state, [
- ops.convert_to_tensor(t, name="component_%d" % i)
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
for i, t in enumerate(nest.flatten(initial_state))
])
- # Compute initial values for the state shapes and types based on
- # the initial state. These will be refined by running
- # `tf_scan_func` one or more times below.
- # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state. The shapes may be refined by running `tf_scan_func` one
+ # or more times below.
+ self._state_classes = sparse.get_classes(self._initial_state)
self._state_shapes = nest.pack_sequence_as(
self._initial_state,
- [t.shape for t in nest.flatten(self._initial_state)])
+ [t.get_shape() for t in nest.flatten(self._initial_state)])
self._state_types = nest.pack_sequence_as(
self._initial_state,
[t.dtype for t in nest.flatten(self._initial_state)])
@@ -62,67 +67,102 @@ class _ScanDataset(dataset_ops.Dataset):
need_to_rerun = True
while need_to_rerun:
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_state_types = nest.flatten(self._state_types)
-
- # Create a list in which `tf_scan_func` will store the s
+ # Create a list in which `tf_scan_func` will store the new shapes.
flat_new_state_shapes = []
- @function.Defun(*(flat_state_types + nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes))))
+ @function.Defun(*(nest.flatten(
+ sparse.as_dense_types(
+ self._state_types, self._state_classes)) + nest.flatten(
+ sparse.as_dense_types(input_dataset.output_types,
+ input_dataset.output_classes))))
def tf_scan_func(*args):
"""A wrapper for Defun that facilitates shape inference."""
# Pass in shape information from the state and input_dataset.
- # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args,
- flat_state_shapes + nest.flatten(dense_shapes)):
+ for arg, shape in zip(
+ args,
+ nest.flatten(
+ sparse.as_dense_shapes(self._state_shapes, self._state_classes))
+ + nest.flatten(
+ sparse.as_dense_shapes(input_dataset.output_shapes,
+ input_dataset.output_classes))):
arg.set_shape(shape)
- pivot = len(flat_state_shapes)
- old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
- input_value = nest.pack_sequence_as(input_dataset.output_types,
- args[pivot:])
-
- ret = scan_func(old_state, input_value)
+ pivot = len(nest.flatten(self._state_shapes))
+ print(self._state_classes)
+ nested_state_args = nest.pack_sequence_as(self._state_types,
+ args[:pivot])
+ nested_state_args = sparse.deserialize_sparse_tensors(
+ nested_state_args, self._state_types, self._state_shapes,
+ self._state_classes)
+ print(input_dataset.output_classes)
+ nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
+ args[pivot:])
+ nested_input_args = sparse.deserialize_sparse_tensors(
+ nested_input_args, input_dataset.output_types,
+ input_dataset.output_shapes, input_dataset.output_classes)
+
+ ret = scan_func(nested_state_args, nested_input_args)
if not isinstance(ret, collections.Sequence) or len(ret) != 2:
raise TypeError("The scan function must return a pair comprising the "
"new state and the output value.")
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
new_state, output_value = ret
- flat_new_state = [
- ops.convert_to_tensor(t) for t in nest.flatten(new_state)
- ]
- flat_output_value = [
- ops.convert_to_tensor(t) for t in nest.flatten(output_value)
- ]
+ # Extract and validate class information from the returned values.
+ for t, clazz in zip(
+ nest.flatten(new_state), nest.flatten(self._state_classes)):
+ if not isinstance(t, clazz):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes,
+ nest.pack_sequence_as(
+ self._state_types,
+ [type(t) for t in nest.flatten(new_state)])))
+ self._output_classes = sparse.get_classes(output_value)
# Extract shape information from the returned values.
- flat_new_state_shapes.extend([t.shape for t in flat_new_state])
+ flat_new_state_shapes.extend(
+ [t.get_shape() for t in nest.flatten(new_state)])
self._output_shapes = nest.pack_sequence_as(
- output_value, [t.shape for t in flat_output_value])
+ output_value, [t.get_shape() for t in nest.flatten(output_value)])
# Extract and validate type information from the returned values.
- for t, dtype in zip(flat_new_state, flat_state_types):
+ for t, dtype in zip(
+ nest.flatten(new_state), nest.flatten(self._state_types)):
if t.dtype != dtype:
raise TypeError(
"The element types for the new state must match the initial "
"state. Expected %s; got %s." %
- (self._state_types, nest.pack_sequence_as(
- self._state_types, [t.dtype for t in flat_new_state])))
- self._output_classes = nest.pack_sequence_as(
- output_value, [ops.Tensor for _ in flat_output_value])
+ (self._state_types,
+ nest.pack_sequence_as(
+ self._state_types,
+ [t.dtype for t in nest.flatten(new_state)])))
self._output_types = nest.pack_sequence_as(
- output_value, [t.dtype for t in flat_output_value])
-
- return flat_new_state + flat_output_value
+ output_value, [t.dtype for t in nest.flatten(output_value)])
+
+ # Serialize any sparse tensors.
+ new_state = nest.pack_sequence_as(new_state, [
+ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
+ ])
+ output_value = nest.pack_sequence_as(output_value, [
+ t for t in nest.flatten(
+ sparse.serialize_sparse_tensors(output_value))
+ ])
+ return nest.flatten(new_state) + nest.flatten(output_value)
# Use the private method that will execute `tf_scan_func` but delay
# adding it to the graph in case we need to rerun the function.
tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access
+ flat_state_shapes = nest.flatten(self._state_shapes)
weakened_state_shapes = [
original.most_specific_compatible_shape(new)
for original, new in zip(flat_state_shapes, flat_new_state_shapes)
@@ -150,7 +190,7 @@ class _ScanDataset(dataset_ops.Dataset):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return gen_dataset_ops.scan_dataset(
input_t,
- nest.flatten(self._initial_state),
+ nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
self._scan_func.captured_inputs,
f=self._scan_func,
output_types=nest.flatten(
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index c2834d8226..cdb3a8d65e 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -42,6 +42,7 @@ cuda_py_test(
srcs = ["values_test.py"],
additional_deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":values",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
@@ -57,6 +58,9 @@ cuda_py_test(
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:model_fn",
],
+ tags = [
+ "no_pip",
+ ],
)
py_library(
@@ -217,6 +221,24 @@ cuda_py_test(
)
py_library(
+ name = "multi_worker_test_base",
+ testonly = 1,
+ srcs = ["multi_worker_test_base.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:distributed_framework_test_lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
+py_library(
name = "step_fn",
srcs = ["step_fn.py"],
visibility = ["//tensorflow:internal"],
@@ -479,3 +501,34 @@ cuda_py_test(
"//tensorflow/python/data/ops:iterator_ops",
],
)
+
+py_library(
+ name = "input_ops",
+ srcs = ["input_ops.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+cuda_py_test(
+ name = "input_ops_test",
+ srcs = ["input_ops_test.py"],
+ additional_deps = [
+ ":input_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python:util",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index cff717db80..c6a1bf6a9f 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -53,15 +53,14 @@ def _validate_value_destination_pairs(value_destination_pairs):
return True
+# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps.
def _get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices)
elif isinstance(destinations, six.string_types):
- return [device_util.canonicalize(destinations)]
+ return [device_util.resolve(destinations)]
else:
- return [
- device_util.canonicalize(destination) for destination in destinations
- ]
+ return [device_util.resolve(destination) for destination in destinations]
def _devices_match(left, right):
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py
new file mode 100644
index 0000000000..1f24f62947
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/input_ops.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Input-pipeline utilities for Distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging
+
+# TODO(priyag): Any other reader datasets to consider here?
+_READER_DATASET_OPS = [
+ "TextLineDataset",
+ "TFRecordDataset",
+ "FixedLengthRecordDataset"
+]
+
+
+# pylint: disable=protected-access
+def auto_shard_dataset(dataset, num_shards, index):
+ """Shard the input pipeline by sharding the underlying list of files.
+
+ Args:
+ dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
+ dataset transformations.
+ num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ shards operating in parallel. Same usage as in `Dataset.shard`.
+ index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
+ Same usage as in `Dataset.shard`.
+
+ Returns:
+ A modified `Dataset` obtained by updating the pipeline sharded by the
+ files.
+
+ Raises:
+ NotImplementedError: If we cannot automatically determine a good way to
+ shard the input dataset.
+ """
+
+ # TODO(priyag): Clone datasets instead of updating in place, similar to the
+ # clone method for TFRecordDataset.
+ def _auto_shard_impl(dataset, found_reader_op):
+ """Recursive implementation of auto sharding."""
+
+ if not found_reader_op:
+ # TODO(priyag): Make this check more robust by enforcing some common
+ # property on reader datasets.
+ if (isinstance(dataset, readers.TextLineDataset) or
+ isinstance(dataset, readers.FixedLengthRecordDataset)):
+ filenames_tensor = dataset._filenames
+ num_files = array_ops.size(filenames_tensor)
+ sharded_filenames_tensor = array_ops.gather(
+ filenames_tensor, math_ops.range(index, num_files, num_shards))
+ dataset._filenames = sharded_filenames_tensor
+ return dataset
+ elif isinstance(dataset, readers.TFRecordDataset):
+ # `TFRecordDataset` needs to be handled separately than other readers
+ # because it converts filenames to a dataset first. Also, we clone it
+ # instead of updating in place because it has special logic in the
+ # constructor. Eventually we will change all cases to clone datasets
+ # instead of updating in-place.
+ return dataset._clone(
+ filenames=dataset._filenames.shard(num_shards, index))
+ elif hasattr(dataset, "_map_func"):
+ # TODO(priyag): Make this check more robust by enforcing some common
+ # property on all map/flatmap/interleave datasets.
+ map_func_def = dataset._map_func.definition
+ for node in map_func_def.node_def:
+ if node.op in _READER_DATASET_OPS:
+ found_reader_op = True
+ break
+ elif node.op == "FlatMapDataset":
+ # TODO(priyag): Should this check for other map datasets? Should it
+ # be recursive? It is too specific to implementation of
+ # TFRecordDataset right now.
+ nested_func_name = node.attr["f"].func.name
+ nested_func = ops.get_default_graph()._functions[nested_func_name]
+ for nested_node in nested_func.definition.node_def:
+ if nested_node.op in _READER_DATASET_OPS:
+ found_reader_op = True
+ break
+ if found_reader_op:
+ break
+ if found_reader_op:
+ dataset._input_dataset = _auto_shard_impl(
+ dataset._input_dataset, found_reader_op)
+ return dataset
+
+ # TODO(priyag): Make _input_dataset(s) a common property of all datasets to
+ # make this check more robust.
+ if hasattr(dataset, "_input_dataset"):
+ dataset._input_dataset = _auto_shard_impl(
+ dataset._input_dataset, found_reader_op)
+ if hasattr(dataset, "_dataset_to_concatenate"):
+ # Special case for `ConcatentateDataset`. We want to shard all input
+ # datasets.
+ dataset._dataset_to_concatenate = _auto_shard_impl(
+ dataset._dataset_to_concatenate, found_reader_op)
+ return dataset
+
+ if hasattr(dataset, "_datasets"):
+ # Special case for `ZipDataset`.
+ dataset._datasets = nest.pack_sequence_as(dataset._datasets, [
+ _auto_shard_impl(ds, found_reader_op)
+ for ds in nest.flatten(dataset._datasets)
+ ])
+ return dataset
+
+ if not found_reader_op:
+ tf_logging.warn(
+ "Could not find a standard reader in the input pipeline"
+ "(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)."
+ "Falling back to sharding the dataset anyway. Please verify"
+ "correctness of auto-sharding for your input.")
+
+ # TODO(priyag): What do we want to do if the number of filenames is
+ # uneven in the number of shards? By default, this will just return as
+ # many items it can before throwing OutOfRangeError.
+ # TODO(priyag): This will shard the filenames before any shuffling of the
+ # filename dataset. It might be desirable to shard after shuffling
+ # filenames? If so, how do we achieve that?
+ return dataset.shard(num_shards, index)
+
+ return _auto_shard_impl(dataset=dataset, found_reader_op=False)
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
new file mode 100644
index 0000000000..16179c3a49
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -0,0 +1,265 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for input pipeline modifications for distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.distribute.python import input_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import errors
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class AutoShardDatasetTest(test.TestCase):
+
+ def setUp(self):
+ super(AutoShardDatasetTest, self).setUp()
+ self._num_files = 10
+ self._num_records = 4
+ self._num_shards = 2
+ self._shard_index = 0
+ self._record_bytes = 10
+
+ def _record(self, r, f):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _text_line(self, r, f):
+ return compat.as_bytes("Text line %d of file %d" % (r, f))
+
+ def _fixed_length_record(self, r, f):
+ return compat.as_bytes(str((r * f) % 10) * self._record_bytes)
+
+ def _createTFRecordFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ record = self._record(j, i)
+ writer.write(record)
+ writer.close()
+ return filenames
+
+ def _createTextFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
+ filenames.append(fn)
+ contents = []
+ for j in range(self._num_records):
+ contents.append(self._text_line(j, i))
+ if j + 1 != self._num_records or i == 0:
+ contents.append(b"\r\n")
+ contents = b"".join(contents)
+
+ with open(fn, "wb") as f:
+ f.write(contents)
+ return filenames
+
+ def _createFixedLengthRecordFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
+ filenames.append(fn)
+ with open(fn, "wb") as f:
+ for j in range(self._num_records):
+ f.write(self._fixed_length_record(j, i))
+ return filenames
+
+ def _verifySimpleShardingOutput(self, dataset, record_fn):
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ with self.test_session() as sess:
+ for f in range(self._shard_index, self._num_files, self._num_shards):
+ for r in range(self._num_records):
+ self.assertAllEqual(record_fn(r, f), sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testTFRecordDataset(self):
+ dataset = readers.TFRecordDataset(self._createTFRecordFiles())
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ self._verifySimpleShardingOutput(dataset, self._record)
+
+ def testFlatMap(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ self._createTFRecordFiles())
+ dataset = dataset.flat_map(readers.TFRecordDataset)
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ self._verifySimpleShardingOutput(dataset, self._record)
+
+ def testInterleave(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ self._createTFRecordFiles())
+ dataset = dataset.interleave(
+ readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ # Since block_length == num records in each file, the output will still
+ # contain records in order of files.
+ self._verifySimpleShardingOutput(dataset, self._record)
+
+ def testParallelInterleave(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ self._createTFRecordFiles())
+ dataset = dataset.apply(interleave_ops.parallel_interleave(
+ readers.TFRecordDataset,
+ cycle_length=4,
+ block_length=self._num_records))
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ # Since block_length == num records in each file, the output will still
+ # contain records in order of files.
+ self._verifySimpleShardingOutput(dataset, self._record)
+
+ def testListfiles(self):
+ filenames = self._createTFRecordFiles()
+ file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
+ dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
+ dataset = dataset.flat_map(readers.TFRecordDataset)
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ with self.test_session() as sess:
+ actual, expected = [], []
+ for f in range(self._shard_index, self._num_files, self._num_shards):
+ for r in range(self._num_records):
+ actual.append(sess.run(next_element))
+ expected.append(self._record(r, f))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self.assertAllEqual(expected, actual)
+
+ def testComplexPipeline(self):
+ # Setup a complex input pipeline.
+ batch_size = 2
+ num_epochs = 5
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ self._createTFRecordFiles())
+ dataset = dataset.shuffle(buffer_size=self._num_files)
+ dataset = dataset.flat_map(readers.TFRecordDataset)
+ dataset = dataset.prefetch(buffer_size=batch_size)
+ dataset = dataset.shuffle(2 * self._num_files * self._num_records)
+ dataset = dataset.repeat(num_epochs)
+ dataset = dataset.apply(batching.map_and_batch(
+ lambda x: x, batch_size=batch_size))
+ dataset = dataset.prefetch(buffer_size=None)
+
+ # Auto shard.
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ # Verify output.
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ with self.test_session() as sess:
+ actual = []
+ num_iterations = (self._num_files * self._num_records * num_epochs) // (
+ self._num_shards * batch_size)
+ for _ in range(num_iterations):
+ actual.extend(sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ expected = []
+ for f in range(0, self._num_files, self._num_shards):
+ for r in range(self._num_records):
+ expected.append(self._record(r, f))
+ expected *= num_epochs
+
+ self.assertAllEqual(sorted(expected), sorted(actual))
+
+ def testZip(self):
+ dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
+ dataset2 = readers.TextLineDataset(self._createTextFiles())
+ dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
+ self._verifySimpleShardingOutput(dataset, record_fn)
+
+ def testConcat(self):
+ dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
+ dataset2 = readers.TextLineDataset(self._createTextFiles())
+ dataset = dataset1.concatenate(dataset2)
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ with self.test_session() as sess:
+ for f in range(self._shard_index, self._num_files, self._num_shards):
+ for r in range(self._num_records):
+ self.assertAllEqual(self._record(r, f), sess.run(next_element))
+ for f in range(self._shard_index, self._num_files, self._num_shards):
+ for r in range(self._num_records):
+ self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testTextLineReader(self):
+ dataset = readers.TextLineDataset(self._createTextFiles())
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ self._verifySimpleShardingOutput(dataset, self._text_line)
+
+ def testTextLineReaderWithFlatMap(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
+ dataset = dataset.flat_map(readers.TextLineDataset)
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ self._verifySimpleShardingOutput(dataset, self._text_line)
+
+ def testFixedLengthReader(self):
+ dataset = readers.FixedLengthRecordDataset(
+ self._createFixedLengthRecordFiles(), self._record_bytes)
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
+
+ def testFixedLengthReaderWithFlatMap(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ self._createFixedLengthRecordFiles())
+ dataset = dataset.flat_map(
+ lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes))
+ dataset = input_ops.auto_shard_dataset(
+ dataset, self._num_shards, self._shard_index)
+
+ self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 6efd578a77..2e57b02583 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -321,7 +321,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _fetch(self, val, destination, fn):
"""Return a copy of `val` or `fn(val)` on `destination`."""
- assert isinstance(destination, six.string_types)
if isinstance(val, values.TowerLocalVariable):
val = self.reduce(val.reduce_method, val, destinations=destination)
with ops.device(destination):
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
new file mode 100644
index 0000000000..f659be5f42
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -0,0 +1,90 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base testing class for strategies that require multiple nodes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import copy
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+
+
+class MultiWorkerTestBase(test.TestCase):
+ """Base class for testing multi node strategy and dataset."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ num_workers = 2
+ # Leave some memory for cuda runtime.
+ gpu_mem_frac = 0.7 / num_workers
+ default_config = config_pb2.ConfigProto()
+ default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+ # The local cluster takes some portion of the local GPUs and there is no way
+ # for the cluster to terminate unless using multiple processes. Therefore,
+ # we have to only create only one cluster throughout a test process.
+ workers, _ = test_util.create_local_cluster(
+ num_workers, num_ps=0, worker_config=default_config)
+ cls._master_target = workers[0].target
+
+ @contextlib.contextmanager
+ def test_session(self, graph=None, config=None):
+ """Create a test session with master target set to the testing cluster.
+
+ This overrides the base class' method, removes arguments that are not needed
+ by the multi-node case and creates a test session that connects to the local
+ testing cluster.
+
+ Args:
+ graph: Optional graph to use during the returned session.
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+
+ Yields:
+ A Session object that should be used as a context manager to surround
+ the graph building and execution code in a test case.
+ """
+ if self.id().endswith('.test_session'):
+ self.skipTest('Not a test.')
+
+ if config is None:
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ config = copy.deepcopy(config)
+ # Don't perform optimizations for tests so we don't inadvertently run
+ # gpu ops on cpu
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+
+ if graph is None:
+ if self._cached_session is None: # pylint: disable=access-member-before-definition
+ self._cached_session = session.Session(
+ graph=None, config=config, target=self._master_target)
+ sess = self._cached_session
+ with sess.graph.as_default(), sess.as_default():
+ yield sess
+ else:
+ with session.Session(
+ graph=graph, config=config, target=self._master_target) as sess:
+ yield sess
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8cb5276579..18afdaa7b0 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -29,6 +29,7 @@ import six
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
+from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -229,6 +230,12 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
@@ -570,6 +577,100 @@ class PerDeviceDataset(object):
dataset_iterator, self._devices, self._prefetch_on_device)
+class MultiWorkerDataIterator(object):
+ """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`."""
+
+ def __init__(self, iterators, worker_device_map):
+ """Initialize the MultiWorkerDataIterator object.
+
+ Args:
+ iterators: a dict mapping from each worker to an iterator for
+ that worker.
+ worker_device_map: a dict mapping from each worker's devices to a list of
+ devices that belong to this worker.
+
+ Raises:
+ ValueError: if iterators and worker_device_map are not compatible.
+ """
+ self._iterators = iterators
+ self._worker_device_map = worker_device_map
+ if set(self._iterators) != set(self._worker_device_map):
+ raise ValueError("iterators and worker_device_map are not compatible.")
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group(
+ [iterator.initializer for iterator in self._iterators.values()])
+
+ def get_next(self, name=None):
+ """Scatter the input across hosts and devices."""
+ index = {}
+ for worker, iterator in six.iteritems(self._iterators):
+ if name is not None:
+ d = tf_device.DeviceSpec.from_string(worker)
+ new_name = "%s_%s_%d" % (name, d.job, d.task)
+ else:
+ new_name = None
+ with ops.device(worker):
+ data_per_worker = iterator.get_next(name=new_name)
+
+ worker_devices = self._worker_device_map[worker]
+ # Ungroup these per-device value so as to get a flat map from devices to
+ # values.
+ for d in worker_devices:
+ v = select_device(d, data_per_worker)
+ if d in index:
+ raise ValueError("Duplicated devices in worker_device_map: %r" % v)
+ index[d] = v
+
+ return regroup(index)
+
+
+class MultiWorkerDataset(object):
+ """Like a `tf.data.Dataset` that distributes data to different workers.
+
+ Each worker gets one shard of the input dataset. It is currently not working
+ in
+ eager mode.
+ """
+
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ """Initialize the MultiWorkerDataset object.
+
+ Args:
+ dataset_fn: a function that returns a `tf.data.Dataset`.
+ worker_device_map: a dict mapping from each worker to a list of devices
+ that belong to this worker.
+ prefetch_on_device: whether to prefetch to devices.
+ """
+ self._worker_device_map = worker_device_map
+ self._datasets = {}
+ # TODO(yuefengz, priyag): support different set of jobs for input
+ # processing.
+ for i, (worker, worker_devices) in enumerate(
+ six.iteritems(worker_device_map)):
+ with ops.device(worker):
+ worker_input = dataset_fn()
+ # TODO(yuefengz, priyag): support efficient sharding.
+ worker_input = worker_input.shard(len(worker_device_map), i)
+ self._datasets[worker] = PerDeviceDataset(
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
+
+ def make_one_shot_iterator(self):
+ iterators = {}
+ for worker, dataset in six.iteritems(self._datasets):
+ with ops.device(worker):
+ iterators[worker] = dataset.make_one_shot_iterator()
+ return MultiWorkerDataIterator(iterators, self._worker_device_map)
+
+ def make_initializable_iterator(self):
+ iterators = {}
+ for worker, dataset in six.iteritems(self._datasets):
+ with ops.device(worker):
+ iterators[worker] = dataset.make_initializable_iterator()
+ return MultiWorkerDataIterator(iterators, self._worker_device_map)
+
+
class PerIteration(object):
"""Holds input for multiple iterations at once."""
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index e96ce54741..9aeef9fa3e 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import os
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
@@ -34,8 +36,10 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
@test_util.with_c_api
@@ -436,6 +440,130 @@ class PerDeviceDatasetTest(test.TestCase):
self.evaluate(next_element)
+class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
+
+ def _test_iterator(self, iterator, devices, expected_values):
+ next_element = iterator.get_next()
+ for device in devices:
+ v = values.select_device(device, next_element)
+ # The `v` here can be a tuple.
+ for element in nest.flatten(v):
+ self.assertTrue(element.device in device)
+
+ for expected_value in expected_values:
+ actual = self.evaluate(
+ [values.select_device(d, next_element) for d in devices])
+ self.assertEqual(expected_value, actual)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate([values.select_device(d, next_element) for d in devices])
+
+ def _test_dataset(self, dataset_fn, worker_device_map, devices,
+ expected_values):
+ multi_worker_dataset = values.MultiWorkerDataset(
+ dataset_fn, worker_device_map, prefetch_on_device=False)
+ multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator()
+ self._test_iterator(multi_worker_iterator, devices, expected_values)
+
+ def _cpu_devices(self):
+ worker_device_map = collections.OrderedDict(
+ [("/job:worker/replica:0/task:0",
+ ["/job:worker/replica:0/task:0/device:CPU:0"]),
+ ("/job:worker/replica:0/task:1",
+ ["/job:worker/replica:0/task:1/device:CPU:0"])])
+ devices = [
+ "/job:worker/replica:0/task:0/device:CPU:0",
+ "/job:worker/replica:0/task:1/device:CPU:0"
+ ]
+ return worker_device_map, devices
+
+ def _cpu_and_one_gpu_devices(self):
+ # The worker_device_map doesn't have to be a OrderDict object, this is just
+ # to simplify the testing so that we can pass expected values as a list
+ # instead of a dict.
+ worker_device_map = collections.OrderedDict(
+ [("/job:worker/replica:0/task:0", [
+ "/job:worker/replica:0/task:0/device:GPU:0",
+ "/job:worker/replica:0/task:0/device:CPU:0"
+ ]), ("/job:worker/replica:0/task:1", [
+ "/job:worker/replica:0/task:1/device:GPU:0",
+ "/job:worker/replica:0/task:1/device:CPU:0"
+ ])])
+ devices = [
+ "/job:worker/replica:0/task:0/device:GPU:0",
+ "/job:worker/replica:0/task:0/device:CPU:0",
+ "/job:worker/replica:0/task:1/device:GPU:0",
+ "/job:worker/replica:0/task:1/device:CPU:0"
+ ]
+ return worker_device_map, devices
+
+ def testDataDistributionOneDevicePerWorker(self):
+ worker_device_map, devices = self._cpu_devices()
+ with context.graph_mode():
+ dataset_fn = lambda: dataset_ops.Dataset.range(8)
+ self._test_dataset(dataset_fn, worker_device_map, devices,
+ [[0, 1], [2, 3], [4, 5], [6, 7]])
+
+ def testDataDistributionTwoDevicePerWorker(self):
+ if context.num_gpus() < 1:
+ self.skipTest("A GPU is not available for this test.")
+ worker_device_map, devices = self._cpu_and_one_gpu_devices()
+ with context.graph_mode():
+ dataset_fn = lambda: dataset_ops.Dataset.range(8)
+ self._test_dataset(dataset_fn, worker_device_map, devices,
+ [[0, 2, 1, 3], [4, 6, 5, 7]])
+
+ def testTupleDataset(self):
+ worker_device_map, devices = self._cpu_devices()
+
+ with context.graph_mode():
+
+ def dataset_fn():
+ dataset1 = dataset_ops.Dataset.range(8)
+ dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2)
+ return dataset_ops.Dataset.zip((dataset1, dataset2))
+
+ expected_values = [
+ [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2)
+ ]
+ self._test_dataset(dataset_fn, worker_device_map, devices,
+ expected_values)
+
+ def testInitializableIterator(self):
+ worker_device_map, devices = self._cpu_devices()
+ with context.graph_mode():
+ dataset_fn = lambda: dataset_ops.Dataset.range(8)
+ multi_worker_dataset = values.MultiWorkerDataset(
+ dataset_fn, worker_device_map, prefetch_on_device=False)
+ multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
+
+ self.evaluate(multi_worker_iterator.initializer)
+ self._test_iterator(multi_worker_iterator, devices,
+ [[0, 1], [2, 3], [4, 5], [6, 7]])
+
+ # After re-initializing the iterator, should be able to iterate again.
+ self.evaluate(multi_worker_iterator.initializer)
+ self._test_iterator(multi_worker_iterator, devices,
+ [[0, 1], [2, 3], [4, 5], [6, 7]])
+
+ def testValueErrorForIterator(self):
+ # Incompatiable arguments.
+ with self.assertRaises(ValueError):
+ values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"})
+
+ # Test duplicated devices under same worker.
+ worker_device_map, _ = self._cpu_devices()
+ worker_device_map["/job:worker/replica:0/task:0"].append(
+ "/job:worker/replica:0/task:0/device:CPU:0")
+ with context.graph_mode():
+ dataset_fn = lambda: dataset_ops.Dataset.range(8)
+ multi_worker_dataset = values.MultiWorkerDataset(
+ dataset_fn, worker_device_map, prefetch_on_device=False)
+ multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
+ with self.assertRaises(ValueError):
+ multi_worker_iterator.get_next()
+
+
@test_util.with_c_api
class MirroredVariableTest(test.TestCase):
@@ -582,6 +710,21 @@ class MirroredVariableTest(test.TestCase):
save_path = self._save_normal()
self._restore_mirrored(save_path)
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testFetchAMirroredVariable(self):
+ if context.num_gpus() < 1 or context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test or it's eager mode.")
+
+ with self.test_session(
+ graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0"]).scope():
+ with ops.device("/device:GPU:0"):
+ v = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+ mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run({"complicated": mirrored})
+
_devices = ["/device:GPU:0", "/device:CPU:0"]
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 1ef7651d03..eb94760ad7 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -128,7 +128,7 @@ The base distribution's `log_cdf` method must be defined on `y - 1`.
class QuantizedDistribution(distributions.Distribution):
"""Distribution representing the quantization `Y = ceiling(X)`.
- #### Definition in terms of sampling.
+ #### Definition in Terms of Sampling
```
1. Draw X
@@ -138,7 +138,7 @@ class QuantizedDistribution(distributions.Distribution):
5. Return Y
```
- #### Definition in terms of the probability mass function.
+ #### Definition in Terms of the Probability Mass Function
Given scalar random variable `X`, we define a discrete random variable `Y`
supported on the integers as follows:
@@ -170,12 +170,62 @@ class QuantizedDistribution(distributions.Distribution):
`P[Y = j]` is still the mass of `X` within the `jth` interval.
- #### Caveats
+ #### Examples
+
+ We illustrate a mixture of discretized logistic distributions
+ [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit
+ audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in
+ a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures
+ `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints.
+ The lowest value has probability `P(X <= 0.5)` and the highest value has
+ probability `P(2**16 - 1.5 < X)`.
+
+ Below we assume a `wavenet` function. It takes as `input` right-shifted audio
+ samples of shape `[..., sequence_length]`. It returns a real-valued tensor of
+ shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and
+ `scale` parameter belonging to the logistic distribution, and a `logits`
+ parameter determining the unnormalized probability of that component.
+
+ ```python
+ tfd = tf.contrib.distributions
+ tfb = tfd.bijectors
+
+ net = wavenet(inputs)
+ loc, unconstrained_scale, logits = tf.split(net,
+ num_or_size_splits=3,
+ axis=-1)
+ scale = tf.nn.softplus(unconstrained_scale)
+
+ # Form mixture of discretized logistic distributions. Note we shift the
+ # logistic distribution by -0.5. This lets the quantization capture "rounding"
+ # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
+ discretized_logistic_dist = tfd.QuantizedDistribution(
+ distribution=tfd.TransformedDistribution(
+ distribution=tfd.Logistic(loc=loc, scale=scale),
+ bijector=tfb.AffineScalar(shift=-0.5)),
+ low=0.,
+ high=2**16 - 1.)
+ mixture_dist = tfd.MixtureSameFamily(
+ mixture_distribution=tfd.Categorical(logits=logits),
+ components_distribution=discretized_logistic_dist)
+
+ neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets))
+ train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood)
+ ```
+
+ After instantiating `mixture_dist`, we illustrate maximum likelihood by
+ calculating its log-probability of audio samples as `target` and optimizing.
+
+ #### References
- Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than
- a closed form function such as for a Poisson), computations such as mean and
- entropy are better done with samples or approximations, and are not
- implemented by this class.
+ [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
+ PixelCNN++: Improving the PixelCNN with discretized logistic mixture
+ likelihood and other modifications.
+ _International Conference on Learning Representations_, 2017.
+ https://arxiv.org/abs/1701.05517
+ [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
+ Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
+ https://arxiv.org/abs/1711.10433
"""
def __init__(self,
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index 9a3b780af8..762685db14 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -37,7 +37,7 @@ support for distributed and multi-GPU training and performance.
## Installation
-Eager execution is included in TensorFlow versions 1.7 and above.
+For eager execution, we recommend using TensorFlow version 1.8 or newer.
Installation instructions at https://www.tensorflow.org/install/
## Documentation
@@ -48,12 +48,3 @@ For an introduction to eager execution in TensorFlow, see:
- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
-
-## Changelog
-
-- 2017/10/31: Initial preview release (in TensorFlow 1.5)
-- 2017/12/01: Example of dynamic neural network:
- [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021).
- See [README.md](python/examples/spinn/README.md) for details.
-- 2017/03: Core functionality moved out of the experimental tf.contrib namespace
- in TensorFlow 1.7.
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb
index 459f2f4a7d..0279db80fa 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb
@@ -1,11 +1,27 @@
{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Eager Execution Tutorial: Basics",
+ "version": "0.3.2",
+ "views": {},
+ "default_view": {},
+ "provenance": [
+ {
+ "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg",
+ "timestamp": 1504118841551
+ }
+ ]
+ }
+ },
"cells": [
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "U9i2Dsh-ziXr"
+ "id": "U9i2Dsh-ziXr",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Eager Execution Tutorial: Basics\n",
"\n",
@@ -21,11 +37,11 @@
]
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "z1JcS5iBXMRO"
+ "id": "z1JcS5iBXMRO",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Step 1: Import Eager\n",
"\n",
@@ -33,34 +49,34 @@
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "RlIWhyeLoYnG",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "RlIWhyeLoYnG"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"# Import TensorFlow.\n",
"import tensorflow as tf\n",
"\n",
"# Import TensorFlow eager execution support (subject to future changes).\n",
- "import tensorflow.contrib.eager as tfe"
- ]
+ "tfe = tf.contrib.eager"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "H9UySOPLXdaw"
+ "id": "H9UySOPLXdaw",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Step 2: Enable eager execution\n",
"\n",
@@ -69,30 +85,30 @@
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "WPTUfGq6kJ5w",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "WPTUfGq6kJ5w"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
- "tfe.enable_eager_execution()"
- ]
+ "tf.enable_eager_execution()"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "twBfWd5xyu_d"
+ "id": "twBfWd5xyu_d",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Step 3: Interactively Use TensorFlow!\n",
"\n",
@@ -102,20 +118,18 @@
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "ngUe237Wt48W",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "ngUe237Wt48W"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"print(tf.add(1, 2))\n",
"print(tf.add([1, 2], [3, 4]))\n",
@@ -131,32 +145,32 @@
"# Most TensorFlow ops are directly usable with eager execution, giving\n",
"# results immediately.\n",
"print(tf.contrib.signal.hamming_window(x * y + 1))"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "IDY4WsYRhP81"
+ "id": "IDY4WsYRhP81",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"Numpy arrays are supported, too:"
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
+ "id": "lCUWzso6mbqR",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
- },
- "colab_type": "code",
- "id": "lCUWzso6mbqR"
+ }
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"import numpy as np\n",
"\n",
@@ -168,14 +182,16 @@
"\n",
"print(\"Multiplied by 42:\")\n",
"print(tf.multiply(ones, 42))"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "PBNP8yTRfu_X"
+ "id": "PBNP8yTRfu_X",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Step 4: Define and Print TensorFlow Variables\n",
"\n",
@@ -183,73 +199,66 @@
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "3Twf_Rw-gQFM",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "3Twf_Rw-gQFM"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
- "x = tf.get_variable(name=\"x\", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer)"
- ]
+ "x = tfe.Variable(0.)"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "45G7094TxsMb"
+ "id": "45G7094TxsMb",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"## Printing TensorFlow Variables"
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "UJBJeZ5XxuwA",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "UJBJeZ5XxuwA"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"# This does NOT print the Variable's actual value:\n",
"print(\"Printing a TensorFlow Variable:\")\n",
"print(x)\n",
"print(\"\")\n",
"\n",
- "# A TensorFlow variable represents a reference to a tensor.\n",
- "# The `read_value()` method provides access to the current value of the\n",
- "# variable. Tensorflow Variables are automatically initialized according to the\n",
- "# semantics defined in tf.get_variable().\n",
- "print(\"Printing a TensorFlow Variable's value using .read_value():\")\n",
- "print(x.read_value())\n",
- "print(\"\")\n",
"\n",
- "print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n",
- "print(x.read_value().numpy())"
- ]
+ "print(\"Printing a TensorFlow Variable's value as a numpy array:\")\n",
+ "print(x.numpy())"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "2njjWHcTpBEn"
+ "id": "2njjWHcTpBEn",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"## Changing a TensorFlow Variable's value\n",
"\n",
@@ -257,64 +266,64 @@
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
+ "id": "v3wr6Erbo_hB",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
- },
- "colab_type": "code",
- "id": "v3wr6Erbo_hB"
+ }
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"x.assign(42)\n",
- "print(x.read_value())\n",
+ "print(x)\n",
"\n",
"x.assign_add(3)\n",
- "print(x.read_value())"
- ]
+ "print(x)"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "uhtynjHVpTB5"
+ "id": "uhtynjHVpTB5",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"## Use a Variable just like any other Tensor"
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
+ "id": "7PbktdnHoehR",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
- },
- "colab_type": "code",
- "id": "7PbktdnHoehR"
+ }
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"print(x + 3)\n",
"\n",
"# This code will broadcast the value across the list of numbers:\n",
"print(x * [1, 2, 4])"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "GVChqwlwy1SI"
+ "id": "GVChqwlwy1SI",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Step 5: Debug Errors with Instant Feedback\n",
"\n",
@@ -326,60 +335,58 @@
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "23ap04N0v4k0",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "23ap04N0v4k0"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"vector = tf.constant([10.0, 20.0, 30.0, 40.0])"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "FCUMsIYxxRRa",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "FCUMsIYxxRRa"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n",
"# arguments) are within the bound of `vector`.\n",
"print(tf.slice(vector, [1], [3]))"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
- "cellView": "code",
+ "id": "T8me2oCNxpFp",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "colab_type": "code",
- "id": "T8me2oCNxpFp"
+ "cellView": "code"
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"# The following does NOT work, because the value of `size` (the 3rd\n",
"# argument) causes the indices to go out of the bounds of `vector`. The\n",
@@ -388,87 +395,86 @@
" print(tf.slice(vector, [1], [4]))\n",
"except tf.OpError as e:\n",
" print(\"Caught error: %s\" % e)"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "markdown",
"metadata": {
- "colab_type": "text",
- "id": "irxJhAgar84v"
+ "id": "irxJhAgar84v",
+ "colab_type": "text"
},
+ "cell_type": "markdown",
"source": [
"# Step 6: Using the GPU\n",
"\n",
- "You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n",
+ "You can explicitly place Tensors on the GPU by calling a Tensor's `.gpu()` method. The `.device` property tells you whether the Tensor is backed by CPU or GPU memory.\n",
"\n",
"The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster."
]
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
+ "id": "7J4N9baqaKCL",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
- },
- "colab_type": "code",
- "id": "7J4N9baqaKCL"
+ }
},
- "outputs": [],
+ "cell_type": "code",
"source": [
- "# The example code from here on will work only if your notebook\n",
- "# is running on a machine with a functional CUDA GPU. The following\n",
- "# line checks that.\n",
- "is_gpu_available = tfe.num_gpus() \u003e 0\n",
- "\n",
"# Create some Tensors\n",
"SIZE = 1000\n",
- "cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
+ "tensor = tf.random_normal([SIZE, SIZE])\n",
+ "print(tensor.device)\n",
"\n",
- "if is_gpu_available:\n",
- " gpu_tensor = cpu_tensor.gpu()\n",
+ "\n",
+ "if tf.test.is_gpu_available():\n",
+ " gpu_tensor = tensor.gpu()\n",
+ " cpu_tensor = tensor.cpu()\n",
"else:\n",
- " print(\"GPU not available.\")"
- ]
+ " print(\"GPU not available.\")\n",
+ " cpu_tensor = tensor"
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
+ "id": "4E-2n7VbzY1n",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
- },
- "colab_type": "code",
- "id": "4E-2n7VbzY1n"
+ }
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"# Time a CPU-based matrix multiplication\n",
"\n",
"print(\"Time to conduct matmul on CPU:\")\n",
"%time tf.matmul(cpu_tensor, cpu_tensor)"
- ]
+ ],
+ "execution_count": 0,
+ "outputs": []
},
{
- "cell_type": "code",
- "execution_count": 0,
"metadata": {
+ "id": "vbSFW-T5zhZF",
+ "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
- },
- "colab_type": "code",
- "id": "vbSFW-T5zhZF"
+ }
},
- "outputs": [],
+ "cell_type": "code",
"source": [
"# Time GPU-based matrix multiplications.\n",
"\n",
@@ -481,51 +487,9 @@
" # Subsequent uses are much faster:\n",
" print(\"Time to conduct second matmul on GPU:\")\n",
" %time tf.matmul(gpu_tensor, gpu_tensor)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "E5pIOe3Rz7iW"
- },
- "outputs": [],
- "source": [
- "# Second timing demo for GPUs, after it has been used once:\n",
- "\n",
- "cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
- "print(\"Time to conduct CPU matmul:\")\n",
- "%time tf.matmul(cpu_tensor, cpu_tensor)\n",
- "print()\n",
- "\n",
- "if is_gpu_available:\n",
- " gpu_tensor = cpu_tensor.gpu()\n",
- " print(\"Time to conduct GPU matmul:\")\n",
- " %time tf.matmul(gpu_tensor, gpu_tensor)"
- ]
- }
- ],
- "metadata": {
- "colab": {
- "default_view": {},
- "name": "Eager Execution Tutorial: Basics",
- "provenance": [
- {
- "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg",
- "timestamp": 1504118841551
- }
],
- "version": "0.3.2",
- "views": {}
+ "execution_count": 0,
+ "outputs": []
}
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
+ ]
+} \ No newline at end of file
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb
index e6c7c11733..1e65b27bc8 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb
@@ -43,11 +43,9 @@
"# Import TensorFlow.\n",
"import tensorflow as tf\n",
"\n",
- "# Import TensorFlow eager execution support (subject to future changes).\n",
- "import tensorflow.contrib.eager as tfe\n",
"\n",
"# Enable eager execution.\n",
- "tfe.enable_eager_execution()"
+ "tf.enable_eager_execution()"
]
},
{
@@ -106,7 +104,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
@@ -114,34 +112,30 @@
"startup": false,
"wait_interval": 0
},
- "height": 360,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 347
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 127,
+ "elapsed": 374,
"status": "ok",
- "timestamp": 1505502830690,
+ "timestamp": 1525154227149,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
- "user_tz": 240
+ "user_tz": 420
},
"id": "O4lsC4ckAcar",
- "outputId": "2f760690-cafb-4777-b970-91d839f99faf"
+ "outputId": "f8becb3f-498b-4cb7-9ef3-608a68cb65d0"
},
"outputs": [
{
"data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAesAAAFXCAYAAACC+2avAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXt8VPWd99+TK7kykxtJQIebqZfaqogtrhKNa1ooEKl9\nCrpVn9ZNW6x9VWsbCi7aVUt01NZ9tq21KVZlFey2YkQNohhj3QWK2liCF5RIBCc3yEwmIZnMTOY8\nf/zmzJwzSSBAYibh+369eIU5c87vXLh8zvdu0TRNQxAEQRCEmCVurC9AEARBEISjI2ItCIIgCDGO\niLUgCIIgxDgi1oIgCIIQ44hYC4IgCEKMI2ItCIIgCDHOiIj16tWrufjii1m8eHF4269//Wvmz5/P\n0qVLWbp0Ka+//vpInEoQBEEQTjksI1Fn/eabb5KWlkZFRQWbN28GlFinpaXx7W9/+6QvUhAEQRBO\nZUbEsr7wwgvJzMwcsF36rQiCIAjCyTOqMesnn3ySsrIybr/9drq6ukbzVIIgCIIwYRk1sb722mt5\n5ZVXqK6uJicnh8rKytE6lSAIgiBMaEZNrLOysrBYLAB885vfZPfu3cc8RtzmgiAIgjCQhJFaKFpo\n29vbyc3NBeDll1+mqKjomGtYLBba2yeuuzw3N0Pubxwzke9vIt8byP2Nd06F+zsWIyLWt912Gzt3\n7sTtdnPZZZfxwx/+kJ07d/Lee+8RFxfH1KlTueuuu0biVIIgCIJwyjEiYv3ggw8O2Hb11VePxNKC\nIAiCcMojHcwEQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBOHXo6HCzcmUt\nTU2Z2O2dOBwl2GzWsb6smEfEWhAEQfjMWLmylurq6wAL9fUasJ6qqqVjfVkxj7jBBUEQhM+MpqZM\nwBL6ZAl9Fo6FiLUgCILwmWG3dwJa6JOG3e4Zy8sZN4gbXBAEQfjMcDhKgPWhmLUHh+Pysb6kcYGI\ntSAIgvCZYbNZJUZ9AogbXBAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFr\nQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhx\nRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBODE6OtysXFmL02mjsLADh6MEm806rGOamjKx2zuHdYww\n9oyIWK9evZrXXnuN7OxsNm/eDEBnZye33norn376KdOmTeOhhx4iIyNjJE4nCIIgACtX1lJdfR1g\nATRgPVVVS037RIuzz9dDTc33AQv19YMfI8QeI+IG//rXv866detM237/+98zb948XnrpJb70pS/x\nyCOPjMSpBEEQhBBNTZkooQawhD6b0QW9vv4qqquvZ/v27mMeI8QeIyLWF154IZmZ5j/wbdu2sXSp\neltbunQpr7zyykicShAEQQhht3eiLGoADbvdM2CfaEGH7GMeI8Qeoxaz7ujoICcnB4Dc3FxcLtdo\nnUoQBOGUxOEoAdaHYtYuHI7LAbPru61tD1AM2AAXkyY5sVr/CBxi3rwMHI5FY3cDwrCJuQSz3NyJ\nHdeW+xvfTOT7m8j3BhPz/uLi+klOTgQgOTmBnJwMsrIyuPnm5w2x7DKmTbuPgoJzaG7ew8GDt6PH\nuDMyNpKdncFNNz3Pxx+nM2NGFw8/vJCsrNhLOJuIf37Hw6iJdXZ2NocOHSInJ4f29naysrKGdVx7\ne9doXdKYk5ubIfc3jpnI9zeR7w0m7v2Vlz8XFuVduzT6+lSy2N69KRhd3zk5Z/LCC5dRWtrPwYOR\n7Xv3pnDjjYOvEUtM1D8/neG8iIxYnbWmaabPJSUlPPPMMwBs2rSJK664YqROJQiCIDB0gtlQsezB\ntg8nSU0Ye0bEsr7tttvYuXMnbrebyy67jB/+8Id897vf5Uc/+hF/+ctfKCws5D/+4z9G4lSCIAhC\nCLu9M1R+pdzauijrsWxVruUJx7JXrZrDrl2VuFzTsNkOsnr1EtaufWvQNYTYYkTE+sEHHxx0+2OP\nPTYSywuCIAiDMFSCmc1mHdSVXVn5Nk7nKsBCb6/G2rXrhxR2IbaIuQQzQRAEYXjoojxYTHewTmWD\nubyHEnYhthCxFgRBmIAYu5vpncrsdk1c3uMUEWtBEIQYYai+3SfSz3swK/rpp+cgLu/xiYi1IAhC\njDCYNVxVtXTI7UdjsOQzcXmPX0SsBUEQYoShyqhOpLxKEscmFiLWgiAIMcJQpVjm7S7a2t6ltJSw\nS3ywphojYUXLOM3YQcRaEAQhRhjKGjZub2t7F6dzFU6nconX1T1AaelU7r770mEL6XBF+ETc78Lo\nIGItCIIQIwxlDRu3l5aC0xlxibvdZ/KnPy06rjahwxVh6W4WO4xYu1FBEARh9IluGQpqPvXxCOlw\nRXg4IziFzwaxrAVBEMYRuku8ttaPx5MCLAQ0CgoODXuNoWLjQ51LktTGHhFrQRCEcYTuEr/hhv+i\npiYBeBY4xNtvu3G53ANiz4PFp4crwlLqFTuIWAuCIIxDmpsLgF5gOWChtVWjomJg7Hmo+LSI8PhC\nxFoQBGEcEG0hFxT4qK+fwrFiz5IkNjGQBDNBEITPiI4ON+Xlmygt3UZ5+TO4XO5hf69byPX1V1Fd\nfT0QoLBwN8dKAJMksYmBWNaCIAifEdEu6V27KqmtvS4cZz5aSVW0hdzcXEBt7SIqKgaOyDQiSWIT\nAxFrQRCEz4howXU6P09FRe2Qgmx0WRcUNFNf/xSQAXgoKPAcdUSmznCTxKRbWWwjbnBBEIRBOJbL\n+kQwu6RdwLts3Up4/aO7rBOBa4DFwLWhzyNHtJu9oqJ2RNcXTg6xrAVBEAZhNFptOhwl7NpVidP5\neeBdYCW9vRaqq9X6DkcJfX3r2LEjDjiMz5cWLsdqbs7B7AbPOalriUYS0WIbsawFQRAG4XjFaziW\nuM1mpbb2OsrK3KSkFA5Y32azkpychNv9bdzun1JTsyJs4UZb3QUFLeHzLVv21Elb/pKIFtuIZS0I\ngjAIw+3ypROxxDupr3+RurqXKS6OHxD71WPI5eXPhCxq8/pDvSREJ4r5fAkmy/94eoMb0WPV+/Yl\nUFhYSXZ2ETNn9kgiWowhYi0IgjAIx5tFHRHZGmABbvcWqqvT2LXrCWprrx+QrOVwlODzPcL27V1A\nNj5ffzhuPdhLQnSiWGnpNoZr+R8teczo7geNuXNlslYsImItCIIwCMfbajMisunAFvTOYk7n4kE7\ni9lsVpKSUnG7vwdYqKnRSEpaf9SXBKPotrXtAcoYjuV/PCVhEquOTUSsBUEQBmEoa3So7brI1tW1\n4HafyXAEcDChPNpLgtkKLqawsJK8vLMpKurl7ruHtvyPJsjH6+4XxgYRa0EQhEEYyhodarsusi6X\nm8svfwKnczFDCaAu+Pv3t6CSuoYnlGbRtZGXdzZbt15Bbm4GH3xwgPLyTYO6uo8myNI0ZXwgYi0I\nQswylo06IsLoBmrC9dD79iUQbaV2dLi55ZaXQiVXh5gzJ5kvfnEdzc052O0eVq26wCSkPl8PNTXf\nBzqBDVitXoqLE44plEcT3aO5uo8myDJZa3wgYi0IQswyGrXOwyXSMcwJ3Bauhy4srCTaGl65spYt\nW24Mb9u2bQNlZQG2br0CgPLyTab7sFofCO1rBa5l+vRnqaq6Ilz+NdTLiS66+/bF09HRRGNjEeXl\nz/Doo2VHdXWLII9/RKwFQYhZxjb5Se8Y9rzpGrKzi5g7V1mpBQUt+HwJvPZaErABWIgS4AyamvrD\nK0XfB2SjOphtAdJoa9uDyzXnmC8nuuhef/3TNDSswum0sHu3xo03PoHdjsSeJzAi1oIgxCxjmfwU\n6RjWhdGSnjmzJyygRotZ7fMgUAD00NbWTmkphnGWkTXmzQvyzjsP43SuwpgxPtyXE+Vuj+xXV6ex\nY8cVSOx54iJiLQhCzDKWyU+RF4WFDBVXHmgxfw5YRHLy7TidP8XptFFfr7Fgwe8oKzPex1dYtuwt\nnM7IsVu3gs023HKsQxhfIOCQuLonOCLWgiDELCMtQEdLWIv+bvXqOUReFAI4HFcOSG6LtvyhO/T7\n01Eu7nSgiwMHMnn11SVHPba3N5He3psoLKwkK6uIjo697Ntnp7z8mQGx63nz0qmp2YCawNXF/Pky\nHWuiI2ItCMIpw2Ax4fvuu5yVK2upqwvgdicDl1FfP5nBktmGEvTaWj8eTwrKCteAg4BqdgIaHR2V\nA65F9xps3Qq9vX7Uf8dv0NOTwFlnfUJDw3SczgwaGjz4fM/z+OPfCl8DJGG1eoGDzJuXwaOPXkN/\n/4BTCBMIEWtBEE4ZBosJR7fbhI3ANYPGi4dKALvhhv+ipkYD/gDk4PMlocqyAGpwuQoHWMjmHuEp\nqGQ2C273Il5//U7g1vA1bd/+AKCEuqRkfTjWDarrWVaWdch51sLEQMRaEIRThsES1gbGndOJjhfr\nFvXWrRj27WTz5oMUFf03gcCh0PbbAQuapqGywy3ActMYzGhr3eEooa7uZdzuyDX090+PuqZsQL0s\nqPGa0h70VEPEWhCEU4bBEtYqKl41CbjV+j7FxS5TIlnEot5AJLHrRYLBVSGR1YDHMQusDzWF+OjC\narNZ+fKX+9myJXINOTkHaWszZ4+D7hno5ni6ngkTAxFrQRBOGWw2azhG3dSUSUXFq1GJZB4cjuUD\nEsn27YtHucctwL1YLFPQNLMQQzvmDG0L8Klp2/vvv0lJyRFmzQqYXOIWSwD1IqASxs49N430dHP2\nOOiegSWha0mjsLABh+O6UXteQuwgYi0IQkwQnby1atUcKivfHvFWoyfSFa2jowmIxImTk9fg9Z6F\nWZyt6CIKO1BW9W3AfcDZwBG83ttoaNhCQ8P1pvM2NxcAV4XPd/jws2zYcMWA61Cegc2hZ+LG4bgO\nTYNlyzawd2/KZ96SVfjsELEWBCEmiBbRXbsqw4lUx9NqdLDyrNzcjPD3J9IVbfLkqTidG9FLsU47\nrZDZsz1s3/4A3d2ZBAKTQmumAW8CPwXeAGzAOcBiw2rpA8473OYvg5WyRbcy/SxbsgqfHSLWgiDE\nBNEi6nJN40QSqQaznJ999vrw98cSxsHEvrPzU4yW9ZEjlfzqV9excmUtjY2pHD78AR5PBt3de1Ad\nzGxEOp8ZO6C5gJ1AO3v2OLnhBicPPbT4pJq/yDzqUwMRa0EQYoJoEbXZDtLbG/mcn38oPOSioKAZ\nSAxNtTK7fo8lXqtWzWHnzntoa8sjPv4Q3d3puFzu8PGDiX12dpGp21h2dtGAkq/k5DXARcAelCir\nzmeZmR34fHfg9Z4P7AXuBiz4/VqosckLJCWlnrC7X+ZRnxqIWAuCMGocz4jLaOty9eolrF0b+ezz\n+amuVpOt1DSsaxisucn+/QHgSeBrwOQB4lVZ+TYtLbOAawgGLWzbplFREXEdDyb2M2d2snu3uT94\n9H59fRcBS1Au7/tISSmktBQcjjKWLXuL+vqrgM2mYyCD7ds/xe3+HsNxYw/2PB2OEpKTN4Zi1tIT\nfKIiYi0IwqhxPMlcg8Vjq6rs4d+Xlm4jInQZGEVv61bYtesJnM6bUC5oNYayuHjKAPFSmd1O1DSt\nLmAhdXUBGhubqKx8m/37W4gujVq1ag67dlXick3DZjvA6tVl3HnndswJZkfC1wNTuOyyI1RVqa5j\nEeu3K+qYLlQN9fDc2EM9z6efvkaaokxwRKwFQRg1oq3PfftSB8xr1jSGZX2b3b0ejKKn+mqvRu8+\nBhamTz+DqqqBGdXRmd2wAbd7Epde+if8/n9HdR4zD+6oqKgNJ7v19mosXVpJd7debuUDmoHvh86g\nAclApP+ncQ71oUP30NMzlbi4w8yblw4khLqfHduNLfHpUxcRa0EQRo3oeGpHx14aGswZ3sCg1uJQ\nfbhVL20f8ERo3URgAZFsbDia6EXHn5XYXoXfHwh9tqLizVU0NZ1BRcWrNDamYRTJSBexxSjX9lVA\nDSrT+wPgX2lufi18zqMNJHG53CQlDS+5TOLTpy4i1oIgjBrRceh9++wGoeykrq6Vvr4pKAt1IWAN\nW4tDuXxVL20Vu1ax6eXo4lVY2EBeXvCoojdz5pFQ/LkTeDG09QXgI4zdydzun1Bfr85dWLiWgS5v\njYgrezLqheFFIAd4gYKCwYV0sLjzcEutxnJkqDC2iFgLgjBqRFuU5eXP0NBgFkTzAI3lYWtxKJev\nUbCUIK4LZYV7cDiuO2YmtX58bW0LHs9PDedfh8redtHb68bvj8S0s7KmM3euOmdb27s4nStQrvh7\ngQySklbS359Mf/9cVDvQBcBfBj3/8cTxT0bYhYmFiLUgCJ8ZRqHdv99rGl6RkuKntHR92FocyuUb\n/QJgFLSKilePWfqkH19auo36eqM73A/00tX1KZr2C4wx7Vmz+sOu+VtvPURPzyaOHPkYv//HgA2f\nL5Kdrr94NDfnDCq2xxN3PpFua8LERMRaEITPDKPQKnd2RIxLSzEJ0VAu32gB7Orq5NVXf4guaD7f\nOh5/fNmAc+vH7dsXT0dHE93d8Zhd25OBa9G05zCKqdXqxeG4ElDiWVNzI2ZvwDVEZ6dDGna7e1Cx\ntdu1YcedJaFM0BGxFgRhTDhW/HWopKxoAUxMXItR0LZvjxtwzOHDxjnQG1HZ4CrrOzPTS3d3C8Hg\nTaG9zVOtiosThmy4EkloM2enT5q0i9Wrl/G9731EtNg+/XT04BBJKBOOjYi1IAhjwtEypI9GtGD2\n9+dgtpAPDzjmpptqQhncnahJWJF4dFzcM+Tnazidk0N7LwDuwGqdQXFxAqtWXRAuN2tr2wMUo2q5\nXUyatAuLxU1m5l407S7a2s5HDez4MZdc8is0bSrwGCpbXDVoOZ77loQyQUfEWhCEMWG43c2i9yso\n8Jmszby8Vlpa9PGSrfT2urDbN2GzHWDTpjJmzLDz8cdqAIfK1r4NYzwaDvPHP17OkiVr6OubgcXy\nMZdcks4f/nAlmgaXXfY4LS0/ALYA55KYuJaUlCx6erLwes8EvkZv72Ss1gdQHcwUfv+Foc9DN2g5\nFif6QiNMPESsBUEYE4abPBW934IFv+OKKx6hrs5CMHiY/v4errjiEIcPp/L++014vSo5rLfXRXHx\nL5k9+4vs21cPfA7owezG9jFvXjq//e1H9PWpnt2appGVtR5Ng5KS9bS0fAEl1KpEzO/vxu83J5Op\nuHU2Q3U0G6pBiyAMFxFrQRDGhOEmT0Xv19xcQFvbuwQCqrlKe7vGe+9VUl9/RSimq++7Ba/3Lhoa\nLMDVKFFNxyioU6Y0AqezdStE13qvXFkbcp13o4+1VEQnk6k1580LkpS0nrq6AG53K8aOZhJrFk6W\nURfrkpIS0tPTiYuLIyEhgT//+c+jfUpBEMaI4xncEZ08ZZyqZTx2sCSrDz4wj89U4zTBZjsQmtTV\nCfQxUFQvJTPzfk4/fSYdHXvp7k4JZXfrDVKeBRIpKPDQ1FRApGb6d6huZQNbnVqt71Nc7MLh+Ao2\nmxWXy80ttzzP9u1/ALKZNy+Iw/GVkXvIwinJqIu1xWJh/fr1TJ48+dg7C4IwrhnKtR0t4qtWzaG7\n20Ni4lr6+3PIzm7i7bcTaWubA3RTX78E2ExV1VJWrDiDmprb8fnsQCtvvHGE9HSLaXympn1ISclL\n+P3dJCSsIRBIAaZjdkt3A5NJTw8wa1ZPqO3p86HvazDXSa8LvSQsAZ4DJmOxrCEjYzpz5/aQlGRs\nxLLc9EJis1l5/PFvfRaPWziFGHWx1jSNYDA42qcRBCEGGMq1HS3iu3ZV4nTmoeK8GbS3dwA/wxgH\nbmrKZN++JhYufJ5gMNKk5PDhDcTH/4NJk9agaTPw+/fh9f6UhgYbEXd3MlBCxPX9D1QG94N0d/tp\nbEwNradPwUrH7GrPCZVYbWbfvgQ6OtxkZ5/HzJlHcDiWhsW5o8NNRcXwPAmCcDJ8Jpb1jTfeiMVi\nYdmyZXzzm98c7VMKgjBGDFUXHC3iym3dBhgbjAxsKnL11c8RDH4u6rsM+vtn0t9fTmFhJU7nl1FC\nrH/vB3YDS1HWsga8AawGLHg8GocP672+F6Ji1fuARabr1jOxy8s30dCwCqfTEuopHvEWqNptFdeu\nr19CX99fSE5OEvEWRpxRF+uNGzeSm5tLR0cH3/72t5k5cyYXXnjhaJ9WEIQxYKi64GgRV7HlwtBn\nN7AntIKKERcWNrBq1RIuvvgg8CEDZ0C3As/jdPaihHmx4ftEYAZKhDOIWM+6ld1FZmYuUGmYnnUd\ncB9wNoWFDTgc14Xv6WjeAn1spr7+jh1xuN3SHlQYeUZdrHNzcwHIysriyiuvZPfu3UcV69zcjNG+\npDFF7m98M5Hv70Tv7fBhNzfdVMPHH6czY0YXjz66hKwsszX56KNlrFixMbRPN2vXfotLLnmMlhYN\nFS+OuMCnTbuPd965iRUraggGVwGfALejYtDtKCs6H7gUaADsqIEaBSj394LQmkYSME7n6u6+j6lT\nz8XpXBzeIzW1kEWLjvDwwzeZrr+oqMf0olFU1EtubgZOp41ob4DF8qlpm9Np+8z+zkzkv5sw8e/v\nWIyqWPf29hIMBklLS6Onp4c33niDm2+++ajHtLd3jeYljSm5uRlyf+OYiXx/J3Nv5eXPhePRu3Zp\n9PUNZk3G8+tfLzJtOf/8PGpqNgD6HGkACzk5Z9LfH8/evSmh7XagAvgD8AVU/PkHRIu8Emz98/6o\n78wtSW222WRkfAw8hbK+D5OW9iF7987lO9+pNrmvf/zjL/DGG5W4XNOw2Q5w221ltLd3UVjYgdHi\nLyxs4ItftFJTY9zm+kz+zkzkv5twatzfsRhVsT506BA333wzFouF/v5+Fi9ezCWXXDKapxQEYYQY\nbhnWiQ6baG4uQLXhfAyj6L3zzm7OO28PZ51lrImeDEwFFpGfX09Ly2Sik8JU0xPlyk5IsBIIGL/L\nMp1j5swedu7sBH4Y3tbevoH29qvCCXB5eWdjt3fi8/nD7u7eXo21a9dTVWUfxOWvXOdJSdIeVBh5\nRlWsTzvtNKqrq0fzFIIgjBLD7TA2VFLZYOValZVvG9qGHgkd14MxvqxpBTidNxIM3kNZ2XoaG1Np\nb3+Pnh6NuLg/cs45mZx//jq2b+/A7Y4khcFeVCMSK3APA/uFb8Bq9VJcnIDDcTnnnVdLdOKa/nun\n8/M4nUuor9ewWv/IYC8jQ7UClRi1MBpIBzNBEAZluBbzUEllt9zyElu2qGzv+nqNF164g0DgrvDn\nBQvWsWDBOmpqfIbzgJpkZaGz024Yp9kTfnHYts1Fbu6DdHUB3IPFkk1c3Kf09/8EJdQagUAPkYSy\nbiwWK0uWBHA4rgx7B1SSmwvVSjQNleR2KcqKj7QKhUMYhV+6kQljwcBZcoIgCCiLWYkUgMb+/R9S\nXv4MLpfbtJ/NZuW++y7Hbvewb18ql1/+BCUlz7FtWwuqMxiAhUDAjlH8X3stgaSkROLjm4GvEmnr\n+S7gQtP2hs9lfnF4hvb2NPr7LwJmoWnX0N9fBNRgtT5KYWElKhltOSpLfDmTJ3cDsGzZW+F72LSp\njEmTfhnabwnwMzIz/5NJk+5AJao9BbiYNy+DsrL1nHfes5SVrT8h13ZHh5vy8k2Ulm4b9BkKwrEQ\ny1oQhEFxOEro61vHK68ECAS6cLu7qK6eyvbtf+Svf/22KX5tdJmDhtO5EZXBvQG4FiX6jRgt1N7e\nZKqrl2Ox3INxUIYS2Pvwem+jokJ1MTO72l1EN1BRMenFTJv2JJ98YkH913Ynykp2091dSHV1PHBZ\nKCb9MHl5ZzNp0gy83sgLRFzcJLzefwuvXVhYyUMPXXfStdLDDSkIwlCIWAuCMCg2m5Xk5CQCAWPj\nko20ta3hllseISkpNRx/bmxMY2AfbjXVCjaj6qKTgcdR86SnAN9ATbmaiu76jhxfAPyOLVuslJc/\nw+rVc9Bd7Q0NGVHJY16U21qjo6MJj2cFSvwvBHYBd4X214UdnE7V5ASexBzbzjZdR17e2SPS1ORE\nk/AEQUfEWhCEIYkWGV2Et2/vwu3+HrqlOGXKHShhzkANuuhFiV8Lqq92I5oWaRmqLG5QrmY/8L+Y\nG5skAT+jr+9RqqsTqav7G8XF8Tz99Bxuuul5tm0zCmwLaWk9/PM/r6exsQin02ilw8Dr7yISz+4h\nM/NeZs48C7vdg8/Xbyq9Gqn49FBJeIIwXESsBeEURs/YdjptFBZ2DCjPihYZFVceaIEePpyAcRBG\nQsJdBAIbUNnZk5k82YXbbRRNN/BrlKtcubaTk9fg988kGExBNTbRXd7fwe22UF2t3MdJSQA/B+ag\nLOrvEx9fFWoN+gy7dycSEWM9acyGPiHL6+3E6707fK3p6ZVs3apmTbtc7lEpvRoqCU8QhouItSCc\nwkTHmqNjqQ5HCUeOPMJrr2kEAk7i4rK45JKH2Lv3CGoaVTdwMYFAPkbxPuusc5g5s4emptcGtVhV\n1na84RgbfX3TgY+BuahxlQuAHAa6jzNRLvUl4evs6ckIX++WLY/Q16eL8SLgDmy2WcyfH4fDsZxv\nfGMnu3dH1szOLgqvM1Q51skyWusKpw4i1oJwCjBUg5NjxVJtNitPPfUvpm3l5ZtoabkFXXgtltvR\ntHOIxH5d7NnzNh99VITNtodHHilD0+CVV+7E75+NillfS2Lievx+o4A3Ar8wrZuRkYHHE+0+1qiv\nbzadr7+8jtPXAAAgAElEQVT/AKWl27DbO5k58wzee89oxV/A7NkJVFVdBsDMmUdCAzkiDVIEIdYR\nsRaEcc5wOo0NlY18tFjqcAVe07JQFnEVqnd3F8FgJb29quPX0qWVzJ07Db//34kI8wbmz8/kf/7n\nDrzeuSh39udN606ePJudO6/kllseYfv2LiAbn6+fn/98Hps3dxAMRlzdmvYL6uvVveXn/wJz0th7\nfPRRIeXlz+BwlIhLWhiXiFgLwjgnWojr6h6guDjPJNpDWdC6cKmYtcskXMMVeNU0pNLw+bemc7lc\nBQPOHxfnYc8eNxbLLJQrfSHK9R1Zd9IkJwBJSanhZLaaGo2kpPV85Svp1NQsN5wzsnZPTz6RjmgN\nwApcLls45l1VtXTYLunhtlwVhNFGxFoQxjnRQuh2n0l19SKM8eehLGg9ljrYoITIum6ghq1bCZdR\n9fWtY8eOOI4c2Y/ff6bp/Kq1Z+RcmtaI3T47dP5O4EWCwSAtLbOAr6FqoTcCC4iLu51g8MvAEVpa\nfkBFxeZBXzSefnoO77xTidM5DeVWj2SS9/a2ohLXQL0IbEHPAt+3L/64BFjqo4VYQcRaEMY5g2ds\nm+PPJ+L6LShopr7+KVRJViK9vUuorp7MSy+t4Z/+yYbb/R3gEeAAZrezF2OrT6/XRm1tVyi2nQXc\nZth3I3ANKSl9XHbZ0/zP/2Tg8fhQXczcvPiik/nzC03rt7Q0cMstzbhc01D/hX0/tE4asAO/f7ph\n//0YG6h89NEdfPnLTtzunzAcAZb6aCFWELEWhHGOLsR1dQHc7kkol7Kyns1WpMbTT885DjduIsZy\nLF1Yvd6LeP31N1EW60qUtfwEaiBHO6qLsdFFvQGP59rQPtEzoPuA5wgEGnj11UmGLO6rgY34/d+n\noeEOpky5k9bWTCCHlpZp1NT0oCzq7xPp7f0ukAt8E3gUVfaVazqf13s+Xm8iwxVgqY8WYgURa0EY\n5+iubJfLTUVFbbhcyuG4nIqKod24RiEvKurh7rsvNQl5c7O5bEpZyhpwhP5+O5GuY1ZUE5PridRG\n3wecg5o9nQc0AR+iyq6Mk7KSgCX4/Xpf8IENWDyeWSQnt2O2yO8AZgPVqBeEbuAW8vP/MzQ+MxX4\nDip2bbT6+1CW//AEWJLRhFhBxFoQJgjRtbwdHW7q6gIMZUVGx2O3blWJafooy/37WzAL3QcoUfwq\nmnY/0EYkVmxsF2pDCfWi0P7LUeJ9F8oK3xD6eRi4OXSMGo9pPp9qwKJp+4AZmIXc+HKgERdXyeLF\nm1m9+uusXbuerVuht9eC8jJsJDXVj9V6EKdzRegY87jM4T5TQRgrRKwFYYKycmUtbncyRgFsa3sX\nl2vOoCVYbvdsqqt7eeGFZwkEfoBqevI4CQkHiY930dcHyjLdiKaBGtDxBMpSNSd5KYu6m0gnsjwi\nVvi1obUnh36BalGqhFUJ//+GjrkTTZvOpEmfmu7DYslG0yLXnpmZHxbVqio75eXPhLK/rcByFi3a\nyN13XxdOWLPbzeMyBSHWEbEWhAmKEuPLiCR7fYDTuWKISVYaKuY7hUAgDlV+dROwhUDgCwQC/wvM\nAv7VsP8TKAs3A9iHxXIP8fF5BAKTQts04K+AhylTPqa11Xiut4F+LJZ7yMgoJDHxQw4f/hg4HdUi\nVG9peit9fRZaWlwUFlaSl3c2druH7m6LqT/4vHlB071Hu68ffngJ/f3xYiUL4xYRa0GYYOix6P37\nA8ALRMqjGgAL+/bFU16+iX37EigsrKS7Ow+PJxXoR8V6VwHPM3Bs5YOYXdFeIq7opSQn34HFkkIg\ncD1qulYkOe3ccx8hGFxDe7sVSEFlmF+IpnnxeBaQmPhbIuVWoCzvFoyu9by8s009vCsqjLHkr5ie\nQbT7OitrYGmaIIwnRKwFYYKgi3RdXWu4NElZqA8CU1GZ0y/S0dFEQ8Oq8PcLFqwjI8PCn/6Ui7KI\nLaj4cXTCVzbmmHIbcC9wJtCL16u3En0KJeSRY1991UJcXAoqSWwjymqPZJn7/YVRax9BxbUHTwST\nWLJwqiFiLQgThEjC2POYRfZzKMsYrFYv2dlFoVnO6vvm5hxefPEqtmypxOPJRAnkQuBhzHHoD4B7\nUOVQn6Cs9dNQ/43obvR7UWIcB/wB1VAlm2CwjWBwNsYs78j1pQE+EhPvxO+/ECXUXwX+jB7DLixs\nwOG4bljPYbCmJ7m5GcN/kIIQg4hYC8I4xihMjY3NKGs0Oqv6g9C2i0lNbaGpqdXwvYuWlgYuuiie\n1FQfHs8hIsM0koG1wNmo+dR+4N8M696DuQ57PxExVo1U4EbD9/eGfkZf35vArcyf/0fee68Bl6uA\nYPBBEhPzSEhwMW9eBg89dJ0pGexoXcgG6zr27LPXj+RjF4TPHBFrQRjHDBxxuQFlFW/AYulE0yaj\nksImAz/B6ZyDsnZ19/X7tLTcTkuLGiepYtjxeDyRrl9qNvVUVDmWbhF3hn4+jxLfhahY9FMoV/jn\nQvsaLegzQ9fXgYpPzwQ+4owzTufsszfj82XidN4aPm9f30ZgOe+8U3nU+46uH5euY8JEJG6sL0AQ\nhBMnWpiURfssYCEpCVSZlJVIzPkaVLz4Z6i4snnSVV7e2cTFTQltawLuIxA4DTW+8gPUCwGooRv/\nhnKTXwO8SE5OV+j35aiMbo9hfw34O6rL2b+EzvuvQCVnn51OVdXSIZqwdOJ0JvGlL71MefkzuFzu\nQe/bKMh2ux7rVueVrmPCREAsa0EYx+Tnt2N2KX+KsliXk529FqfT+F02A8XQQ3QS1/79h4hY6SsN\nx98R+mVHZY4b1+rC69Vbieq11A+jeocfRjVKuZWMjN/j9f4Wv/8H4WN1oR28x/mLwG243Zbw1Kyf\n/ewC3n//TZStoWq5jYIsXceEiYiItSCMA4aK0VosAZRL+xxUYtZNJCb+ioUL13PTTZdTVnYHXu8Z\nKBHXk8f0lqC7gHxgLSkpUykuDuDz+QkGe1FCrTcyIfTzDOA6VH11E+aXhAz6+oyNS14GvoDKLs9A\nxbxtBAIF5OYewOmcjHLHv0hjo5fzzvt/ZGbmUlhYSUdHPl7vx8BZKE+B2YK++urn8HrvDp970qQ7\ncDi+G35WkikuTERErAVhHDBUjLa5uQBlyR5BWco1zJ49i6qqpZSXb8LrvQtdnJOSKklIuJ2enjhg\nGiqurGqws7PvIzm5kOrqG9HHWMI+zILsDH13AGVdr0G9JAAsJCnpMIHAGjStCJVsdhvKotbLxzR6\nexPp7b2JwsJKenoScbt/gsdjwePRcDo3AuUUFq7F6bwRlQneHzpeXdP+/V48Hv2zcu9bLGdIJzJh\nwiNiLQjjgH374ol0IusKfdZdx06MYyA7OytDx6QSsUq34PPdh893H2bX9qNAKocPF1BX10JEBK8F\nqlCZ4fkogc5CDc643XD8htC+GoFAK5p2t+E7NaVLfc5An1kNVvLyzgagvn7g4I7U1Azi4n5PMHgm\nyoJ/mMREF37/atzugee12Q6OwBMWhNhGxFoQxgEdHU2ozmJKrDo6lCA7HCVs2/Ys3d0RIXc6M7nh\nho20tzehRk0aB20UYnZtu4Dv0NtrobdXL69agcoeTw/9Mo67/H3U8T5SUp6gtBS2bp0V9V1a6Pca\neXlO2toy0NuPFhR4SEpKHSRGrXHwYDvB4C8M2+8jIeE0/P7I2gkJXSQmPoHNdpBNm5aMwBMWhNhG\nxFoQxgHRjUy6u/MpLd2G3d5JWlob3d03YxS3mpofkJFRScQa34PK3Nbjyrqr20qk3MsKTCUh4X7i\n4vz4fBehksOMAtwedbwPTfuE1auXs2tXdUjw9VjyLs48Mxjq5Z3Ntm3Gmux14USwxsZUDh/eS1aW\nnVmz1g8i+oXYbAdMa3/taykSlxZOKUSsBWEcMHPmEXbvjoiVxzOJ+vqrqK/XyMy8n4Edyyx0dWWh\nOoFtQcWoVwE5KDd2GrAas8t6OZBIIHAPSsD/GZXR/RyRCVr5qASzT9AbpHi9LoqLf8mMGbPp6FiD\nxTILm62ZTZuWMWOGHYDS0m2ma2xuzglN7oL4+ATmzp2KwzEfm83Keef9P5Mwx8W9z2OPLeI3v5EM\nb+HURcRaEGIUYwZ4QcERFixYR3NzDvv3f4jbXR7ay0JcXA4DO5Y9h7Ki70fVNH+CyuZuAaaj/ukb\nBb6XSExZjzHXYIyFq45lFtRkrFzD8Vvweu/ivffUfmVl66mq+qHpXqLLsvLzD1FSsh6n8/NAN/X1\nSwA1DWzTpjKKi+/A650LHCEY/Cm/+c1msaSFUxoRa0GIUaIzwBcseCRUB52NcZqWGg+5jr/+tZfu\n7k+Bi1CW8OmoxiMbMVvRG1ATuIwCvxeoNHzuIjLUg9DPPOC7od8/aTg+zbSfPtXLWGbmcJTQ17eO\nHTvigMP8/e+dtLYas8U3huutZ8ywc+aZc0ICrpAuZMKpjoi1IMQI0bXU+/aZrd/t27twu7+HLqiZ\nmfeTnh7gwAE7s2YFuPRSqKkxCq4+0jJ6cEYGkEpy8hr6+opQJVnXAveRklLIxRd3smePO9yCNLJe\nu2Gdr6EsbTvwIcaBH8apXvX1Gjt33oPXO4nu7kwCgRRUh7PJRJLZrEAaBQVt4WcRbYlLFzLhVEfE\nWhDGEKNAt7Xtwem8CbBRX69RWFiJ2fo1dyCLi8vB6VyK07mFhgYbCQnNmEVZH2kZPTiji4yMOI4c\nyUEN2zgHZWmfTmlpAJhMS8vNqCSyDajGJMmobmcuVAw8DeU670MN67gPOJ1Jk97D5TIniLW0nAbc\nYDi/XtJ1DsrVvhyVABeplT6RLmRHG+4hCOMdEWtBGEPMgzjK0OueIZ3u7ngWLPgdzc0F2O0efL5+\namoiohsMtqISwFTcNxBIxizKHwLrAB8JCXeQmmqnt7cVvz+Prq6bQsdGyrL0TmDLlr2FuW3oY6hE\ntXZUDFwvq1qMst5/HfrcH2rCsiHqOj7GPPAjncjMaj9KvFfQ3Pxa+LmcSBeyW255iS1b1JSv+noN\nn28djz++7LjWEIRYRcRaEMaQgYM4VN0zWPB4FvHOO5U888ylVFa+zYEDqRQWVpKdXcTMmT3s2NGD\nx6N3KFPlUCrTuwg4hEokawb6uPLKqTz++DJKS7dRX38VqtXnFNO5LZaZVFS8SkGBL6r+OQnV4/t7\nqKYo0Znnt6EEWo9xL0QJsB/1wvBjIrHpDSi3ezfqBaAGZWWfvKtbxcONYQOZUyRMHESsBWEM0F22\n+/cHUMlaKlksISGDQCAiOE7n5/n615/D6VyFXtvc3d3K4cOddHbOwFwjnQccxOxyvg/4N/7+919Q\nWrqNtrY9QDHKlW22xHt7J1Fd/VVycx8gLq6SYPBclKguBJ4hMfGXTJ6sceiQUcg7iMTBdXe7FWWx\nbwQuQAm1up+MjB4uuSSN5uYUCgr+Avhpbn52hMqx9AEk+rUdPsn1BCF2ELEWhDEgeg611foAxcVT\n8PniTK5u2EN7+ySU8H0K3IbHsxGP5ybMMWA97mu2llUi1ye0tEyipUUD4oiLewzoIRj8FuamKWnA\nb2lvn40S/UuIWMQp+P134fGsJGJFd6GsZz0uvjD0nQc129ofWidyPxkZbTz00HWjEkueNy+dmprI\ntc2blz7i5xCEsULEWhDGgGj39/TpZ1BVdQUul5va2kiNMXyf/v77gVtRcd9OlGgbY8CHgZ8Dp6Hi\nwy4iItsM/BfG0q1g8FGUqNeiXNyXoBLMsgFzJzRlraeg11/7fJ9HxbHdKBd2H6rZyixUL3Er8+f3\n8NFHHQZvQCRJzelcQUXF6NRMP/TQYpKSamlq6sduD+BwLBrxcwjCWCFiLQhjwGClSR0dbm699QW8\n3lTgXVQf799hsVhD+3Whz3c210w7iSR96f29P48S4FuBNxgqLq72vxPlEg9gdqufTWLiLvx+Y1xc\nb1eqZ3FvBCJWfn7+PVRV/V+WLXsr1B5VT1J7At3CHq2aaRmNKUxkRKwF4SQYrFxI0zhmCdGqVXPY\ntasSl2saNtsBVq8uY+XKWmpqMlGCGBHI/v5VKKH7J1Ss2Si8XSir1rgtK7R/AcrCPow5lpsVtX8i\ng7ce3cP8+Vbee68y1GnsCPA14uJuJxiczWA13IcO5bFs2VuG2Lhu4SeG1tyA3R44mUcuCKckItaC\ncBJEdxnbseNOLJZEWlq+SHQbTSOVlW+H3MRq2tXatetDFmc8oAshqCztIpYsWU9dXStudwCz8LpQ\n7m/jtkmoGHRC6LNuMetx5vej9j8Ns3gfAdYQF5cNTGLTpq+wdu3bNDVl8v77/43X+wsi5VnmGu5A\noIv6+u8BZRQWqpeR3t5EdDe61erF4bhyJB69IJxSiFgLwkkQHXtubc3E7KbeyL598dxww5Ns394F\nZDNvXj8HD9pMx+lWeH19AsqtHRHA5OSPqaqqoKTkJdzuuURiyZ+gBnTEoVzZXyAu7m0SE0/Dau0h\nP9/PO++sQpVyAVyKcksfQrnN1QuFwije7cDdBIMWtm1TLxL6y4YqrzKWZx1Cud3PRDVJ0T0IFvLy\nzmbu3E6qqyO13MXFCdKoRBBOABFrQTgJomPPaqqV0UpN49Chf9DQMBNVp2yhpkajsHAtRoFsa3uX\nRx5ZwvPPr6e/H2ANMAP4kOeeUz2yOzo+QM2n/hmq3KsIVaOs1igsrKS2dkVYDE8/3YHRnR5xbx9C\nJY3prURdpKTcyec+d0FoSMh0ol8kdDIzP6S39ymUlR4EWiksTCMvz0Jb236czhWhPbVQOdbxdyIT\nBGEgItaCcBI4HCXs2mWM6YJRhAsLG+juzid6KEZW1nSCwXtCrTgP4XTm8POf/5WMjM/hdn8ntJ+b\nxMTf8KMfHaCx8QX6+rJRTU+CqCEdlqg1i/jRj14KNQc5hNc7jYHu7TtQGdr5JCSsITGxCJvtIK+/\nfiOZmVmUl3dSXR003YOxWcm5506ltdU8lzovL4etW6/A5ZpDRcVmkzBL0pcgjAwi1oJwEthsVmpr\nr6OiQh9l6QbWceCAlY6OvWRl2Wlvfx9lyUYEsKOjiba2eIwNTF555U5SUuyoEqgkQMPvn857730F\n+CbKMs4Grg8d86RpzY8+eoeGBqMlvQqze/sQytJ+ArievLxK6uuVkObmZtDe3oXDUUJX1yb++te1\n9PfnkJfXyurVXw/fb0tLtOcgKyzmgwmz9OsWhJFBxFo45TlZQRlMpMrLN9HQsCpUvuQCfoXqo51D\ncvJenM6fAq9hFD6//zz8/q8DT2F0b0cGX6SjMrvNk6+s1qmkprbgdJ6FWUhPJ9L0pBs1IUvPFrdw\n6JCV8877T7KzizjrLB93330pNpuVjAwrfv8PUUM4VMz6vvsms3JlLR988AnGF4BJk/6Ow/HdIZ9N\ndAIerBdLWxBOABFr4ZTnZAVlMLGPTjxLTEwmISEPm+0AaWmz+fBDGwOzsveimo34MIuuPviiG5X8\npR8zGYsFtmy5iNLSF4nUQOvrdaJGUBpFXwM+ADz4fB04nbfjdFrYvVtj69YHKC7OGzCas6kp0/CM\nzE1OZs8+86gvNtHPQeZSC8KJIWItnPKcrKAMJvYFBUeor9cTsRrw+1fj96syrbi421GiOR01ZcuF\nSkzTUIMyEjGKrsXyDzTtDSAXaMNYhlVSMpnKyrfxeH6KEtInAC8WSxslJalYLI/wt78l0Nv7CX7/\n5NCx/4pqQ/p703273WdSXb1owGhOu91jeEZ6k5PNwCJmzVp/1Gcjc6kFYWQQsRZOeU5WUAYT+4IC\nH2ZXduT7YLAIZeUeRDUNKUSJbyJKjL9NxH39Ppr2A5S4bgD+D/AUcXFTyM9vZu3aMr73vY9QQl2D\ncnHXk5CQQnp6DqtWzaGy8m2ami6goaGVQOBaw5V7MFvi3YCFtjYbmZn3Ehc3hXnzgjgcX6Gi4lXT\nM7Ja36e42HXM7G7JBheEkUHEWjjlOVlBGUzsm5qMiVjdmEVxHzAXZSk3oTK0dSt6DZo2GX1spGpu\nYkW5x53AfwM/Ixi04HSqeLLdrlFf/yKq8cgW4Iv4/X+jujqVHTueprVVTzp7LOo6JqFqtnWL/VpU\nYxM3Hk8u8G2SktZjs1kHeUbLhxXXl2xwQRgZRKyFU56TFZTBxN5siS5g0iR9OEcD5vnOD2O0ujVt\nKuaksNND3+k9wfVhHjVAOi+++AlPPTWX6upGlFDrDUgWAxtob3cb1r8KPclNZZtnYh7c8SAwFfg+\najZ2JCQgoisIY8uoi/Xrr7/O2rVr0TSNq6++mu9+d+jMUUGINYzJY0VFPeGM6ejv7HaNp5+eg6ZB\nRUUtH3zQR1LSSgKByWhaFgkJCeTkvMGhQzMwzneO7tsdH3+Q/v7lKOFNA/4GrEe5rI3DPJSL3e9f\nxHXX3YHqIKYnkaWH9rMQF5dNMOgCnkPVWbtRZWT7UF3QjIlsn0OJPOgxdIkxC0JsMKpiHQwGufvu\nu3nsscfIy8vjG9/4BldccQWzZs0azdMKwogRnTzW1xfJFDd/52LXrofp6cnH7U5GtQA9D11Uu7s1\nurvvZaBLvBdju87MzG5crl+i3OTdKGv6d8TFHSYYfCp0nC7cABb6+magyrgexNyx7A4CgWyUq/si\nlBv9bsP390ZdS1doTY3MzGYuv3y9xJgFIUYYVbH+xz/+gd1uZ+rUqQB87WtfY9u2bSLWwrjhaJni\n5u+2hAdzRFzKUzBbrlNR85/10qckoAKYTGbm/Vx+eT61tfmodqLGcqtzCAZ3EElYM2drJyU10tc3\nGbgg6nznh873o9Dn56K+n0JcXCWZmflcfHEQTfPT3PxsyJX/LWleIggxxKiKdWtrKwUFBeHPU6ZM\nYffu3aN5SkE4LvQZ0sYhGw899NWwUB0tU9z8XRpmIcxhYLb1p6h48JbQfsbM7CyqqpYye/bTUesk\nomZbzyYya/paVO/wmUAj55+vMWXKeurqWnC7jefrwzzCMtqqn0QwuBq3WyM9fSO//vWiE3+QgiCM\nKqMq1pqmjebygnDSRGZIR4ZsJCVFXN3G5LGiol7uvjviFjZ+19a2B6fzUpT1qgGfEB9/iJSU/Xi9\nOaSmdjB3bhJJSX/hwAErDQ31GIWzp6cJgNTUZjweo6C+jZqQFT2M42z07O3333+A555bSmNjE5dd\npiey7UG9GNQYzrMAWE1CwukEgx0Egz8I3YmFl1/uw+VyizUtCDHKqIp1fn4+Tqcz/Lm1tZW8vLyj\nHpObmzGalzTmyP3FFk6nMdlL/Xz5Zbj55s08/PBCiopO49lnrx/02I8//pitWz/C651OUpKb5OT7\n6euLCGt//4P097t5//2vMmuWPXzcsmUbaGiwY8z61rQscnMzyM+fTUuLMRu8ELOl3YNyg98U3max\n5JKbm8HNN+8OCfUSYD5KqA9jjImXlc3i2Wf/lWXLnuJPf5ocWkPD5UpizZo3ePrpa07mccY04+3v\n5vEi9zexGVWxPvfcc/nkk0/49NNPyc3N5YUXXuCXv/zlUY9pb+866vfjGX1YwkTls7y/kRoQUVjY\ngbI8jVauxp/+dA11dXdywQWn09ycg93eyaOPltHfHx8+trj4v/F6VUJXX9/AMiz4HL29izjnnDWc\nddaF4evcuzcF1RAl0go0Le1+PvjgAG1tjcBqIpb0PZhd1wdRLvGI0H75ywHa27tC6+qubiuwHKv1\nAdzuVeFrfuGFRzj33Cc57bROMjPvx+M5K3TMQvbufW3C/v2Uf3vjm1Ph/o7FqIp1fHw8a9as4Tvf\n+Q6apvGNb3xDksuEESE6S9vne4SkpNTjFm+Ho4QdO35Pa2ukhSf4AQutrZnU1NwYPseKFea4rsrC\nNoqzLvzmjmB9fRdRX78k3IpUNTHRO5KloHqEZ1BS8gRO57+gLO40EhPfRNM6CQSM1xYAesOJYfPm\nBXnooa8Aegx9Sfj4KVPexGJJQLnmu4EFBAIZNDRYaGhYQWHhWjyeReHrLShoobx8k0zIEoQYZNTr\nrOfPn8/8+fNH+zTCKUZ0lvb27V243SrufDzDOGw2K7m5X6S19RuGrZtRYmseB/nxx+mmYxMT38Pn\n08up9gNWLJZVaNoUVCb4wtA6R8JrNDVl8vTTc/D5nmf79k85csSH378aj8cSilXrE7bgnHOCfPCB\nJ6pF6BPAdSxePPD+VAxdnyftxuc7Pfyyoa7j58A09KSzrKzpzJ0bicd3dSXIhCxBiFGkg5kwLonO\n0lZzno8+jGMo13lHxweYLeJ/oCxRv2n71KkdpvXmzTuNurprUAKryq1UUuUToWP+GlrrJlQzkhfZ\nv99LRcWr3HnnpVRWvs3WreD361neVlRWOeiZ521tB+jtjVxDYuJHLFwYqX8+WjigtHQbZsv/QmAR\nqu5aY9as/rAY5+ZmcP75zx7zGQqCMDaIWAvjkugWnz5fPzU1Rx/GMdQozKys6TidxqQuG5BGUtLf\n8fkiLmhN8wMRgfzb3zKJuLKNomhDJXlp5OfXc/75f2H7dhdu909wuy1UV2vs2lUZVZetZ3nvITPz\nAOnpnTQ2FnHWWekEg/fQ2WnHZjvIpk3fZMaMSLLa0cZ75ucbx2lG3PLJyZlkZ1fS2FhEefkzOBwl\n5OZmyIQsQYhhRKyFcUl0r2qXy01SkhLv/PxD+Hx+Skqeo6OjiezsImbOPGKY0+wGati6FcrLn+G0\n047Q0BA993k+gUAzxlpop3MzYBbIwTuBvQtYsFrfp67u/2KzWSkt3UZ9fUTQW1ryMQu8L3TeFfT2\n/gaPZzVOp1qvrGxod/TRmrZYLAHMDViUWz47243TuSo8xxrW8+yz18uELEGIYUSshQmBUbzLyzdR\nXX0jSvwiohSZ01wDLKe3V1m5Cxaso6xsPXV1AdzuScDngQcJBmcCT6JaeU5mxoxuYKBAKsv7DlQ8\nuAOYDnhITvawbNlb2O2dFBT4TFZrMNiIWeCdwCpAw++fynDd0UezhpubC1DDO9TLSUrKc5SWQmNj\nUagP95kAAB2VSURBVOhFwLy+DOsQhNhFxFqYcETE1Ni9y0J2dhFz565n61bo7Y1s37LFD8SRlLSP\nSy9NYceO9/H7jT221wCn8cYbbXz8cdMQ8fJ/Ae4HvoxyN/8Tra1NtLbGU1+vYbM1oAR9BtCIGk/5\nKOACckhI8HHmmU9y8KATtzsXo5C3tb2Ly6WGhETHp49mDUeuU5VxlZYqC728/JmQRS3ubkEYL4hY\nCxOOiEh1YU4QcwNJJCe3mJK21Pzoa+nr0/jb39bQ3z8Ts+V8EbAEp1Nj6dJKamuvo69vHVu39hMM\ndqFqnp/D3GnsTuDfw59drmYiPb9dwC9DP28DLAQCGtOmraOjw4/bnYkS9rMAC07nCioqlAteud87\nqa9/kc2bXyQ//xCbNpWZ4tg6Qwm5uLsFYfwhYi1MOHQx2rcvno6OylDMugefzx9yj3cCG7Bavbjd\nzcC3ULHddPr6JqHKsIyWc6T0yuWahs1mJTk5iWBQj1u7gD9hFvjZUZ+Nru0tqOlYz5v22bEjLtTA\nxAIsxVjGFXGFW1Bu/GsIBi3hF4j6+h8OeA5DubXF3S0I44+4sb4AQTgROjrclJdvorR0G+Xlz+By\nucPf6WL05z/PZ+7cacTHJwAaBw7o7nErcC3Tp2eRnNwN/A8qE/tS1HCM6cDtKOt3FZAMPAW4sNkO\nAkZXuxvVuSwdJewQGdox1Gd96EdX1D6HMQu8uYzLbu8M7Wd277tc007gCQqCMJ4Qy1oYM06mZejR\nSpaG2ieSYBaJ1WZm5vL66z6MFmvEoq4M/VKfU1LuZNOmbwJGV3sNKiFtPpFe3/XAdeidxPLz/8E5\n56Tw1lsPEAza6OvbT1/fYlR2trLwi4sT8PnSTOVn8CbgJjHxQ1avXoamESr5ysKY+Ka/QAiCMHER\nsRbGjOEI7lAcrWRpqH2ys4v4whfWsWNHHHAYny8Nl+t0Iv20zRYrmMurZs/+AmvXvk1T00cUFPhY\nsOB3vPZaGr293ai49TWhdeopK3s93EnM4bjB9BLicrmpqNBjxgEcjiux2azh8jOVld4K3ArY8Ps1\n1q5dD2CqzY6LqyQ/HzZtWjKsZyYIwvhFxFoYM4YjuDC4BR6dkd3W9i6NjbOprHw7FKtuort7CkYL\n9PDhvRw4kIjb/RNAjcMsLFwL5KFi1p+iOnzplu1HGC3xDz98h927fwxsob5+Cvn573DxxbBt23J0\nK1qNpnRTVXXLkPetu+n1+9LLuxyOEqqqluJyufnSl17G7Y5MBDPHrNXPL3zhbLZuveL4HrogCOMS\nEWthzBhux6zBLHCHoyTkEv48cASncwVf//rDIctT1Vfr61qtD5Ca6sfpXAG8gVHwsrKm09PTh9t9\nLSr+vBHoBeJJT7fS3R3pbOb1TkElhy0HOmlp6aalxQU4UAlk76EakMwIdwY7mls/+r5eeukOzjjj\ni8yceYR587yDdGTTpMOYIJyiiFgLY8ZwSog6OtzU1bWiMqe7gIXU1QUAyMs7G6cz4gJWiVadKAs5\nsj9kk5WVHJpdbSznctHR0YT6Z2BM9Ipj0qQP+dKXckNWs3EQhje0dgORUiy9i9mtqBj2tVRXR9z6\nQ8Xmoz0LXu9cdu9ewu7dkUYtA5+NlFwJwqmIiLUwZgynhGjlytqw21qJ4gbc7klUVNSGRk1GLE2b\n7QC9vS+i1y4b909N3R/6HEnqSk1tCVnincCDoTOqY71eDYvlEcrK9CYqicBpgHGKlTG+fQ7K6s4I\nb9Nd10PF5gc2V4mUiG3fHsfOnZcPsMyl5EoQTk2kdEuIaQa29vQBC2lqysThKKGsbD3nnfcsZWXr\n2bSpDKvVO+j+2dlFoX1fo6wswM6dV5KXdzaRUq5CoMh07JtvJlFVtZTSUg3l+p5i+F5PSoOI0Kah\nLHe1TXUecw8am+/ocOPz9ZCYeCeqocq9wFfDx+ovJIIgCCCWtRCj6K7j/ftbMDcoSQYmY7d7BrXM\ni4vfCrmgzfvPnNkzYF+zZbsA1S50cfjYI0eacbnchvi4RiQBbQEqLn4xSqi/Sn7+b9C0Plpbn0OP\no1dUbB7gAbDbPaxcWUtNzfdRVv2LZGZm0Nv7K/z+81Gu9oU0Nb02sg9VEIRxi4i1EJNEXMeq21hm\nppf09BaysuzMmrWeVasuoLx804A4sB4Hb2xM5fDhvUPuv2LFGezc+Qlxcb9H09rQtGyU5RwZien3\n9/GlL71McXE8mzYtYfHi/6Kt7V5UMtmnXHppOllZ7tCam3E4bmDZsrdobdXj6CrePm1aIYWFkU5q\nDsflLFv2FsYGLTNnPovdnkF19VVE9wQfbu25IAgTFxFrIWYwJmIpi7oTo5ht3fp/wvuqyVqROPCO\nHXdywQX/v717D66yvvM4/s4dSAI5QIBEuiGAEay2TC11YVxCsY0SwKBopXWkRZuV0sEx7Qw3124t\n3VBTrbZDhyJip1AqWNYkUAhVA4RWKcvWTTEqZYg0CLmS5DQJhlzI2T8eTs41yUlyDufJyef1jyR5\n8jy/x4if/G7f379QVTWelBQb+/bdicVyj5frjbra+/f/DZttKvZtXUbFskSMFd23AGeBWVitVyks\nXAgcYO7cmRQUrMAepnFxO/rorR/qPsMabMye7VhwVlv7IcYsVAuw8PqCMc8V7mvXHtA8tYgorMU8\nPM+Jfg3jPGnPbUru88A1NaMpKjIWf5WW2igpeZ709AleVl4bVcpsNvszXgVGYdTyjsGo8x2O8yEc\nsIeKitFERUW4PLOqarzHO2zYcAenTm2msXEyHR2tdHZ67iNft+6oS3GT5OTN5OU9isWS4LHCvbfj\nMUVk+NACMzEN9wBOSLjavXjMfZuSo0421/853uV7rdYZFBau6F6k1VNdbSOclwMP4Dhw4zxGr95+\nTSy1tR9y7twZl2c6/wJhr1V+773/Q2VlCq2t99HZGef1evf3nDDh1u6hbvf30l5qEQH1rMVE3Lcy\n/eu/dhET00RFxWjWrj3iUmTEfcjY4LywrAXn3qx9LrukpBqr1blKWTze64I7evUjRpyisvJx4G3g\nBSIj4/nqVyPIy3MMs3uOCuwBFpGQ8DyTJ6fS0HCW8vIUsrPfICmpvcfiJjq+UkS8UViLaTiOthxF\nQ8NZ3n13NE1NI4H5lJaOwbl2uMWSwNGjj/LUUwc5caKZrq6RjBr1X3z6adL178nEOQjtK8dd63I3\n0dLSRXGxZ4979OirhIe/CtTT1TWRq1dPYN9j3dlp429/20xj4z9Zu9Y+x96Ja489DhhDevpE4FPK\nyjZQWRlGWZmNhQt/1UPBEx1fKSLeKazFNOxBlZ2dT1mZY07Xfq6z+/ytxZJAdPQorNYngDCamowg\njI6OoqLimNeeqc3m8hG5uf9Gbu4ujh6toqnJ0eP+9NOP6ezcdP3j3TiOtQQIo7LyNpYuzae6ehoQ\nAdTg3LNPSDhDenqj28pv43urqpJU01tE+kVhLUHR2/GYnoVQjLlfb/O37td6C0LnZ9XWfkBl5SPA\nCUpLLZw6Vcirr36ZoqIyjCpmxtx3Z2eM030XERb2PDabYw82NFJdzfW2NQNfJyrqP/nsZ79w/ZeE\n5S7z0KrpLSKDobCWoOjteEz3cHPupdo5iqZ0Ar/AmKOezJkzZzl/fjqpqSlenwVZwHPAOowe8hKW\nLv0B7e3P4dqTHwf8DmNOu4nY2Gu0tPwEo6zoFYzKaP/h8j2xsVO89pg1Dy0ig6WwlqDo7XhMz3Bb\n7lIYpKHByoIFu64vLmsB/oF9q9XVqzbuv38zpaVrenyWUVrU8XFbW6rb12MxVoR/B8ee6k20tKzC\nqP8dS2Rkk8u2LIhlzhz7QjdXmocWkcFSWEvA+XIetfPQcF/h5r5PGTbjHLaNjZNdnlldfRqoAyYB\nTURHv097u+PZMTEfc/Wq4+Pw8L8QE3Mzra2OeyYm3sq8eYc5e3YkKSlW2tvDXY6wTE4u46WXHvXr\nvyNVLhMRO4W1BFxP51E79557Kh/qjWtP+Z/ANYzDMIxqYBbLRS9D369h1P22MW9eM7Gxjmd/97uZ\nfOtbRiETi+Ui+fnfIDfXtcb41KmfsnfvCurqjIM6GhutREc79/4fHVS49jYtICKisJaA8zbk7d57\ndi8f2ltYTZpUh2Pl9SGc545HjPgB+fkP88QT53Ad2o4HrEAR77wziowMG3v3Oupul5be7vKMvDxj\nq1hP88z+HtrubVpAREQVzCTgfKnK5R5Wb74J2dlv0Nho7b7GXiXs3XerMHrKBzAWejm+b8aMO0hN\nTfFS4awZo/DJclpbV7hUN3O/f0ZG8fUiLF9mz547AFi27CSf+cxmFizY79Euf1DlMhHpjXrWEnB9\nrYb2drBFa2sUhYXLgV0899yXWbfuKCUlnVitMcDNGNXGwFix7Riurq39kIwMSEq6wsKFO6iqGk9S\n0mWgg2PHYl3mod17r84nfZWWHqKk5C1Gjap2mR+/eHEPZWUr8PcwtVaMi0hvFNYyIN4WRCUmxvf6\n9Z7mdD0XjD0HrMIeqJ6lPH+CUdP7MAAjRjzDzTfPor7+LJWV36Gy0kJpqY2srF28+ebd3W2Jiemk\ntXU39pO2ej4cxCg9arWGYbXux3PPt/+HqbViXER6o7CWAfG2IMo4PrLnr/cURp5bq27FOBrTGA72\n/PoM4JcYx1oa27UmT95BRMStVFZauq9zPuXKOewTEp4nPX2i18NBjLY6lx5twbPmuIapReTGUljL\ngHhbEFVfbyU7e7+X86h774m6b+OaNOk0V69eBuppb48lKanNrUjKOVpaEl32OZ84EU56uvftYO5t\nnTLlZrZv77l4iethHwtJTt7MuHFpNDaeIyHhM0yb5jgFTFuuRORGUFjLgHjbJ716dZHP51E7y8tb\nQFvbDv7yl3CgHputDat1JRBGUZG3gy+Wc+edr2G1Ovd468nLM+a43ed9fS332dNhH/ZtWYmJD3Zv\n3bLTlisRuREU1jIg3vZJL1z4v7ifRz1lSkGfC6YslgRiYqKxWpdgzEN3YAR9JpDgtd73nDlxFBW9\nhrElq5k5c+KwWBJYv/4LLFu2n7//PYk//nEbqak3M2WKY7GZL4u3+jN/rC1XInIjKKxlQLztk25s\njMJ5fjc9PdLrcLOd8xCyMWy+H1iBo7e8B1jutSf80ktLiI4+SkXFNVJSOsnLWwzAsmX7XRarffTR\nHj76aEX3YrPBcB7m96USm4iIvyisxS+MHuV8jIAdQVTU/1FefgvZ2W/0OI9rDCHbe9MzgCqce6kj\nR3aQkbGrx+pm3nq/jY2TXe7hz9XbzsP8PVVi05YrEQkEhbX4hdHDHIOx//l3dHQ8S1lZGGVlrvO4\nnr3p/wYex3FutKOXmpFB9/nWvs4LWyyf0NoamNXb5887rxL3XolNRCQQFNbiF3l5CwgL28mxY9do\nauqgq8v7PK7nnukXcD43Oioql8jIz2CxXGTjxvsAKC8fhXNIfvzxKI/n238JGDNmOg0Nz2CzJREW\nVk1q6nTS0nb5pcebmtrMqVMa8haRG09hLT5raLDy1FN/vL5q+zJz5sTx0ktLsFgSsFgSiI6Oxmpd\njrE4zHuouS/IioyMp7PTfu0YOjpS6ej4Bq2tNnJzd7F9ewoNDX93uV99/VngHpe2uf4S8DWysnax\nffsK/Gnr1kza2jTkLSI3nsJafLZu3VEOH7YPWdsoKnqN6Oij3cPAjmHiTGDP9TlnXELNfUHWqFEN\nxMUZ+5g/+eQ8Vms29gM39u/v5NSpXxAbm4QxFx4HtDB2bIpH227EquyxYzXkLSLBobAWn3lWEoun\nouJa99cdw8QJwHIyMjznlh2FRzqxWkfQ1PQdmprGMHv2LqZOnUBh4Rjsq8BttjAqK22MGPEMsAl7\nwE+btsujbVqVLSKhTGEtPjMC0V6TOxb4gKQkxypvX4aJ7QuyMjKKKS1d2v35iorR7N17B7CLwsJm\nHD3pZq5ds7gVRfG8r1Zli0goU1iLz/LyFnDy5C+prjZqcsMSYEf31/szTOzeE66t/ZCHH4aUFBsx\nMVW0ta3u/lpExA/Yvv3fe72fVmWLSChTWIvPLJYEJk26jepqx1B4VdX4Ad3LuSdcW/uhy2lZ8fHb\naWtzPCM19TZ/NF9EZMhSWIvPvJ07PdC5YeeecEYGLqdlRURYcV79nZbWNui2i4gMZQpr8Zn7udPJ\nyZvJy3t00Pd1HxKfMyee6GjNP4uI2CmsxWfuq8EnTLgViyWhuyBJZaWF5OSGfh8T6bk4bLGOmRQR\ncaKwHqYGcg5zT9ujPKuS9e+YyIEsDtM50iIynCish6mBnMPc0/aoG3VMpHNA19Z+QGXlasCic6RF\nJOQprIcJ957oxx/H0t+A7akH3FOP29+9X9cefBbGXuyv+9x+EZGhSmE9TLj3pJOTc+mpfndf3EN4\n40ajmIkxZ93Y3eMeSO+9N54V1GKv/1kVy0QktAUsrLds2cLrr7/OuHHjAMjJyWHevHmBepz0wT3o\nxo6dwuzZA1tx3VMIJybGU1fX3OMzB9v7de/BJyeXMWFCl1aMi0jIC2jPeuXKlaxcuTKQjxAfuQfd\ntGnXBtzL9TWE+1uvu69hc88580e1qExEhoWAhrXNZgvk7aUf/Fk729cQ7u8z+xo2V0lRERmuAhrW\nu3fvprCwkNtuu43169cTHx8fyMdJL/wZdL6GcH+feaNWlYuIDDVhtkF0f1euXMnly5c9Pp+Tk8Os\nWbOwWCyEhYXx4osvUldXR25u7qAaKwNXX29l9eoizp+PIzW1ma1bMxk71lxDyA8//Dtef91Y3Q02\nvva1Pezd+/VgN0tEJOgGFda+unTpEqtWreLAgQN9Xuu8QCnUuC/AupGys/NdCpdkZfl/X/Jg36+x\n0cratUddeuxmmpMO5s8v0EL53UDvN9QNh/frS8CGwevq6khMTATgrbfeIi0tLVCPEh84hpitQBFv\nvgnZ2W+YqvKX5qRFRLwLWFj/9Kc/5aOPPiI8PJybbrqJH/3oR4F6lPjAsSisCFhOa2sYhYUD3/vs\nbeW2L78diohI/wUsrPPy8gJ1axmADRvu4NSpzVRVTcJmG/wiLm8rtwsKVviruSIi4kQVzIaJzZvf\nu3685Wv0VrnM3mMuL4+goaGCcePSmDr1isdwuVZui4jcOArrYcIRrpnAHkaO7CAjA49tV44e8x5g\nA5WVYbz/vudwube91vX1VrKz9+skLBERP1NYDxOOcE0AlpOR4Qhf5/nnf/yjEyOA43DuOZeXjyI7\nO9+jHrh95faGDV9g1qxfcfHiOvpbC1zHXYqI9E5hPUz0VsjE9TSr3RjD5M04D5dfvnyGsrKnsQdx\ne/sOfvObh7vvkZ2dz8WLtzKQoXF/H/ghIhJqFNbDRG/bolznnxeRkPA8kycn09Cw+fqc9accPZqA\ncxCfOBHu5R4tDOQkL81/i4j0TmEtbvPPY0hPn8j27fe5XJOWthXnIIZ6L/e4D2OuO5bk5DLy8h4d\nwPN13KWIiDuFtfhU63vOnDiKil4D4oFm5syJ87hHTMxhzp4dSUqKtV8nYvnzkBERkVB0Q8qN9keo\nl5Qbqu/nSynQofx+vgjl9wvldwO931A3HN6vL+pZB0ioVfhSKVARkeBRWAeIKnyJiIi/hPd9iQyE\nVjiLiIi/KKwDJCXlnxirpkErnEVEZDA0DB4gWuGsymQiIv6isA6Q/i7ICsVgU2UyERH/UFibRCgG\nm+btRUT8Q3PWJhGKwaZ5exER/1DP2iRCseSm5u1FRPxDYW0SoRhsKqQiIuIfCmuTULCJiEhPNGct\nIiJicgprERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgpr\nERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NY\ni4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NYi4iImJzC\nWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicoMK68OHD7N48WJmzpzJBx984PK1\nbdu2kZGRwcKFC/nzn/88qEaKiIgMZ4MK67S0NLZs2cLs2bNdPl9eXk5RURGHDh1i+/btPPvss9hs\ntkE1VEREZLgaVFhPnTqVKVOmeARxcXExmZmZREZGMnnyZFJSUjh9+vSgGioiIjJcBWTOuqamhqSk\npO6PJ06cSE1NTSAeJSIiEvIi+7pg5cqVXL582ePzOTk5LFiwwOv3eBvyDgsLG0DzREREpM+w/vWv\nf93vm06aNImqqqruj6urq5kwYYJP35uYGN/v5w0ler+hLZTfL5TfDfR+Q12ov19f/DYM7tybXrBg\nAYcOHaK9vZ1PPvmECxcu8LnPfc5fjxIRERlWwmyDWKb99ttvs2nTJhobGxk9ejQzZszglVdeAYyt\nW/v27SMyMpKnn36au+66y2+NFhERGU4GFdYiIiISeKpgJiIiYnIKaxEREZNTWIuIiJicacN6x44d\nzJgxA6vVGuym+NXPf/5z7rvvPpYuXcrjjz9OXV1dsJvkV3l5eSxcuJCsrCzWrFlDS0tLsJvkN73V\nwh/Kjh8/zr333ss999zDyy+/HOzm+NXGjRuZO3cuS5YsCXZTAqK6upoVK1aQmZnJkiVL2LlzZ7Cb\n5Dft7e089NBDLF26lCVLlrBly5ZgNykgurq6uP/++1m1alWv15kyrKurq3n33XdJTk4OdlP87tvf\n/jb79++noKCA+fPnh9x/gHfddRcHDx6ksLCQlJQUtm3bFuwm+U1PtfCHsq6uLjZt2sSOHTv4wx/+\nwMGDBykvLw92s/zmgQceYMeOHcFuRsBERESwYcMGDh06xJ49e9i9e3fI/Pyio6PZuXMnBQUFFBQU\ncPz48ZAsW71z506mTZvW53WmDOvc3FzWrl0b7GYERGxsbPefW1tbCQ835Y9gwObOndv9TrNmzaK6\nujrILfKfnmrhD2WnT58mJSWFm266iaioKBYtWkRxcXGwm+U3X/ziFxk9enSwmxEwiYmJzJw5EzD+\n3zJt2jRqa2uD3Cr/GTlyJGD0sjs7O4PcGv+rrq6mpKSEhx56qM9r+6xgdqMdOXKEpKQkbrnllmA3\nJWBefPFFCgsLiY+PD6lhK3f79u1j0aJFwW6G9MJbHf/3338/iC2Sgbp48SJnzpwJqQJUXV1dPPDA\nA1y4cIFHHnkkpN4NHB3T5ubmPq8NSlj3VG/8qaeeYtu2bbz66qvdnxuKvZi+6qnn5OSQk5PDyy+/\nzG9/+1vWrFkThFYOnC/14rdu3UpUVNSQmyscSC38oWwo/v0ST1euXOHJJ59k48aNLqN3Q114eDgF\nBQW0tLSwevVqzp07x/Tp04PdLL84duwY48ePZ+bMmZw8ebLP64MS1j3VGz979iyXLl0iKysLm81G\nTU0Ny5Yt4/e//z3jxo27wa0cOF/rqS9evJgnnnhiyIV1X++Xn59PSUnJkBw1GEgt/KFs0qRJVFZW\ndn9cU1Pjcx1/MYfOzk6efPJJsrKy+MpXvhLs5gREXFwcX/rSl/jTn/4UMmH93nvvceTIEUpKSmhr\na+PKlSusXbuWvLw8r9ebasI0LS2Nd955h+LiYo4cOcLEiRPJz88fUkHdl4qKiu4/FxcXM3Xq1CC2\nxv+OHz/OK6+8wtatW4mOjg52cwImVHqkt99+OxcuXODSpUu0t7dz8OBB7r777mA3y69C5WfVk40b\nNzJ9+nS++c1vBrspftXQ0NA9PHz16lVOnDgRUv+//N73vsexY8coLi7mZz/7GXfeeWePQQ0mnLN2\nFhYWFnJ/0V544QXOnz9PeHg4ycnJPPvss8Fukl/9+Mc/pqOjg8ceewyAz3/+8/zwhz8MbqP8xLkW\n/qpVq1xq4Q9VERERPPPMMzz22GPYbDYefPBBn1amDhXf//73OXnyJFarlfnz57NmzRqWLVsW7Gb5\nzV//+lcOHDhAWloaS5cuJSwsjJycHObNmxfspg1aXV0d69evp6uri66uLjIzM0lPTw92s4JGtcFF\nRERMzlTD4CIiIuJJYS0iImJyCmsRERGTU1iLiIiYnMJaRETE5BTWIiIiJqewFhERMTmFtYiIiMn9\nPyQ+uNKCpR6MAAAAAElFTkSuQmCC\n",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAFKCAYAAAAnj5dkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xt8VPWdP/7X3M5MkpkkM8mEAAER\nQoICgUBALkUEQ7FucekDEeWL3VZXu121dler39pu1Vbb77b+2m1/3277qNXa2kUptGttt/tDEWqp\nyDWBiC6ES8slXDJJJpfJ3C+/P8JM5nLOmTOTmWQm83r+RebMnJyTAO/z+Xzen/dbFQqFQiAiIqKc\noR7rCyAiIqJYDM5EREQ5hsGZiIgoxzA4ExER5RgGZyIiohzD4ExERJRjtGN9AWE220DWzm02F8Nu\nd2bt/LmukO+/kO8d4P0X8v0X8r0D+XH/VqtJ8lhBjJy1Ws1YX8KYKuT7L+R7B3j/hXz/hXzvQP7f\nf0EEZyIionzC4ExERJRjGJyJiIhyDIMzERFRjmFwJiIiyjEMzkRERDmGwZmIiCjHMDgTERHlGAZn\nIiKiJDy+ADrtTnh8gVH5fjlTvpOIiCjXBIJBbNt9Gq3tNvT0e2Ap1aOxzopNq2uhUWdvfMvgTERE\nJGHb7tPYdfhi5Ovufk/k683NdVn7vpzWJiIiEuHxBdDabhM91treldUpbgZnIiIiEX0OD3r6PaLH\n7ANu9DnEj2UCgzMREZGIMqMellK96DGzyYAyo/ixTGBwJiIiEqHXadBYZxU91lhXCb0ue20pmRBG\nREQkYdPqWgBDa8z2ATfMJgMa6yojr2cLgzMREZEEjVqNzc112LByBvocHpQZ9VkdMYcxOBMRESWh\n12lQZS4ete/HNWciIsqa0a6sNV5w5ExERBk3VpW1xgsGZyIiyrixqqw1XvDxhYiIMmosK2uNFwzO\nRESUUWNZWWu8YHAmIqKMGsvKWuMFgzMREWXUWFbWGi+YEEZERBk3VpW1xgsGZyIiyrixqqw1XjA4\nExFR1ox2Za3xgmvORESUMawIlhmKRs7t7e34x3/8R3zmM5/Bli1bcPnyZTzxxBMIBAKwWq34zne+\nA0EQYj7zzW9+E8eOHYNKpcJTTz2FhoaGrNwAERGNPVYEy6ykPzGn04lvfOMbWLp0aeS1H/zgB9i8\neTO2bt2K6667Djt27Ij5zMGDB3Hu3Dls27YNzz//PJ5//vnMXzkREeWMcEWw7n4PQhiuCLZt9+mx\nvrS8lDQ4C4KAF198EVVVVZHXDhw4gFtvvRUAsGrVKrz//vsxn3n//ffR3NwMAJgxYwb6+vrgcDgy\ned1ERJQjlFQE43R3apJOa2u1Wmi1sW9zuVyRaeyKigrYbLG/lK6uLsyePTvytcVigc1mg9FozMQ1\nExFRCjy+QFYzppNVBHt150mcPG/ndHcKRpytHQqFMvIes7kYWm320uytVlPWzp0PCvn+C/neAd5/\nId+/xVKCl3/3IfYfvwxbrwvW8iIsmTMR962bDY0mc4HRVFYEq7kInXZXwjG9oMW+41ciX4enu4uL\nBDywfm7GrkFMPv/u0wrOxcXFcLvdMBgMuHr1asyUNwBUVVWhq6sr8nVnZyesVvFqMWF2uzOdS1HE\najXBZhvI2vlzXSHffyHfO8D7L+T7t1pN+L+/ao3pDNVpd+HNvWfhdHkz3hmqYUZFzPcKC4WCou9/\n79glfGLxlKztfc6H373cw0Naj07Lli3Dzp07AQBvvfUWVqxYEXN8+fLlkeMffvghqqqqOKVNRDSK\n3F7/qHaG2rS6Fs1NNagoNUCtAipKDVg+pxpur3hwZgMMeUlHzsePH8e//uu/oqOjA1qtFjt37sQL\nL7yA//2//ze2bduGSZMmYf369QCAf/qnf8K3vvUtLFiwALNnz8bdd98NlUqFp59+Ous3QkREw+z9\nyTtDZaI4SPR6dnxFMAA4cd6ObpHrYAMMeUmD85w5c/Dqq68mvP6zn/0s4bXvfe97kT8//vjjI7w0\nIiJKl7l0qDNUssCYLFlM6rjcvubooN9YZxWd7mYDDHks30lENA54fAHYel1AKASruRhWQSsbGLUa\nFbbuapcsGpKsqEh4X3NYONELQMx6tlgDjIYZFqxqnAyPL8AALYHBmYgojwWCQbz+zim898EVuL1D\n68gGQY3mxdfhzlumAxDvDJUsuMod37Byhsx6tg0bVs6IBN3oBhg9/W7sOnIRbae78MfWS9xWJYPB\nmYgoj23bfRrvHOmIec3tDeL3f/4L3G6faGeoZEVD1i2bJnv85nmTJNezu/s9eHXnSXz29lkxAVev\n02BPawf2tHTEvFdstE1sfEFElLfkgiwAtJy0RaaOq8zFkdFssqIhFzsdsscRCsFSKp3Mte/4lYSy\nnUqqiNEwBmciojE0krKWckEWAOwDHtHtSmVGvWRwNZsMqKkySh4XdBpYyorQWCdfuyI+4CZ7IOC2\nqlic1iYiGgPpdHGKz5wOB1mxjGwAMJv0otuV9DqNbLKYqViQPO72BvDG3rPYtLoWLrcf70VV/4oW\nv11L7lq5rSoRgzMR0RhQmu0MyAdyqSAKAAvqrZLZ0GJZ1OFkMQBYv+J6/LntciTJLFprexc2rJyB\nLWvr8T/netAz4E14T3zATfZAwKztWAzORESjLNn6a3S2MyAfyDetrkUoFIrL1tagefFU/O2y6ySv\nITqLWmwfs8Ppg0ckMAOxo+IF9VWKA26yBwIaxuBMRDTKlKy/hqeDlQTy/7WmHnfeUhuzz7lmUrmi\n2tLhZLF4SqehUwm4yR4IaBiDMxHRKEtl/VVpINfrNKixKu9hkKwymNJp6HQCrtQDAQ1jcCYiGmWp\nrL9mOpHK6fFh69uncOJcD+wDXlhK9WiorUTzwhpYSg0x3zuVUTEDbmYxOBMRjQGlgS9TiVThpLL4\nJK/ufg/2tAwVB6mIyxgXGxUDQHefO/JnTk9nB4MzEdEYSGU6eP2K6+F0+3HinB29Do9oIE82TR2f\nVCZGKmNcr9OgoswQyRjv7vfAIKgBqODxBliGMwsYnImIxlB4v7LSzk9LZ1fjnjV1KNZrJd/TWGfF\nw3c1Rs6TrJJYPCUZ49F9mlmGM/MYnImIRlH0CFerUaXc+em941fg8vrxd7fNgqlYkNxmVVwkYP3y\naQCSVxKLl0rGeDSxoE7pYXAmIsqAZNPKYiPcYoMOFzodkfco7fzU0t6F1vY/Y7K1BE63T/Q9+49f\nxicWT1FUSSye2aSH1xeI1OVWGtzjgzqlj8GZiGgElJbhFBvhSgXLZJ2fACAE4KJtUPJ4V68LZzv6\nMH1ymWxSmZhBtw9Pv3woci/rV0xXFNxZhjNzGJyJiEZASRnOVNd8ozs/KR3tJlAB33n9aCQDO7G3\nsx71U83QaVU4ftYO+4Abgk4DtzcQWU+OvhclwZ1lODOHwZmIKE1Ky3CmuuZbbtQDKhUaaitj+h+n\nIngtXyv+YUEsO9zjC8DW68K//eqoZC3tZ+9fhEAgiNZTXeh1eGEQhj7r9QVYhjMLGJyJiNKktHpX\nqmu+To8fT790EGaTgImWYlzucY74WsMPC2L0Og0ErRp2kQYWwNC9bH37FE6et6PP4YXZqMf8ukps\nWDkdDqeP+5yzgMGZiChNSqt3KV3zFbQqeP2hyOh1qNuTF3qdGh5fUPazydgH3Hh150mcPG8XXRuX\nuxdBp8G+qNaQdsdQ4RKNWsWtU1nC3eJERCny+ALotA+NZhvrrKLviV9/3bS6FsvmVMue1+sPib6u\nUqV5oVEEnRr7jl9Bd78HIQxPd7/+zikAww8Q4sSvq7W9Cx6feOcqGhmOnImIFBLLzJ4/sxKrF07G\nsVPdsmU4NWo17l1bj5Pn7Sknebm9QSyfU40j7TbRNWElfH7xkfd7H1zBnbfUQq/TYNPq2si6cp/D\nC0upAbOmluO9qFFzNG6dyh4GZyIihcQys9850oHmpho898BNSctwprqlKUytGirh+T/netIKztWW\nIlzpcYkec3uHksEmVhRj2+7TaDvTjT6HF+VGPRpqK7Bh5QyckHig4Nap7OG0NhGNC+Gp5mxNsybL\nzAYQad0oJRAMIhQKRTKdlQqGgE67SzJhKxmXxy//hlAo8uARnvYOryu/sfes4ql7yhyOnIkor8kV\nAckkpZnZcrbtPo13jqS+Ncpi0qOmypj2vue+QR/0WjU8IlPbBmGogpjcg8ez9y+K/DlZ60jKDAZn\nIsprckVAHr1nYca+z0j7KqdaiCRaSZEOGo0K9VPNMVnTSpUbBcybWYl3Wy8lHFs2txouj1/2wcPh\n9CnuoEWZweBMRHkr2VSz25tkOjcFep0GDTMqsEckwCmZ3k21EEm0C50OPP7DffB4AzAIGoRCoZS2\nVjXOrMTmNXXQadRoOWlDz4AHZSU6LKiz4p5bZ8IfkK5GFr8ljMlfoyPt4Lx9+3a8+eabka+PHz+O\n1tbWyNezZ8/GggULIl+/8sor0Gj4pEVEmZNsqtne78nICCQ8dd52phvAUIJWMDQ03bygXnoKPboZ\nRqqFSOKFE8FSTQibaCnG5jV10KjVQ9nYwRCOtneh1+FB25luaDSnsWl1rWSiGteVx0baf283btyI\njRs3AgAOHjyI//7v/445bjQa8eqrr47s6oiIZCSbajaX6jHQJ56lrEQ4uO48dCGmjGbw2rbfeTMr\nRYtwSK2Dz59Zmdaas5jwA4Icg6DGV/6uKdKAY9vu0zH3Eb0EEH7A4LpybsjItPYPf/hDvPDCC5k4\nFRGRYnJbkxrrKmEQtBgQ+Vwq7R27+z1QSxQBaTvdDc+qQMI5tr7dHjP9HQ6CtyyYhBpriWw3KRWk\nSn7EShaYAeBjDZNQrB/6b37A6cXh/+kUfV+4tCfXlXPHiINzW1sbJk6cCKs1NtXe6/XiscceQ0dH\nB9auXYvPfvazsucxm4uh1WbvL4LVasraufNBId9/Id87MP7v/+G7GlFcJGD/8cuw2V0wl+qxZM5E\nPLh+LoDY+w8Egnj5dx8OvbfXBWt5EZbMmYj71s2GRjO8s/TFNz6ICfhSgdA+4IZG0MFaWRI5/0/e\n+ADvHktclwaAAx9ehcsjPy2tJDADgLXcgAX1VXj70PlIk4toRXotHlg/F3pBi5++eRxvHzwHj1d8\nnTr+PmoUXkOuy+e/+yMOzjt27MCnPvWphNefeOIJ3HHHHVCpVNiyZQuampowd+5cyfPY7SMv7C7F\najXBZhN7fi4MhXz/hXzvQOHc/7qlUzEw6MFRXxfs/R4cOH4ZXq8fD9/ViJ6e4VHq1l3tMUG30+7C\nm3vPwunyxrR3fO+Ysqlns8mAgNcX+RnHnz9essAMDK1jz5tZibbT3TFtHOPNq63E8jnVeOvAedHz\neLx+/OWCHbuOXExa9CT+PsaDfPi7L/fwMOLgfODAAXz1q19NeP2ee+6J/HnJkiVob2+XDc5EROmS\nWkstLhKwfvk0ANlp7xidLDWSrVLRFtRbsbm5Dp5VQ1PvxmIBb+w9G7MWPG9mBUKhEP7tV0clR9qV\n5UUo0mvRclJ8KlvqPig3jKhC2NWrV1FSUgJBEGJeP3v2LB577DGEQiH4/X60tLRg5syZI7pQIiIx\nckFx//HLkYphSoqIAMNJZmLUqqEmFBWlBjQ31WDT6tpIZTKb3Zk0qMtVBrOY9JFzAsPblor1Wmxu\nrsNzD9yEbz64BF/7TBM8ngDeOdJxrWuVuCVzJg7tX05SVWzZnGomfeWgEY2cbTYbLBZL5Ouf/OQn\nWLRoERobG1FdXY0777wTarUaq1evRkNDw4gvlojGv2TJWvHkgq7N7sLZjj5Mn1yWkfaOK+dPwtrF\nU1Fm1EOrUSVkZAs6FTw+8bGsXqfGktlV+GPr5YRjy+dUY8vaetn71WpU2HXkIlpOdsoG3IprmeH3\nrZuNy1f7YTEJku+3lOpx79r6SDY35Y4RBec5c+bgpz/9aeTrBx98MPLnL33pSyM5NREVGLkynHLB\nQy7oqtTAC68fjZxr3sxK7BbZyiTW3hEQ31YUvpb49eVk+5c9viBuXVgDrUYje14p8ZXQxKgAPHpn\nA2qqTNBo1NDrNFhQXyX5uQV1Vk5n5yhWCCOinCBXhlNsL3GY3Eg3nMUcPtetCyejualGci9v9Kg9\nflsRAHT3uSN/Tmd9efeRDty7dlbS7UrxswceX0DR2rGl1ABrXAWvTatrEQyFsO+DK5HEMoOgwfK5\nnM7OZQzORDTmlCZrSYke6fb0u6GSKNBx9FQ3nnvgpoTgGAgGsXVXu+iovaLMkDCir59qTqsUZ9uZ\nHnh8AckymFKzB6saJyddOwbEE7s0ajW2rKnHxltqYbM7AZUK1vIijphzHIMzEY25kXZ80qjVkZHu\n2Y4+vPD6UdH39Qy4I2vQ0eeTG7UHAsGEgiL7jl+BRF0SWcnuRap4idcfkK0IJuhUWNEwSXYkrNdp\nUFOVv/t+Cw2DMxGNuZF2fArT6zSYPrlMeg0awHdePxpJmtq0uhb+QEhy1L637RK8EoU7lBYLiSZ1\nL0Mj91N496h48ZK2092yFcG8vhBUKhUTu8YR/iaJaMyF143FpLoHV+5c4QAXHpFu231adtTu8QbT\nCsJSpO4lvE9bKgD3ObwoNwriB69pbe+KbBsDALfXj067M+Y1yh8cORNRTshk44XwZ9rOdMPW64IK\n4lPCre1dWLds2oi6RSlRbhTQNKtK9F6UFC+xlBrQUFsRU2glXnjKPLxG3namGza7S3HWO+UWBmci\nygnR68YjbbwQPtfnNhTh4LEOfEdiDdo+4IbL45fM9s6EcqOAZ+9bDFOx+MhXSUWy6IeUd1vFR9jh\nKfN0s94pt/AxiohySjiTOZXAHK7SJTaFayrWoUKi4lc4oG1aXYtbF05GNgaWbq8fb+w9i8vdg6LX\nl6wi2c3zqrGqcTL8gRDuWlWLxTdMEH1vY10lAOktXvHT3pTbOHImorwltfXozlumY8cfz0amdvWC\neNSNXgMOBkOi3Z1SMVTeU4VA1NDW7R3K9t7TeikmES08xSy3T3tSZQk+/Isde49dgV7QAAjB7Q3C\nIKgBqOD1BWKm/7v73CPKeqfcweBMRHlLagr35PleXOh0RF53X8u41qiBwLUAbBA0CIVCCASDQxnb\np7pGdC2L6q24Z00dnvv5Ick9yVJTzGLr7cUGbdw9RCd7Dd3EsjnVuDeq7Gemst5p7DE4E1Fekkuk\n6rA5RF8PRI2M3d6h5hHBENBUZ0WvI3mRDznGEgFeXwB2BcVC4gurxK+3F+m1+Porh5Ke5+T53piv\n5Ubh7DyVXxiciSjjUm1ekY6efrdkhrXcnuB477Z2yGZBp3Ieh9MLs0yjiTD7gBu2XhcErTrmZxRe\nb+9U0OEKGPoZxE9VR2eqd/W6RpT1TmOHwZmIMibd5hXp2HVEOrtarppWvFQCebLzHDphg0bBbQo6\nDb63rRV2hw8Wk4AF9VUxPyO56eloekGTMFUdnal+5q/dWX1AouxhtjYRZUx4Dbi734MQYot9ZJLH\nF0Dbaek14urKsUt6CihIKnN7A7A7fACAngEvdh2+iNfeORXJOgcgWUhFKYOgTTnrnXIHR85ElBEj\nbV6RCrkpbQCwlhlwyebMyPfKJLUKQAgQi99/bO3A0XYb7ANeWEr1mDezErcunIyWk12wO8Tv1Xtt\n+YAZ2OMPR85ElBFKmldkytuHz0seU6uAY6d7Mva9MikoEZiBofaWPQPeyIzD7iMdUKlUeOa+RZKl\nO5mBPX4xOBNRRsgV05ALIh5fABc7B3DR5lBUJMPjC2D/h9K9jTO1hpwLWtu7IOg0aJpVJXqcGdjj\nF6e1iQjAyDOsU93GEwgG8do7p7Dvg8uRfbsGQYPlc6tx960zJRPIbL2umD2/41l4xiGTdccpPzA4\nExW4TGZYpxJEtu0+jd1HYrcwhfceq1QqbG6uE39gCOXe0DiV7PBUhGccMll3nPIDgzNRgctkowSl\nQcTjC6DlpPTUdGu7DYFAEG1nuhMeGKzmYhgEdWS0nQuyNZUeP+MQ3gdN4x/XnIkKWLIMa6WNEuIb\nTyRrXtHn8MgW6uju92BP6yXRLVl6nQY33Sje/GG8ELRqNDfVcNq6gHHkTFTAlGRYy43U0p0SL9Jr\nIWhU8AbEh5xS08QtJ20IBEM4fjY3s7GTUauGZuXNJj2cHr/o2rleq8a3/mEpypmFXdAYnIkKWCqN\nEsTWf1OdEo8O5lKBGZCeJu4Z8GSk1OZYWXzDBKxfcT3KjHr8+t0zoslzK+ZPYmAmBmeiQqYkw1pq\ndLx+xfUpFx2JD+ZSykp06Bv0JbyercSr0WAQNNiyth7F+qH/dpmBTXIYnIkKXLIgITU6drr9KU2J\ny61vR6soNWDuDDP+2Ho54Vi+BmYA+FjDxEhgBpQnz1FhYnAmKnByQUIuoJ44Z1c0JR6eDvf6Aoo6\nLTXWVSKYg9ul0lVRKj8iZgY2iWFwJiIA4kFCLmGs1+HB4hsmoPujqwnH6qeWIxAMYeuu9pjpcL3M\nFihTkRY3Xl+B25dch+d/cXjkN5QDyo0CvvaZJpiKxctvEklJOzgfOHAAjz76KGbOnAkAqKurw7/8\ny79Eju/btw/f/e53odFocPPNN+Ohhx4a+dUS0aiSSxgTdBq0X+wFMLwWrNepgJAK7x+/gpb2zphA\nnKz94YDLjwMfXR1KFvPnzh7mkegf9MLl8TM4U8pGNHJevHgxfvCDH4gee+655/DSSy9hwoQJ2LJl\nC9auXYvaWiY6EOWD6MxsqYQxtzcQ2QoUXgv2+EIAQteOiwdYg6BGKAR4fOLHMxGY9Tq15PnV6qEm\nE+kyCBrF5UPZmILSlZVp7QsXLqCsrAwTJ04EAKxcuRLvv/8+gzNRjhPLzJ4/sxKrF07GsVPdsA+4\nUW6U3qOrRCYre5mNAubXWdF2ujsmmW353Go8+zPxqfFQEBC0Knj9qa1rq1XAyvmToFar8M4RZdu5\nGmZYmORFaRlRcD59+jT+4R/+AX19fXj44YexfPlyAIDNZoPFYom8z2Kx4MKFCyO7UiLKOrHM7HeO\ndKC5qQbPPXDTUGKXP4inXzo4hlc5zHXtAeHhDXOgUathLS+CVqPC1rfbJbddhQD4ZfZYSwmGgOam\nKagyF0GlUqG1vQs9A27ZUt/NTVNS/j5EwAiC87Rp0/Dwww/jE5/4BC5cuIBPf/rTeOuttyAI6a2t\nmM3F0Gqz94RptZqydu58UMj3X8j3Dii/f6fLiz+3JW5fAoDWU134zLo5qJlUDrfXD6u5CJ12VyYv\nMy1ubwB7Wjqwp6UDVnMR5s6ohKBTY0/rJdnPpbsl670Pr+LzG+bh0XsWwu3140r3IL7+0/2w9boT\n3qtWD73/wfVzodGMTaVk/t3P3/tPOzhPmDABt99+OwBg6tSpqKysxNWrVzFlyhRUVVWhq6sr8t6r\nV6+iqkq8H2mY3e5M91KSslpNsNkGsnb+XFfI91/I9w4ov/9AMIiv/fSg5FR1d58bD39nN5pmVWHT\n6lo0zKhQVExkNNnsLuw+nN0Zuv0fXMa6pddFpqpLtGrMq60U/VkEg8Af9v0VXq8/5QYimcC/+7l/\n/3IPD2k/zr355pt46aWXAAxNY3d3d2PChKFi9DU1NXA4HLh48SL8fj/27NkTmfImotwQ3axi665T\nuNwj/4Dc6/BGmk9sWl2L5qYaVJQaoFYNJUllmirjZxy5ngEP+hyxWeebVtdiVeMkqCUuOJUGIkRh\naY+cV69ejccffxzvvPMOfD4fnnnmGfz+97+HyWTCmjVr8Mwzz+Cxxx4DANx+++24/vrrM3bRRJS+\n+KQvs0nAoMuv+PPh0pzRhUuMxTq8sfcvOHyiE70O6W5TqcjFMiRq1VDTjmgatRprF0/FHyWm0pU0\nECGKl3ZwNhqN+PGPfyx5fNGiRdi2bVu6pyeiLIlP+pJr3SgmOthEFy7Z3FyHdcum4emXD2YsQOea\nYAii+5ZTaSBCpAT7ORMVEKX1reWIdasKT4+bigUUG8Zv4cGKUr1ooA03EBETbiBClIrx+6+IiBLI\nleNUSq5bVZFei0td2UvuHGuNdVbJQMsuU5RJDM5EBaTMqIfZJIhOZet1ahiLdOgZ8KC8RI95MysQ\nCoVw7HQ3+hxeWEqTd6sCRhb4c9nK+ZNkAy27TFEmMTgTFRC9ToOSIvHgXGUuxlP3LoxJ8Gptt6HP\n4UW5UY+G2gpsWl0LjVqdkenxfPOJm6ZCo06+EsguU5QJDM5EBcTjC8Dp9okec7p9sNmdsJqL8et3\nz8SMiu0OD/a0dECjVmFzc11GpsfzidmoY1IXjSoGZ6ICIhdUu/s9+NrLh2AxCXB6xPfltrZ3Yf2K\n6fjDgXNQqSBbunI8MRbLT1FHNwrhVDZlAoMzUQGR2/ITJre1yj7gxvM/P5y0YMl4M+jyweMLJARe\nsaS4xjprZPqfKF3820NUQOS2/Cih1agKLjADQK8jsTIYMJwU193vQQhDsw/hKmpEI8HgTDTORO87\nFhNdelOVYo1Mf7odI/KcWCERuaQ4luykkeK0NlEOS2UtU2yKdfm8yVi3NDbLOLzlZ/2K6/HLnSdx\n4KNOxaUyg5lrxZxXxAqJyK3fs2QnjRSDM1EOSmctU2zf8Zt7z8Lp8op2RXpj71+w/6POrN1DPlOr\nhmp7W2QKibBkJ2UTgzNRDhILtOGvxQJtsinWDStnxIz8CnGfcipWzp+EtYunxsxYDDi9uNjpQE2V\nEaZiIbJ+L9YukiU7aaQYnIlyTKqBFkg+xWrrdUHQqiPBps/hkc3YLiRTqoxwuv0JJTfDMxRevx/P\n/6IFHTYHgqGhUfVkqxFf+fQCluykrGFwJsox6axlyk2xCjoN/u1XR2Ef8Eamx29fMhVq1VCXpZEQ\ntGp4/bm/EF1jLcGDd8zGntYOtJ3uTgik/kBIcm3/+V+04EKnI/J1MARc6HTg+V+04Nn7FrNkJ2UF\ngzNRjklnLVNuitXtDcDtHcocDk+PO93+EQdmAPjnTQ048FEn3j16KSPny4abGybi3tvqoVGrce/H\n6+FZlZhkp1FDNHlrwOlFh82R8DoAdNgcGHB6I1PcTP6iTOJWKqIck077wUAwiFAoBIMwfMwgqGEQ\nxP+Jnzhnh8UkiB5LxSt/OIlFCLMMAAAgAElEQVSb508as0phgjb5XjCNJvY94UCqZIR7sdMh+dAR\nDA0dJ8oGBmeiHBS9F1mtAipKDWhuqpFcy9y2+zTeOdIRGSEDgNsbhNsrPuXc6/DghussI77OK3YX\n/s8vWyDoUtwwnQHlJTo01FZCp5X/b2xP66W0i4LUVBmhlrg1tWroOFE2cFqbKAeF9yKvWzYtJkM4\nnscXgM3uTDnzWtBpcM+aOpzvdMSsp6bD4xubNefeQR8On1B231KJdMmYigVMthpFf0aTreK/E6JM\nYHAmyiHhoiPRLRvF9jlH74NOJ+s6FArhavcgBl3SdbTHk5EUBfnKpxdIZmsTZQuDM1EOiC86ohc0\nMVPU8fuc4/dBp8rjC+Ibvzgy4uvOFyMpCiJotXj2vsUJ+5yJsolrzkQ5YOvb7TENFKIDc7TW9i4M\nOL0sIJKiTBQFMRULuGGahYGZRgVHzkRjKBAMYuuuU3j36CVF7+/pd+Nip0NyH3Suq7YUodPuyvq2\nK/W1XtNWcxEaZlSwKAjlHQZnojHi8QXwy50n8d7xK4o/oxc0qKkyJu3JnItunj8RaxdNRZFei1f/\nvxM4cd4OlzeYkWIo8UIAHr97PhbPm4yBPldmT040ChiciUZZeH35yImrsDt8KX46hDf2nsWgO9XP\njb332i7jT0cvJ7yejVF0eYke0yeXwSBoMZD50xNlHdeciUZZOJkr9cA8tHd5T+ulhP3LGjWwsnEi\nKkpztxNSIMmOK71O+X9HS26sgtmokzw+n40nKM8xOBMl4fEF0Gl3wuMLiH6d6rmykcwVCAJqlVqy\nslg+ULJf2iBo0NxUg/s/eSMWzpog+p4pVUZsbp6Z6csjGlWc1iaSEL+9yWwSUFIkwOn2Ke6xHC+b\n3aBaT9owt9aSN80o4pmKdBB0atmfT4lBiw0rZ0CjVsd0hOrpd6PMKKBxZiU2r6lT/PsgylUMzkQS\n4vcS9wx40TMwXLQjWY/leIFgEDsPXchKAhQA9A56sfeY8uSyXNNYXwlBq5Hdv20f8ESKiYSrqLEj\nFI1HIwrO3/72t3HkyBH4/X587nOfw8c//vHIsdWrV6O6uhoazdA/lhdeeAETJohPQxHlmlSmn5WW\nhty2+zT2tHSM6LoErQpef462f7rm5saJOHuxHxdtg4o/o9WocO/H6wEAgWAI77Z2iD7AiBUTYUco\nGo/SDs779+/HqVOnsG3bNtjtdnzqU5+KCc4A8OKLL6KkpGTEF0k02uR6KsdTUhoyU2vNC+qtOH6m\nBw63f8TnyhatSoWnP7sIW99uR+upLvQ5vBB0atk15dJiHfyBELQaFTRqFXRa8fdnopgIUT5IOzgv\nWrQIDQ0NAIDS0lK4XC4EAoHISJkon8n1VI6npDSkXLBXqQBTkYB+p3yda4OggVarzunADADvHb+C\njatm4t61s3DX6qFa4V5fAF97+ZDkZ+wOL/ocHuw6clF0WnsoG30yi4lQwUg7OGs0GhQXD40UduzY\ngZtvvjkhMD/99NPo6OjAwoUL8dhjj0Glkm4rZzYXQ6vNXmC3Wk1ZO3c+KOT7T/fel8+bjDf3nlXw\nvkmomVQu+x5TWRGs5qHqWAnXV16EuuvK8WeRPcDRqiuKse+D3F9T9niD8KtUqCwrwmC3EyUmA2pM\nBljLDbD1ukU/Yy0vQs2kcrT96pjo8UAQMOh1qJ5QlvL18O9+4crn+x9xQtiuXbuwY8cOvPzyyzGv\nf+ELX8CKFStQVlaGhx56CDt37sRtt90meR673TnSS5FktZpgsxVuKYJCvv+R3Pu6pVPhdHnR2t4F\n+4Ab5UY9Sop0cLp9sA94YDYZ0FhXiXVLp8JmG4h0lJJKTGqYUSE6Kuwf9CQNzBMtxfjr5fz5Hf7i\nDx/hg9Ndkf3YBkGDynKD5PsbZlTg4qVe0YeXsPfbLmPd0utSmtbm3/3CvHcgP+5f7uFhRMF57969\n+PGPf4yf/vSnMJliv8n69esjf7755pvR3t4uG5yJco1GrcaGlTNwc8NEQKWCtbwIep0mIQgP1cdu\nl2zvGHbnLdNx8nxvpPVgWHxBkXhmowCvP/U91WPp0EedMV+7vQFc7BxEtaUI9gFPZD3ZIGiwfG41\nNq2uhT8QQrlRQK9DfHq/d9CTdttHonyTdnAeGBjAt7/9bbzyyisoLy9POPbFL34RP/rRjyAIAg4d\nOoS1a9eO+GKJRkv8HufogBufHRy/5Sp+i1U4mO88eB4XOh0pX8sN0yx4P4X627nsSo8LZpMejTPL\nsPam61BtKY6MhDVqoHFmJfa0ijcBsYyg7SNRvkk7OP/hD3+A3W7HF7/4xchrN910E+rr67FmzRrc\nfPPN2LRpE/R6PW688UaOmimvSAXcQDAU2fIDAANOLw6f6BQ7BVrbbQgEgmg7042efg9kUi4kLZ9T\njXvWzMTJ8/a8a3QhxT7gwf6POqFRq7FlbX3Msc1r6nC6o1/0IYaZ2lRIVKFQKCc2TWZzbSAf1h6y\nqZDvP5179/gC+OqL+0WDoVoFLLphAjavmYk39v4FR050ot+ZnSYUpmItHr2zAZOtppS7V+ULi0nA\ngvqqmCWAcBvNo+1d6B30wHJtbT+VSmxh/LtfmPcO5Mf9Z23NmSjfhaeci/RauDx+lBn1stuegiHg\nwEdXceCjq0nPrbrWUzhdA04/nvtFCwyCGotvqIJBUCddn843PQPehCprGrUad62qxarGyUAoBKu5\nmCNmKjgMzlSQwmvKLSc70TPgjZTUrCjVY870CpSVCOgdlN93nEym5qTc3iD+dOwKplQZ01qzzgfh\nKmtajUpyrZ/1sqmQMDhTQYpfUw5nT3f3e/DuUfGEpLE2MDg+1pzFhKusxRchSbV+OdF4wUdRKjjZ\natuYbb2D2VnbzgVmkwFFeq3k76W1vSutFp1E+YrBmcalcM9ltzex1GUqdbMzqdwoYGXjJOi16f2z\ny9VlV4Mw8v9GGusq4fL4JX8v4ZE1UaHgtDaNK/H7k63mIjTMqIhZs0ylbnaYkh7J1RVF6OxxiXZT\nUqmAL909H5ayIrSf68XlntQr4uXqwLGi1ICOLvn7MQgaeH0BmE16FBt0GHT50OsYrrIWLkIi9XtR\nUr+caDxhcKZxJX4tudPuSliz1Os0aKyzyvYNjmYx6THnegsOfHQVnmsBWqMeanPo8YWgAhAC4HL7\nJfs0h0LAsz8/BBVUst2Z8pHD5cOqBZPRdrob3f3uoZkBNeDzBSPBd/2K6XA4vZGqamKlTjVqSP5e\nuMeZCg2DM40bcmvJ8T2Xw92NWk7a0DPgiWRriykp0uFPbbG1rwNBYIKlCJe6nAh/rC/JmrDXFwKQ\nE2UFMqp/0Ie1i6bgrlW1kYALICH4FuuH/7uR6sEc/r2E65lHj6yJCgmDM40bcmvJ8T2XNWo1NjfX\nYcPKGZF9zg6XD7sOX0DbmZ5IYGiYYUHbmW7Rc15KMpVbKCylhkgQjg646dTAjv+9SDURIRrvGJxp\n3JBbS5Zas4wOKKZiAfeunRUz5WqzOyVrPcspLdZlrXJYrsnGlLPUyJqoUDBbm8aN8FqymPgAEs7m\njt+eEw7MxmIBv373DL63vS3l67CY9Nh068yUP5ePjEVaTjkTZQFHzjSuxK9ZVpYPZ2sD0t2m7rxl\nOnb88Sxa223o7vcoys6W4nB58dPffZSxe8plTrcfTrcfpmJhrC+FaFxhcKZxJX7Ncsa0Cgz0uSLH\npbpNnTzfG1MaM93APPTZ8Zf0JSUYAi52OnDDNMtYXwrRuMJpbRqXwmuWBmH4+VMum3u81qzONrUK\nqKkyjvVlEI07DM5UMMaqMthYaaqvhEFQlqhVbhSgVg01/qixlkCvG/6vwSBoYCwSn2SbbDVySpso\nCzitTQWjzKiH2SSgZ2Bk3abyhVarwavP3ob/OW2D1+fHd147KloAxSBo8Ox9iyMtM8NFQmx2J6BS\nwVpeBJUqhOd/0YIOmwPB0NCIebLViK98esEY3BnR+MfgTAVDr9Ng1nUW7Dt+ZawvZVS0n+8FANRY\njRhwehGSqrICQNBpYkbAep0GNVWxjeCfvW8xBpxeXOx0oKaKI2aibGJwpoKyec1MtLTb4PbmaKHq\nDOp1eNDV68L2XSfx52OX4Q2IB2ePNxBToEWOqVhg8hfRKOCaM41LUl2p9DoNKssMY3RVo8tsMuB3\ne89i95EO2ezzcIUvIsodHDmTLLEGBbkm+hq1GpVsV6ptu0/jom0w49dQXiKgpEiLrj53zjS2mD3d\njP0fJK9uJlbhKx9+70TjGYMziZIq1hHdenGsiV1jsUEXsy0quivVhpUzJLdSjVTvoBcurz9nArOx\nSIu2093odcgnvy2bUx1T4Ssffu9EhYDBmURJFesAhlsvjjWxa5Tq0dza3oWlN07I6laqXAnMAOBw\n+ZO/CYBOp4r5Oh9+70SFgI/ClCBZ68X4etRjQe4axXT3u/GDX7eNw4aNI/Nu62Vs230aQH783okK\nBYMzJVDSenGspVNQJFm/5ULVctIWWWPO9d87UaFgcKYE4daLYqRaL442uWuk1NgHPJHkr1z/vRMV\nCgZnSpBK68WxIneNwFDVK7UKqCiQbVMjYTbpI1nZuf57JyoUDM4katPqWjQ31aCi1HCt5rIBzU01\nOdW7d/2K6yVrR5cYtHjms4vw/X++BRUcYctaUG+NBN58+L0TFYK0s7W/+c1v4tixY1CpVHjqqafQ\n0NAQObZv3z5897vfhUajwc0334yHHnooIxdLoye+9WIu7nd1OH3wSFT6Cq+dlhn1mDXVjPcKpGQn\nANRPLcPJ831J32cQNFg2N3YrVT783okKQVrB+eDBgzh37hy2bduGM2fO4KmnnsK2bdsix5977jm8\n9NJLmDBhArZs2YK1a9eitpZP3vko3HpxrMgVwwivkYptnwoB+P6ONiw/1Y31N08vmOCsVgGf/cQN\n+MqL+xEQ2dmlF9T40j2NEDRqWM3FkoF3rH/vRIUureD8/vvvo7m5GQAwY8YM9PX1weFwwGg04sKF\nCygrK8PEiRMBACtXrsT777/P4EwpUVIMI7xGGr0vN1p3vwdv7j2LQx8WRmAGhjpFVZmLcUvjZLxz\npCPh+MfmTsT0iWVjcGVElIq0gnNXVxdmz54d+dpiscBms8FoNMJms8FiscQcu3DhQtJzms3F0Gqz\nN31mtZqSv2kcy7f7f/GND0SLYRQXCXhg/dzI6w/f1YjiIgHvf3AJtl636Lku9zizfr1jTa0GplWX\n4juPrIAgaPHIpgUoKdZj//HLsPW6YC0vwpI5E3HfutnQaAor1STf/u5nUiHfO5Df95+RCmGh0MhL\nO9jt2fsP1Go1wWYbyNr5c12u33/81LXT48NbB86Jvve9Y5fwicVTYqZjP7F4CqZXG/Fv29tG65Jz\nyv23z0JDbSVMxQL6+lyR19cvn4Z7b78BZ/7aHfnZ9vRkvq54Lsv1v/vZVMj3DuTH/cs9PKQVnKuq\nqtDV1RX5urOzE1arVfTY1atXUVVVlc63oXFOaura4fZJtnQMF8OoMhcnfL4QVZQa0HTDBMm1Y4Og\n5doxUR5Ka35r+fLl2LlzJwDgww8/RFVVFYxGIwCgpqYGDocDFy9ehN/vx549e7B8+fLMXTGNG+E6\nzt39HoQwPHXderJT8jPlJj28vgA8vkDC5wsR9x8TjU9pjZwXLFiA2bNn4+6774ZKpcLTTz+N3/zm\nNzCZTFizZg2eeeYZPPbYYwCA22+/Hddff31GL5ryn1wdZ49POtQ6nD48/fIhmE0CnJ7Cq/WsUSOS\nhW0Q1AiGQggEg+wYRTTOpL3m/Pjjj8d8PWvWrMifFy1aFLO1igqT3DaodGpjA4DXPxSZegbkWyGO\nFwZBA68vALPJgCK9JqYXtdsbxO4jHVCrVOwYRTTOsGUkZZySbVBye5QNgkZyzXm8U6uG9mhbTAY0\n1lVi/Yrp6HN4sPPQefz52GXRz7S2d2HDyhmc3iYaRxicKeOU9ASW26NcWW5AV687EqD1WjU8/uz3\nSv7nuxrQM+DBG386g95BZf2QM23l/ElYu3hqzGzDG3vP4k9HxQMzEJskR0TjA4MzZZTcWvKREzas\nWzYNpmIBACJlI1vbu2AfcMNsMqDYoMWFTkfsOf1BGAQ13N7EAG0QNCjSa2Af4TS3SgW88t8n0DPg\nRaa3AWvUgFajhscn/4BhEDTYcEstivXD/yyV9K1mxyii8YfBmTJKtieww4OnXz6IpllVkSnu6DrO\nRXotvv7KIYkzq0RftZYXYVJlMQ58JJ3hrUQoNLyOLVb2MhUq1dD5wl2xnti8AIJWjadfPoheh/RD\nhNcXgMPpjQnOStbmmbFNNP4wxZMyKlmf5V6HF7sOX8R/vH0y8lq4jrPL45cMRF5fANWWooTXL3Q6\ncPRUl8gnxk64Jk8wBNh63fiXn+7H7/b9FQvrpVtcAuIjYLmfp1oFrGqcxI5RROMQgzNlVLI+y2F/\nbL2MV986iUBweJhqLNZBL9ECUqdV42qPS/RYsuniseb2BrHr8EWEADQ31Ui2uWyYYUGfwwOPbzgZ\nTu7nubJxMu5dO4vbqIjGIU5rU0Z5fAGsapyMQDCElnYb+mSmcfe0dECjHt4G9Js/nZXM0s71AKzE\nsVPdeO6Bm7B+xfXY+vYpnDhnR6/Dg3KjHiVFOrSd6cYfWy8lZLeLrc031lVyxEw0jjE4U0ZEb5/q\n7vdA0Krg9Sev2xXeBgQA+z6QzkiWky9br6Kzqv/+kzdG9oHvPHQBe1qGO0jFZ7ezxzJR4eF8GGVE\ndClNAIoCMzAcsGx2p2g2thJL50xAjbUkrc+Opvg1Zb1OgzKjHm2nxdfMW9u7Eqa4q2R6MBPR+MHg\nTCOmZLuPlEjAUolnYyvh8wfh8ozNvuRUiGVVy2a3X3twIaLCw+BMafH4Aui0OyNTs+l2hQoHLGt5\nETRq8QCdLGy/13ZFtNKYEjoNUG4U0vqsHL1ODUE7fOUGQYPQtTrY0eSysbl/mahwcc2ZUiJWmrOh\nthJmk6Co3nV8ecropCadVoWAN3E6XC+oMb/Wiv0fXRU950g6UvkCkN17nK4qc3FMMRW3N4B3jnRA\nFVcHW6tRodigE3244P5losLF4EwpESvNuaelA1OqjIqC88caqnHTDdWoqTJGKoUBQ9O7UmvObm8Q\nK+dPwoGPruZ0a0iVauiho6G2AsdOiU/zRyfA9Tk82HnwfEJFNACYUmVkNjZRAWNwJsXk1padbh9W\nNU7C+x9eFc2cVquBSRUl+PAvduw9diVhu5CxWCebdf3j3x5XFJjD1blGm8WkxxfvmoeyEgEXOx0x\n2dfRevrd+OXOkzhx3o6efo/kUrvT7Yc/EMp4KVEiyg8MzqSYfPKSB2sXT8WGW2rx2tvtQ8FnwIOy\nEgGzppZDr9fi3dZLkffHbxd6Y+9fZLdD9Q36FF3jWARmAJhfV4k/HbsUme5Xq4YqhMXT6dR47/iV\nyNdS18tmFkSFjcGZFJNr8xhOXtLrNLj/2h5eW68LCIVQZtRL1sxube/CumXT0HJyZLWxR5v62gjd\nUjq0dh4KhWKm+6WCrldhMRUmgxEVNgZnUkyuzWN08lIgGMSv3z0TGUWWGQXJpCv7gBsXOx2K1qtz\nycrGyVg1fxKgUqGsRJB8+JAaQSfDZDCiwsbgTEmFt0uVGfWKSknGJ43JZUObTQZUmYvSDmIjoVYD\nQZmBrFoFlJYMPViEr89i0mN+XSVUAL6/oy3pw0cIQGmxgH6n/MNH/EicyWBEhY3BmSSJbZsKJ3FJ\nlZJMtSBJY10lXB7/qAdmYCgw11SVoMM2KDoNfcuCydh4S22knaXL40eZUY9fv3tG8cOHqViH/sHk\nswIrGydj7aIpLM1JRAAYnEmG2Lap6CQusWSlPodHtiCI2ahH36AHZpMBc6ab4XT78b3tbZm/eIVc\n7gD+n4eW4Ve7T+PDv3RjwBWA2ShgYVTP6fB9moqFlB8++gd9srMCFpMeC+qHs9aJiAAGZ5IgF4TC\ne3XjR3iBYBA7D12QDEYVpQZ87TNNcLh82HXkIt4/fjntetqZYh9ww+UJwFgsQNBpoXIFoJaoVAbI\nZ6wDQw8f9riSm1KBefmcamxZW8+RMhEl4KM6iVJa8zm6jOe23aexp6VDMhg11lXCVCxgT2sH9rR0\njHlgBoCyEj12HjofadoRwvAMwbbdp2PeGwgGsfPgecm9yRWlBjx17wLJcqBq1VAp0opSA5qbavCZ\n22cxMBORKI6cCUBs0le4W5LctiljsYCtu9qHM7JLBLi80s0nJltLsGl17YiaZGSD3eHBn4+Jt6qM\nnyHYtvs09kTt1Y7XWFeJQDAk2cM6BODxu+dj+uQyBmUiksXgXODkkr6ktk0V6TX4j7dO4v0Ph2td\n9yZJehp0+eDxBbD17VNpN6nIFqmRfnQhELmHCrVqKKFr0+pa+AMhyYcai8nAwExEijA457H40W46\n5JK+Nq2uxcnzvQm1ny/aBnHRNpjS9+lzeLH17VPYF1UdK9eVG/WRQiBy0/yhELB20RRo1Gpo1FC0\nF5yISA6Dcx6SG+2mkvGbLOlr3bJpcLqVlc1Mxlyqx4lzPSM6h0EYCmweXwAqZH9fdEmRLhJM5ab5\nLaWx1byU7AUnIpLD4JyHkm1xUipZ0tfFTkfafZrjOZw+eP3KE8AMggZeXwDma12emhfWwFJqiFz3\nHw6cw5+Oiq8VZ4rTPTQVr9dpFFdHAwCNWo3NzXWSe8GJiJJJKzj7/X585Stfwfnz5xEIBPDEE0+g\nqakp5j2zZ8/GggULIl+/8sor0Gj4H9RIpbPFSUqypK+aKqPkcTla9VBWcnQZaaWBWaUCbmmcjA0r\nZ8Dh9IoGtipzMdYumpr14Gwf8MQ0n0h1RKzXadi4gojSklZw/u1vf4uioiK89tprOHXqFL785S9j\nx44dMe8xGo149dVXM3KRNEzJFielASHZaFDQaVA/1ZzyOnEKA+QEEyxFuPfj9QCAYr30X09LqQEG\nQZ32dixBq4I/EJKdGo9vPpFsRJyJHAAiIiDN4HzHHXfgk5/8JADAYrGgt7c3oxdF0pR0hkqF2Ghw\n/swKBEMhfPXF/ejp90TWeuVaOmaKvd+DAac3UipTPshJFwvR69TwyHSA8vmTL1hLJXDFj4gzlQNA\nRBSWVnDW6XSRP//85z+PBOpoXq8Xjz32GDo6OrB27Vp89rOfTf8qKSKVtU8lxEaDv373DN6JOn84\nKC+bUw29To22Mz2wD7gh6DIftD2+IJ5+6SD6Br2yQa7P4YFH4vuqVEBTfVVM3+R4ZpMeKhVEH3LU\nKmDl/EmKE7gylQNARBSWNDhv374d27dvj3ntkUcewYoVK/Af//Ef+PDDD/HjH/844XNPPPEE7rjj\nDqhUKmzZsgVNTU2YO3eu5Pcxm4uh1WZvKtBqNWXt3KPt4bsaUVwkYP/xy+jqdaGyvAhL5kzEfetm\nQ6MRH6kpuf8aAG6vH21nukWPn+7oww+fWA0AuNLtRCAYwH/vO4cDx6+g15G5vcvhPdPhIFdcJOCB\n9bF/d0xlRbCai9BpdyV83lpehEfubkTFzpN4++A5uDyJQfxj8ycDAN7cezbh2G1Lp+HzG+Ypula5\nn1fbmW58bkMRDMLY5l2Op7/76Sjk+y/kewfy+/6T/q+xceNGbNy4MeH17du3Y/fu3fj3f//3mJF0\n2D333BP585IlS9De3i4bnO12p9JrTpnVaoLNNpC18482jy+AZTdW4dbGSTHTvz094nuPU7n/TrsT\nNpGABwBdvS60n+3CntYOtLbb0i4mkmp7yPeOXcInFk9JmBWYO92Cd450JLx/7nQLnA4P1i+fhs1r\n6/H/vt6KE+ftsA94Iklc65ZOBQA4Xd6EBK9PfWxaxn5eZ/7aPaZJYePt736qCvn+C/negfy4f7mH\nh7Qe6S9cuIDXX38dv/zlL6HXJ65xnj17Fj/84Q/xwgsvIBAIoKWlBbfddls634qiiK1tNsyoQHPT\nFFhKDRlJQkq2pr3r8AXZEpZKTLYaEwqbyOnpF090k4rv/mAQnXYnyox6WIsE3P/JGyWTtUa65SnT\nOQBERECawXn79u3o7e3Fgw8+GHntpZdewiuvvIJFixahsbER1dXVuPPOO6FWq7F69Wo0NDRk7KIL\nldja5p7WS9jTegkVGUpC0us0aKitxJ4WkRHpDEtMyc50LJ9TjXtvq8P2PWfw3gdXIuvVekENny8o\nOqLWC5qEIOfxBXBUYkvZ3qOX8afWy7CU6rF83mSsWzpVdlvTSLY8ZToHgIgIAFShkFib+dGXzemH\nfJjeAOS34nh8AXz1xf1Jp5Kbm2oSkpCU3n94ZN5yshM9A97I9HNFqR6zpprhDwZx4KPOpOdRqYZK\nWsazmPR4/sElkXvz+AKw9boQCASx52iH5L5lg6DB9x75WORzgWAQL/3+I+xXcC2A+M8kk4ZnNBL3\nP491tna+/N3PlkK+/0K+dyA/7j/j09qUWUq24iTrIxyWaiGSaPEj8/Ao1uHy4r3jV2Q2LsWqNhfj\nck9iDsGCemvMdel1GtRYjdi6q122oIj32kNLlbkYgWAQX3/lcErT4iP5mSjBimBElGnchJkDwkFR\nrp9weG0zmehey6mQqzzm8Q1F6WRTLAZBA4OgwZUeZ+TP4f7FqxZMxqrGyfD4YjOnlbSQjF673fp2\ne0qBGUj/Z5Kq8PQ4AzMRjRSD8xhLVo4zHMzCa5vJpJuEpHRkLkavU+OmG6vg9gbg9gYQAiJ/Xjqn\nGg0zLGg73YWvvngAX31xP7buakcgGFT8fcNrtx5fAK2nulK+PiZmEVG+4bT2GEulHGd0Na/ufrfo\nZ9JNQpLLOk5m2Zxqyb2+Le22mCIl8QU65L5vdJ9kYOhn1euQ7xstholZRJRvOHIeY3LT1VK1nZ97\n4CY8/8BNWLVgMipKDVCrhqaOm5tq0m5LqHRkDgwFTVXU92xumiL5gCFVPSw8KyD3fVfOn4R7P14f\nWXcv0mtRbhQUXWNYkXhq+nIAABFXSURBVF6D9Sump/QZIqKxxpHzGEtnK45ep8HEihLc+/F6eFZl\nrtlCfJ1tQacRDa4r50/C2sVTI9/T4wukPOruiZoVSNbtKTphLtWRs8cbgMPplW2iQUSUa/g/Vg5I\ntRVhtEy2JYzOOu7pd+Otw+dx4MOrkc5PBkGDJbOr0Nw0JeFhYNZUs2wt63gqADsPnsfmNXVJs53j\ns8hTUVlexPVmIso7DM45YKRbcTLdqlCv02BPawfebY3d3uT2BrD/w068e63Ax/yZlQgBOHaqC939\nHhgENQAVvL6A5Kg7LBgC9rRegkajjuxBFnvQkEuY0+vUKDFo0evwSn6/JXMmcr2ZiPIOg3MOSXUU\nnK1WhXIBMRwAu/s9CXWtwyPsJbMnoP28XVG3qmR7kOUS5nz+IL5413wIWjWMxQLe2Hs2YfbhvnWz\nJWuOExHlKgbnPJatVoUj2VYFACfP9cKucG04PiM9nnztaj2s5UWRwC42+yDVpYuIKJfxf648pXR/\ntNznO+1O0fcpLXgixZ5CwY9ke5D1Og2KDYldzwCg2KBLGHGzEAgRjQccOecpudFtd78bPf1uTKwo\nSTgWPxVebtRjfl0lNjfPjEyFy2WQK5FKS8hke5A9vgAGXeKj8EGXL7Idi4hoPOHIOU8lG93uOnxB\n9PX4UqF2hwd7Wjrw9VcOR6p2AUMZ5M1NNZF91AZBeQBUEpjVKmDVgslJM9L7HB7YB8SDc6/DMypl\nOYmIRhtHzjkqWQa2XGtHAGg705MwqpSbCr/Q6cDWt9tx79pZAGIzyMOdo/7Udhltp7sjCVfzZ1Zc\ny9Yefq2htgLHTtnQIxFQw0IhYO2iKUkT19gvmYgKEYNzjkklA7t5YY1kcBZLtEqW6PXe8SvYcEtt\npGBHIBjEr989E3MtDTMq0Nw0BZZSQyTwb7wl9kFCo1YlnRK3lIoH1vBDSZFeC5fHjzKjnv2Siajg\nMDjnmFQysC2lBlSkMKosM+pRbtRLJmx5fUG89nY77v/kjZLXEr83GUjcApZODXC5XtLzZ1Zi9cLJ\nMSN0pUVaiIjyEYNzDpGbdj58ohPrlk2DqXi4tnSqpT/1Og3m10lPhQPAifP2SAa3XDa43N7k+Epj\nuw5fQNuZHtnAKtVLOryfurmpBs89cBP7JRNRQWBwziFy0869Di+eefkQFs6KneJOtfTn5uaZOPFX\nOy73OEWP2weGk6yUdsuSEqkBvnaW7Bq6kp7O4QeCTJUqJSLKZQzOOSRZ20a7I3GKO9XSnxq1Gl/5\nu4V47P++B48vmHA8PB0eCIagF9SRql9i70mlbKhc9TMlRU+UPhAQEY0HDM45ROn+4tb2LqxbNi2S\nMKXXaVIq/Vms12HFvEmS0+EAsPXtdtHADADzZlYkJIqNpGyokl7SzMwmokLC4JxjwtPRh090SrZH\n7O534+mXD6LP4U07MIpNh8+bWYFQKISvvrhfMlAaBA2CwRB2tw6vW4+0bKiShxJmZhNRIWFwzjHh\naep1y6bhmZcPSWZWhwN3KoExfho6fjr81++eSTpq93gDOHaqW/RYskQxOeGHhZaTNvQMeGKytcMP\nH0REhYLBOUeZigUsnKW8hGZ8YIwOxIFAEFt3tYtOQ4enw5UkZQFAmVFAr8QDw0jWhePXzqP3OXPE\nTESFhsE5h8VPPZeVSO9RDgfGijJDQhGTMqMeZy/1R94rNtpW2omqcWYl2s50Z61iV/TaefS2MSKi\nQsLgnMPERpNff+WQbGAUKxwitX4cPdpOlpRlMemxoP7a2rbmNCt2ERFlEYNzHogeTcoVHQGkC4eI\niZ6GlkvKWj6nGlvW1kcCb6p7q4mIKDUMznkkEAzCHwhCr1XD4x/a5mQQNFg2txqbVteiu8+taGo6\nLH4aWi7oRmeCp7q3moiIUsPgnMOik7q0GhW+/sphXOh0xLzH7Q3A6fLjctegov3C0eKnoVMNuqns\nrSYiIuXSCs6/+c1v8P3vfx9Tp04FACxbtgyf//znY97z5ptv4uc//znUajXuuusubNy4ceRXWyDE\nOlMZ9Fp02AZF37//o6vY/9FVGAQ1KsuKACQG5ylVRjjdfkXT0Ay6RERjK+2R8+23344nn3xS9JjT\n6cQPf/hD7NixAzqdDnfeeSfWrFmD8vLytC+0kIgldYkF3HhubxAXbYMJgXj5vElYt3Qq/IEQp6GJ\niPJAVqa1jx07hrlz58JkMgEAFixYgJaWFqxevTob3y4vKK1D7fEF0HKyc0Tfa9Dlw9OfXRTZJ1wz\nqRw22wA0anBETESUB9IOzgcPHsT9998Pv9+PJ598EjfeeGPkWFdXFywWS+Rri8UCm00+i9hsLoZW\nm73RnNVqysp53V4/7P0emEv1MAiJP85AIIiXf/ch9h+/DFuvC9byIiyZMxH3rZsNjUad8N4f/Ooo\negbEy3YqZR/woKjEgOnXlURey9b954NCvneA91/I91/I9w7k9/0nDc7bt2/H9u3bY177m7/5Gzzy\nyCO45ZZb0NraiieffBK/+93vJM8RCoWSXojdLt7CMBOsVhNstoGMnlNsXVisxvXWXe0xU9Sddhfe\n3HsWTpc3odzm1l3t2K2wIpgcs0mPgNcXueds3H++KOR7B3j/hXz/hXzvQH7cv9zDQ9LgvHHjRtlk\nrsbGRvT09CAQCECjGRr5VlVVoaurK/Kezs5OzJ8/P5Vrznli68LxVbfkSmKKldtMtkc5vJbc0++G\nTqeGV6TlIwAsqLdyTZmIKI+l3t8PwIsvvojf//73AID29nZYLJZIYAaAefPm4YMPPkB/fz8GBwfR\n0tKCpqamzFxxDkgWdD2+AAD5kpjhAiBhycpnLptTja99pgnPPXATvvW5Jfjuwx/DrQsnwyAM/9wN\nggarF05mMRAiojyX1przunXr8KUvfQmvv/46/H4/nn/+eQDAT37yEyxatAiNjY147LHHcP/990Ol\nUuGhhx6KJIeNB0qCbpW5WHbfcXwBELn3VpTqce/aemjU6pikrv+1ph533lILW68LCIVgvVbpi4iI\n8ltawbm6uhqvvvpqwusPPvhg5M+33XYbbrvttvSvLIcpDbpyJTHjC4DIv1d6mlqv06DGakz3VoiI\nKAexQlgaUgm6qdShZs1qIiICGJzTpjSQplISkzWriYgIYHBOWzbrULN8JhFRYUsrW5uGhQNpvoxw\nPb4AOu3OSEY5ERHlHo6cC4TSoilERDT2GJwLhJKiKURElBvG7ZCJ07fD3F6/oqIpRESUG8bdyFls\n+nb5vMlYt3Rqzk/fKu1clSp7v7KiKURElBvGXXAWm76VajSRK7K9HmwuVV6pjIiIxl5uDyVTpLTm\nda4JP1B093sQwvB68LbdpzNyfoOgRWOdVfRYfNEUIiIae+MqOKfSaCJXjNYDxabVtWhuqkFFqQFq\nFVBRakBzUw2rjxER5aBxNa2dSqOJXKG0icZIsfoYEVH+GFcj53DNazG5On0bfqAQk40HinwrmkJE\nVIjGVXAGxKdv71gxPWenb/PxgYKIiLJrXE1rA+LTtzWTymGzDYz1pUliNyoiIoo27oJzWD41j8jm\nerDHF8DlrkEEfAGOwomI8sS4Dc75KJMPFDF7pwc8sJhYS5uIKF8wOI9TrKVNRJS/OIQah/K1GAsR\nEQ1hcB6H8rEYCxERDWNwHodGe+80ERFlFoPzOMS900RE+Y0JYRmWrbaPqeLeaSKi/MXgnCHZbvuY\nqui90xpBh4DXxxEzEVGe4LR2hmS77WO69DoNJlaWMDATEeURBucM4NYlIiLKJAbnDODWJSIiyqS0\n1px/9KMfYd++fQCAYDCIrq4u7Ny5M3L84sWLWLduHebMmQMAMJvN+MEPfpCBy81N+dhHmoiIclda\nwfnzn/88Pv/5zwMA/vM//xPd3d0J77n++uvx6quvjuzq8kR461J0ucwwbl0iIqJUjShb2+/347XX\nXsMvfvGLTF1P3uLWJSIiypQRBee33noLH/vYx2AwGBKOdXV14Qtf+AI6OzuxefNm3HHHHSP5Vjkv\nm20fiYiosKhCoVBI7g3bt2/H9u3bY1575JFHsGLFCtx///149tlnUVNTE3Pc4XBg586duOOOOzAw\nMICNGzfitddeQ1VVleT38fsD0GoZzIiIiJIGZylOpxMbN27Ef/3XfyV976OPPop77rkHS5YskXyP\nzTaQzmUoYrWasnr+XFfI91/I9w7w/gv5/gv53oH8uH+r1SR5LO2tVCdOnMD06dNFj+3fvx/f+ta3\nAAwF8RMnTuD6669P91sREREVlLSDs81mg8ViiXnt+eefx4ULF9DU1IS+vj5s2rQJn/70p/Hggw9i\nwoQJI75YIiKiQpD2tHamcVo7ewr5/gv53gHefyHffyHfO5Af95+VaW0iIiLKDgZnIiKiHMPgTERE\nlGMYnImIiHJMziSEERER0RCOnImIiHIMgzMREVGOYXAmIiLKMQzOREREOYbBmYiIKMcwOBMREeWY\nggjO3d3d+Pu//3vce++9uPvuu3Hs2LGxvqRR4/f78eSTT+Kee+7BXXfdhcOHD4/1JY26gwcPYunS\npdizZ89YX8qo+uY3v4lNmzbh7rvvRltb21hfzqhrb29Hc3MzfvnLX471pYy6b3/729i0aRM2bNiA\nt956a6wvZ1S5XC48+uij2LJlCzZu3Ji3/+61Y30Bo+HNN9/E3/7t32LdunU4ePAgvv/97+Pll18e\n68saFb/97W9RVFSE1157DadOncKXv/xl7NixY6wva9ScP38eP/vZz7BgwYKxvpRRdfDgQZw7dw7b\ntm3DmTNn8NRTT2Hbtm1jfVmjxul04hvf+AaWLl061pcy6vbv349Tp05h27ZtsNvt+NSnPoWPf/zj\nY31Zo2bPnj3/f3v3D5JaFIAB/BNvRtHfK9ewLVqKIlqaoqJoimgTWguChhqL4g7NRrQooZiDQ2Bo\nBEFDEVE0BOGoREtLiFEXScqSQHhDcHnCe5EP3j3q+X7TuWf6DlzOxz2IB/39/VhYWEA6ncb8/DzG\nx8dFxyqbFOU8NzdnjjOZjFTXV87MzGB6ehoAoKoqXl5eBCeylqZp8Pv90HVddBRLXV9fY3JyEgDQ\n3d2NXC6Ht7c3NDU1CU5mDYfDgVAohFAoJDqK5YaGhjAwMAAAaGlpwcfHB4rFIux2u+Bk1piamjLH\n1bzfS1HOwNf904uLi8jn84hEIqLjWKaurs4cRyIRs6hl0dDQIDqCEIZhoK+vz3xWVRXPz8/SlLOi\nKFAUaba3Ena7HY2NjQCAeDyO0dFRaYr5d7Ozs3h8fEQgEBAd5Z/U3Nsbi8UQi8VK5paXlzEyMoKD\ngwNcXl5ifX29Jo+1v1v73t4eUqlU1b6oP/Hd+mXHf+mVz9nZGeLxeE3udT8RjUZxe3uLlZUVHB0d\nwWaziY5UlporZ4/HA4/HUzJ3c3ODXC6H1tZWjI2NYXV1VVC6/+tPawe+Suv8/Bw7OzslX9K15m/r\nl5HL5YJhGObz09MTNE0TmIisdHV1hUAggN3dXTQ3N4uOY6lkMgmn0wm3243e3l4Ui0Vks1k4nU7R\n0coixa+1T09PcXh4CAC4u7uD2+0WnMg6Dw8PiEaj8Pv9qK+vFx2HLDI8PIyTkxMAQCqVgsvlkuZI\nW3avr6/Y3NxEMBhEW1ub6DiWSyQS5mmBYRh4f39He3u74FTlk+JWqmw2i7W1NeTzeXx+fkLXdQwO\nDoqOZYnt7W0cHx+js7PTnAuHw3A4HAJTWefi4gLhcBj39/dQVRWapklzzLe1tYVEIgGbzYaNjQ30\n9PSIjmSZZDIJr9eLdDoNRVHQ0dEBn88nRVnt7+/D5/Ohq6vLnPN6vSV7QC0rFArQdR2ZTAaFQgFL\nS0uYmJgQHatsUpQzERFRNZHiWJuIiKiasJyJiIgqDMuZiIiowrCciYiIKgzLmYiIqMKwnImIiCoM\ny5mIiKjCsJyJiIgqzC8iivHPF8qqogAAAABJRU5ErkJggg==\n",
"text/plain": [
- "\u003cmatplotlib.figure.Figure at 0xa813090\u003e"
+ "\u003cmatplotlib.figure.Figure at 0x7f7a18dfb8d0\u003e"
]
},
"metadata": {
@@ -155,7 +149,7 @@
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
- "plt.scatter(inputs.numpy(), labels.numpy())\n",
+ "plt.scatter(inputs, labels)\n",
"plt.show()"
]
},
@@ -168,14 +162,12 @@
"source": [
"## Step 2: Define our TensorFlow variables\n",
"\n",
- "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/contrib/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias.\n",
- "\n",
- "(**Note**: We're using the implementation of `Dense` found in `tf.layers.Dense` though the documentation link is for `tf.contrib.keras.layers.Dense`. When TensorFlow 1.4 is released, the documentation will also be in `tf.layers.Dense`) "
+ "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias."
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
@@ -183,27 +175,23 @@
"startup": false,
"wait_interval": 0
},
- "height": 34,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 34
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 22,
+ "elapsed": 332,
"status": "ok",
- "timestamp": 1505502830753,
+ "timestamp": 1525154229931,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
- "user_tz": 240
+ "user_tz": 420
},
"id": "z9r-ZeyrXu3A",
- "outputId": "6230a7a3-29fe-4d08-f101-da80425bad82"
+ "outputId": "e19a698e-5892-4fcd-80d3-1394605ee72c"
},
"outputs": [
{
@@ -212,7 +200,7 @@
"[]"
]
},
- "execution_count": 4,
+ "execution_count": 48,
"metadata": {
"tags": []
},
@@ -222,7 +210,7 @@
"source": [
"# Create TensorFlow Variables using Keras's Dense layer.\n",
"\n",
- "wb = tf.layers.Dense(units=1, use_bias=True)\n",
+ "wb = tf.keras.layers.Dense(units=1, use_bias=True)\n",
"\n",
"# We can access the underlying TensorFlow variables using wb.variables.\n",
"# However, the variables won't exist until the dimensions of the input\n",
@@ -240,7 +228,7 @@
"id": "docKLUaonYG_"
},
"source": [
- "## Step 3: Define our loss function\n",
+ "## Step 3: *Define the loss function*\n",
"\n",
"Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)."
]
@@ -261,15 +249,14 @@
},
"outputs": [],
"source": [
- "def loss_fn(inputs, labels, wb):\n",
+ "def loss_fn(predictions, labels):\n",
" \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n",
- " predictions = wb(inputs)\n",
" return tf.reduce_mean(tf.square(predictions - labels))"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
@@ -277,36 +264,32 @@
"startup": false,
"wait_interval": 0
},
- "height": 34,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 34
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 24,
+ "elapsed": 348,
"status": "ok",
- "timestamp": 1505502830875,
+ "timestamp": 1525154234538,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
- "user_tz": 240
+ "user_tz": 420
},
"id": "RkNbXoXkpjVH",
- "outputId": "c36fc98d-3a57-4074-901d-c10ae017ae3f"
+ "outputId": "e4688f3c-e29f-416d-f541-6d81953b5660"
},
"outputs": [
{
"data": {
"text/plain": [
- "\u003ctf.Tensor: id=40, shape=(), dtype=float32, numpy=7.3549819\u003e"
+ "\u003ctf.Tensor: id=1252, shape=(), dtype=float32, numpy=16.979801\u003e"
]
},
- "execution_count": 6,
+ "execution_count": 50,
"metadata": {
"tags": []
},
@@ -316,47 +299,43 @@
"source": [
"# Test loss function (optional).\n",
"\n",
- "loss_fn(inputs, labels, wb)"
+ "loss_fn(wb(inputs), labels)"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
- "height": 51,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 51
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 57,
+ "elapsed": 418,
"status": "ok",
- "timestamp": 1505502830981,
+ "timestamp": 1525154260083,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
- "user_tz": 240
+ "user_tz": 420
},
"id": "K_7beXoHOU7t",
- "outputId": "1ad0856a-02ec-4117-a6c0-b41030981d87"
+ "outputId": "8f55c028-fe2b-4edb-ad68-a849afc60623"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "w: tf.Tensor([[ 1.56891453]], shape=(1, 1), dtype=float32)\n",
- "b: tf.Tensor([ 0.], shape=(1,), dtype=float32)\n"
+ "w: -0.311619\n",
+ "b: 0.000000\n"
]
}
],
@@ -364,31 +343,20 @@
"# At this point, the variables exist, and can now be queried:\n",
"\n",
"w, b = wb.variables\n",
- "print(\"w: \" + str(w.read_value()))\n",
- "print(\"b: \" + str(b.read_value()))"
+ "print(\"w: %f\" % w.numpy())\n",
+ "print(\"b: %f\" % b.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "YIlebeb_qYtC"
+ "id": "JVDWpL9VYWdP"
},
"source": [
- "## Step 4: Create our gradients function using `implicit_value_and_gradients()`\n",
- "\n",
- "With a loss function defined, we can calculate gradients and apply them to our variables to update them.\n",
+ "## Step 4: Create an optimizer\n",
"\n",
- "To calculate the gradients, we wrap our loss function using the `implicit_value_and_gradients()` function.\n",
- "\n",
- "`implicit_value_and_gradients()` returns a function that accepts the same inputs as the function passed in, and returns a tuple consisting of:\n",
- "\n",
- "1. the value returned by the function passed in (in this case, the loss calculated by `loss_fn()`), and\n",
- "1. a list of tuples consisting of:\n",
- " 1. The value of the gradient (a `tf.Tensor`) with respect to a given variable\n",
- " 1. The corresponding variable (`tf.Variable`)\n",
- "\n",
- "Test it out below to get a feel for what it does. Notice how the first value of the returned tuple (the loss) is the same as the value returned in the cell above that tests our loss function."
+ "We'll use a `GradientDescentOptimizer` to fit our model."
]
},
{
@@ -403,87 +371,29 @@
}
},
"colab_type": "code",
- "id": "v1spZQ4NwW1U"
+ "id": "DudNEebMKDWN"
},
"outputs": [],
"source": [
- "# Produce our gradients function. See description above for details about\n",
- "# the returned function's signature.\n",
- "\n",
- "value_and_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "cellView": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "height": 153,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 46,
- "status": "ok",
- "timestamp": 1505502831114,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 240
- },
- "id": "21WMcpsmFFLd",
- "outputId": "f51b3171-33f5-4f87-8bf7-0be2dc8edc8a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Outputs of value_and_gradients_fn:\n",
- "Loss: tf.Tensor(7.35498, shape=(), dtype=float32)\n",
- "\n",
- "Gradient: tf.Tensor([[-3.00773573]], shape=(1, 1), dtype=float32)\n",
- "Variable: \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e\n",
- "\n",
- "Gradient: tf.Tensor([-4.06519032], shape=(1,), dtype=float32)\n",
- "Variable: \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e\n"
- ]
- }
- ],
- "source": [
- "# Show outputs of value_and_gradients_fn.\n",
- "\n",
- "print(\"Outputs of value_and_gradients_fn:\")\n",
- "\n",
- "value, grads_and_vars = value_and_gradients_fn(inputs, labels, wb)\n",
- "\n",
- "print('Loss: {}'.format(value))\n",
- "for (grad, var) in grads_and_vars:\n",
- " print(\"\")\n",
- " print('Gradient: {}\\nVariable: {}'.format(grad, var))"
+ "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "JVDWpL9VYWdP"
+ "id": "YBeJYxY8YaiO"
},
"source": [
- "## Step 5: Create an optimizer\n",
+ "### Step 5: Define a training step\n",
"\n",
- "We'll use a `GradientDescentOptimizer` to fit our model."
+ "To fit model variables to the data we'll need to:\n",
+ "\n",
+ "1. Calculate the gradients of the loss with respect to the model variables.\n",
+ "2. Use `optimizer` to compute updates to the variable values based on those gradients.\n",
+ "\n",
+ "To calculate the gradients, we use the [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context manager\n",
+ "and its `gradient` function to compute gradients through computation conducted within its context:\n"
]
},
{
@@ -498,94 +408,72 @@
}
},
"colab_type": "code",
- "id": "DudNEebMKDWN"
+ "id": "diDZfrMJM3OC"
},
"outputs": [],
"source": [
- "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)"
+ "def run_step(inputs, labels):\n",
+ " with tf.GradientTape() as g:\n",
+ " loss = loss_fn(wb(inputs), labels)\n",
+ " # Compute the partial derivatives of loss with respect to the variables\n",
+ " grads = g.gradient(loss, wb.variables)\n",
+ " optimizer.apply_gradients(zip(grads, wb.variables))\n",
+ " return loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "YBeJYxY8YaiO"
+ "id": "1WWepgmJQOzc"
},
"source": [
- "### Step 5a: Test Our Optimizer\n",
- "\n",
- "Now we have everything needed to start fitting our variables to the data!\n",
- "\n",
- "In the next cell, we'll demo these capabilities. We'll:\n",
- "\n",
- "1. Print the current values of `w` and `b`\n",
- "1. Calculate the loss and gradients\n",
- "1. Apply the gradients\n",
- "1. Print out the new values of `w` and `b`\n",
- "\n",
- "You can run the cell multiple times. Each time, you should see the values of `w` and `b` get closer to their true values of 3 and 2."
+ "Repeatedly running the training step will nudge the variables towards the values that best fit the data (i.e., \"w\" will move closer to 3.0, while \"b\" will tend to 2.0):\n",
+ "\n"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 0,
"metadata": {
- "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
- "height": 102,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 51
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 103,
+ "elapsed": 380,
"status": "ok",
- "timestamp": 1505502831285,
+ "timestamp": 1525154412590,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
- "user_tz": 240
+ "user_tz": 420
},
- "id": "diDZfrMJM3OC",
- "outputId": "d585fff0-ecb3-4e98-9b33-bbae07a95d8c"
+ "id": "ya5Qxz5XQlhU",
+ "outputId": "8dd47155-a6c1-44c5-c279-617c803f1723"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Values of w, b, BEFORE applying gradients:\n",
- "(array([[ 1.56891453]], dtype=float32), array([ 0.], dtype=float32))\n",
- "()\n",
- "Values of w, b, AFTER applying gradients:\n",
- "(array([[ 1.86968815]], dtype=float32), array([ 0.40651903], dtype=float32))\n"
+ "Values of w, b BEFORE applying gradients: 2.725763, 1.894334\n",
+ "Values of w, b AFTER applying gradients: 2.774932, 1.922555\n"
]
}
],
"source": [
- "# Test the optimizer.\n",
- "\n",
- "print(\"Values of w, b, BEFORE applying gradients:\")\n",
"w, b = wb.variables\n",
- "print(w.read_value().numpy(), b.read_value().numpy())\n",
- "print()\n",
- "\n",
- "# Calculate the gradients:\n",
- "empirical_loss, gradients_and_variables = value_and_gradients_fn(\n",
- " inputs, labels, wb)\n",
- "optimizer.apply_gradients(gradients_and_variables)\n",
- "\n",
- "print(\"Values of w, b, AFTER applying gradients:\")\n",
- "print(w.read_value().numpy(), b.read_value().numpy())"
+ "print(\"Values of w, b BEFORE applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n",
+ "run_step(inputs, labels)\n",
+ "print(\"Values of w, b AFTER applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n"
]
},
{
@@ -602,51 +490,44 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
- "height": 397,
- "output_extras": [
- {
- "item_id": 1
- },
- {
- "item_id": 2
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 364
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 225,
+ "elapsed": 580,
"status": "ok",
- "timestamp": 1505502831550,
+ "timestamp": 1525154278709,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
- "user_tz": 240
+ "user_tz": 420
},
"id": "VukGe-huNaJ4",
- "outputId": "f0a8d665-1910-477c-d8ab-c94ccdc4afcd"
+ "outputId": "c79c8e63-c781-451e-f74f-20815d8da49f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "[2.111051321029663, 2.3047544956207275, 2.4602210521698, 2.5850086212158203, 2.6851789951324463, 2.7655951976776123, 2.830157995223999, 2.8819968700408936, 2.9236228466033936, 2.9570505619049072]\n"
+ "[0.9409681558609009, 1.3733772039413452, 1.7128530740737915, 1.9793939590454102, 2.188689708709717, 2.3530514240264893, 2.4821391105651855, 2.583533763885498, 2.6631851196289062, 2.7257626056671143]\n"
]
},
{
"data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAFXCAYAAADnFpTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd4FFUbBfAzu+m9koSShBQCSC+igIAgRRGkChJEiggo\nHURAEBQBQeADRcWCha50ULFLk6IivYRQQwskhPS6O/P9sckmm4Rkk2x2difn9zz7bLuZvC8JHO7M\n7FxBkiQJREREVOlUchdARERUVTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMjArdlJQU\njB8/Hk8//TS6d++OkydPVnZdREREiiMY8znd6dOno2XLlujbty80Gg0yMzPh4uJijvqIiIgUo9TQ\nTU1NRa9evfDbb7+ZqyYiIiJFKnX38s2bN+Hp6YkZM2agd+/emD17NjIzM81RGxERkaKUGroajQbn\nzp3DoEGDsH37djg4OOCzzz4zR21ERESKUmro+vv7w9/fHw0bNgQAdO3aFefOnSvxa3g5ZyIioqJs\nShvg4+ODgIAAXL16FbVr18aRI0cQGhpa4tcIgoC4uBSTFSkHX19Xq+8BUEYfSugBYB+WRAk9AMro\nQwk9ALo+jFFq6ALArFmzMHXqVGg0GtSqVQsLFy6sUHFERERVkVGhW7duXWzdurWyayEiIlI0XpGK\niIjITBi6REREZsLQJSIiMhOGLhERkZkwdImIiMyEoUtERCbRuXM7uUuweAxdIiIyCUEQ5C7B4hn1\nOV0iIqKy+OijFTh69BAEQYUhQ4ajU6fOuH8/HnPmzER6ehq0Wi2mTJmOJ59sgwUL3kZU1HkAArp3\n74nnn39B7vIrDUOXiEhh5s6dhd27d5h0mz169MLcue8aNXbv3t9x+XI01qz5Fg8eJODll4egadNm\n+PXXn9Cq1eN48cVhkCQJmZmZOH/+POLi7uGbbzYBANLSUk1at6Xh7mUiIjKp06dP4qmnugIAPD29\n0LRpc5w/fw716j2CH37Yha+++hyXLkXD0dERtWrVwp07t7F8+RIcPXoYTk7OMldfuTjTJSJSmLlz\n3zV6VloZCq80l/e8ceOm+Oijz3H48EEsWDAXAwcOxuDBA/D11xtx9Ohh7Ny5DX/88StmzHhLjrLN\ngjNdIiIyifxwbYbff/8VoijiwYMHOHXqBOrXfwSxsbHw8PDEs8/2wrPP9sLFixeQmJgIUdSiffsn\n8fLLoxEdHSVzF5WLM10iIjKJvLOX27d/EmfPnsbQoS9AEFR49dXx8PT0wp4932PjxrWwsbGBk5Mz\nZs16G7GxsXj99TcgSSIEQcDo0eNk7qJyCVIlrThv7esjKmmNR2vvQwk9AOzDkiihB0AZfSihB8D4\n9XS5e5mIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIismjHjx/DmTOn9M93\n7NiKn3/+0STbXrv2K5Nsx1gMXSIismjHjx/D6dP5odurV1907fqMSba9Zo15Q5dXpCIiogrbsGEN\n7O3t0bfvAHzwwVJcvnwJK1Z8gmPH/sGPP+7C7NnzDMZHRV3Ahx8ug0aTDWdnN7z55hx4eXlj8+ZN\n2LlzG2xsbBAcXBujR4/Fzp1boVbb4Ndf92DixNfx779/w8nJCQMHDsa4caNQp04ETp48gczMTMya\nNRdr136FK1cuo2PHzhg5cgwAYMaMqYiLu4fs7Cz07/8CevTohVWrViI7OwvDh0eidu0QzJ49D7/8\nsgebN2+CVqtB/foNMGXKdJOuE8zQJSJSGOe5s2Bv4qX9snr0QloJiyg0btwM3367Hn37DkBU1AXk\n5ORAq9Xi1KkTaNy4mcFYjUaD5csX4733liEsrBY2bdqGTz/9CDNmvIX167/Bli27YWNjg7S0VDg7\nu+C55/rqQxYA/v33b4Pt2dra4Ysv1mDz5k2YPn0KvvpqPVxcXDFgQC8MGBAJNzc3zJw5B66ursjK\nysLIkUPQvn1HjB49Ftu2bcaXX64HAFy/fg2///4LVq36Emq1GkuXLsIvv+wx2awaYOgSEZEJRETU\nRVTUeaSnp8PW1hYREXVx/vw5nDx5HJMmTTMYGxNzHVeuXMakSa9BrVYhO1sDHx9fAEBYWDjmzn0T\n7dp1wBNPdDDqe7dt2w4AEBoahpCQUHh6egEAqlevgXv37sLNzQ3ffbcBBw7sAwDcu3cPN2/GoH79\nBgYrIv3779+4eDEKI0cOgSRJyM7OhpeXV0X/aAwwdImIFCZt7rslzkorg42NDfz9A/Djj7vQsGFj\nhIWF4/jxf3H79i0EBQUXGi0hJCQUn3zyZZFrL7///gqcOPEfDh7cjzVrvsSaNd+W+r1tbe0A6BZc\nsLW11b8uCAK0Wi2OHz+G//77F5999jXs7OwwbtwoZGdnF7MlCd26dceoUa+V40/AODyRioiITKJx\n46bYuHEdmjRphkaNmmDHjq0ID69TZFxgYDAePEjEmTOnAeh2N1+9egUAcPduLJo2bY4xY8YhLS0N\nGRnpcHJyQlpaWrnrSktLhaurK+zs7HD9+jWcPXtG/56trS20Wi0AoHnzR7F37+948OABACA5ORmx\nsbHl/r7F4UyXiIhMonHjpli79is0aNAQ9vYOsLe3L3I8F9DNit99dxGWL38fy5cvQnZ2Dp5//gXU\nqhWId96ZnRuwEvr3HwhnZxe0adMOs2a9gb/+2o+JE183OLGppJOc8t5r1ao1duzYisGDn0dgYBAa\nNGioH9OzZ2+89NJARETUxezZ8/Dyy2MwefJrEEUJtra2mDx5Gvz9/U32Z8Sl/R5CSctNWXsfSugB\nYB+WRAk9AMroQwk9AFzaj4iIyOIwdImIiMyEoUtERGQmDF0iIiIzYegSERGZCUOXiIjITBi6RERk\ndt99txFZWVlyl2F2DF0iIjK7zZs3Iisrs9j3RFE0czXmw9AlIqIK27BhDbZu1V0n+YMPlmLCBN2S\neseO/YN582YbjN2yZRPi4+MwbtxovPTSSwCAzp3bYeXK5Rg2bBDOnDmF/v17Ijk5CQBw4cJ5jBs3\nCgCQmZmJhQvfwciRL2H48ME4eHC/uVo0CV4GkohIgbyaNyj29YRjZ4p9vazjCyvL0n79+g3Et99u\nxIcfforQ0BqIi0tBZmYGGjRoiLFjJ+aOMry8Y94lHb/5ZjWaN38UM2a8hdTUVIwcOQQtWz4Ke3sH\no+qUG0OXiIgqrCxL++lIuTcdtVqN9u07Fnq/qH/+OYpDhw5g48Y1AHSLJdy9G4vAwGCT9VKZGLpE\nRApk7Ay1vOMLK9vSfkXZ2dkbLF6gVqshirrgzc7OP+FKkiS8++5i1KoVWKF65cJjukREZBLGLu0H\nAE5OzgbL9RVeeycgoDqios4DAPbt+0P/+qOPPoYtWzbpn0dHR5myhUpn1Ey3Y8eOcHFxgUqlgo2N\nDbZs2VLZdRERkZUxdmk/AOjZsxemTh2PgAB/LFmyssgSfUOHjsR7770DFxcXNG3avMDrL+ODD5bi\npZcGAgD8/QOwaNH/Kq8pEzNqab9OnTph27ZtcHd3N2qjFy9ehKdnQIWLk5OSlpuy9j6U0APAPiyJ\nEnoAlNGHEnoATLy0nyRJZfrc1IABA5CTk2P0eCIioqrAqNAVBAEjRoxA37598d1335U6/sSJE/jw\nQ+uZ7hMREZmDUcd0N23aBF9fXyQkJGDYsGEICQlBixYtHjq+Ro0aWLp0Ebp164769R8xWbFERETW\nzKhjugWtXLkSzs7OGDZs2EPH/PDDD3j22WfRvHlzHDlyBDY2/GQSERFRqWmYkZEBURTh7OyM9PR0\nHDx4EGPHji3xa7p3747nn38B3323EXPnvosJE6aYrGBzUdLBfWvvQwk9AOzDkiihB0AZfSihB8D4\nE6lKDd34+HiMHTsWgiBAq9WiR48eaNu2bakbfvfd97Bv3594//2F6NatOyIi6hpVEBERkVKVeiJV\nrVq1sHPnTuzYsQO7d+/GK6+8YtSGPTw88f77y5GdnY0JE8ZAo9FUuFgiIrJMsbF3MGTIAJNuMzr6\nIg4f/kv//ODB/Vi//huTbFuupQUr9YpU3bo9g759n8d//x3DqlUfVea3IiIimRW+wEVFXbp0EUeO\n5Idu27btEBn5kkm2XdLSgpWp0s9wmj9/Efbv34tFi95F165PP/SSYEREZN00Gg3eeWc2Ll68gNq1\nQzFr1tuwt7c3GHPr1k0sW7YYSUmJcHBwwHvvLYCLiw/++OM3fP3151Cr1XB2dsHy5R/jiy9WITs7\nG6dPn8TgwcOQlZWJCxfOYdKkaViw4G3Y2dkjOjoKiYkPMGPGW9iz53ucPXsa9es3wMyZcwAAS5a8\nh6ioc8jKykKHDp0wfPgrBksLenh4YMWKT/D330fw5ZefIScnBzVq1MTMmXPg4GD6lYsqPXS9vLyx\nePH/MGxYJCZMeBW7d/8MtVpd2d+WiKjKmjvXHrt3m/af9x49NJg7t+TdsTEx1zFjxhw0aNAQCxe+\ng+3bN2PgwMEGYxYvXoBp02aiRo2aOHfuDObOnYslS1bim2++wLJlH8HHxwdpaamwsbHByy+PRlTU\neUyc+DoAYM+e7w1m06mpKfj0069w8OA+vPHGJKxa9RVq1w7BiBEv4tKlaISFhWPUqNfg6uoKURQx\nYcIYXLlyyWBpQTc3NyQlJWLNmi+xYsXHsLd3wPr132DTpnUYOvRlk/4ZAmZaZah79x7o1asPduzY\nhs8//wSjR5d89jMREVkfPz9/NGjQEADQtesz2LLlW4PQzcjIwJkzJzF79hsFFjjQ3Tds2Bjz589B\nx46d0b79k0Z9vzZtngAAhISEwcvLG7VrhwAAatcOQWzsbYSFheP333/Grl07oNVqkZBwH1evXkVI\nSBgKLi149uwZXLt2BWPGjIAkSdBoNGjQoFHF/0CKYbYP0C5YsAQHD+7HggXvoEuXbrlNExGRqc2d\nm1XqrLQyFD6mW/gQrySJcHV1w5dfrte/lveRoalTZ+D8+bM4dOggRox4EatXryv1+9nZ2QEAVCqV\n/nHec61Wizt3bmPTpvVYvXotnJ1dsGDB2wbLBObXJaFly8cwZ867ZWm3XMy2tJ+Pjw/ee28pMjMz\nMWHCa2W6ljMREVm+2Ng7OHtWty7vr7/+jEaNmhi87+TkjICA6vjzz9/0r124cAGA7lhvvXqPYMSI\nUfDw8MS9e3fh5ORksPxfSYq7zlNaWhocHR3h5OSMhIT7OHLkkEEtedt+5JGGOH36JG7dugkAyMrK\nxI0bMWXo3HhmvVRUz5690aPHduzevQOrV3+KkSPHmPPbExFRJQoKCsa2bd9h4cK3ERwcgl69+hUZ\nM2fOu3j//YX45psvodVq0LNnD/Tv/yI+/ngFbt68AQBo3rwlwsLCUa2aH9at+xrDh0di8OCHXwUR\nKP7M6bCwcISHRyAysh+qVfNDo0aN9e/lLS3o4+OLFSs+wcyZczB37kxkZ+dAEASMHDkGtWoFVvBP\npJg6y3oZSGM97AojcXFxeOKJlsjMzMSffx7S74O3NEq6Soq196GEHgD2YUmU0AOgjD6U0ANg4qX9\nTMnX1xcLFy5Beno6Jk0ay93MRERUZZg9dAGgV6++ePrpZ3Ho0EF8/fVqOUogIiIyO1lCVxAELF78\nP3h4eOCdd97C9evX5CiDiIjIrGQJXQDw8/PD/PmLkZ6ehsmTxxV75hkREZGSyBa6ANCv3wB06dIN\nBw7sw5o1X8lZChERUaWTNXQFQcCSJSvg7u6BuXNnVdrnooiIiCyBrKELAP7+AZg3byHS0lK5m5mI\nyEoZu7Tfnj3f4/79eDNUZJlkD10AGDBgEDp16ox9+/7Ehg1r5S6HiIjKwZil/X78cTfi4uKKfa8q\nfITUIkJXEAQsXfoBXF3d8NZbM3H79i25SyIiojLKW9pv8OD+mD17epFF4vfu/R0XLpzHvHmzMXx4\nJLKystCxY0d88smHGDHiRfz5528YN24UoqJ0l4ZMSkpE//49AegC+eOPV2DkyJcwdOgg7Nq13ez9\nmYJFhC4AVK9eA++8swApKcmYMmU8dzMTEVVA8+bOxd5MNb44MTHX0afP81i3bjOcnJywfftmg/c7\ndOiEevXqY86cd/Hll+v1a+26u3tg9eq16NSpSzFb1c2ev/9+J1xcXPH559/g88+/wa5d2xEbe6dM\n9VkCiwldABg06EV06NARv//+K779doPc5RARURkUXtrv1KmTRcZIkoTCc6pOnTqXuu2//z6Cn376\nAcOGDcIrr7yE5OQkqzz51qwLHpRGEAQsW/Yh2rV7DLNnz0CHDh3h7x8gd1lERFbn2DHjVucp7/ji\nlLa038M4OjrqH6vVakiS7thudnZ2gVESJk16HS1bPlbRMmVlUTNdAKhZsxbmzJmHpKRETJ06gbuZ\niYisRGlL+wGAs7Mz0tJSH7qNgIAauHDhHAAYLAH46KOPY9u2LdBoNACAGzdikJWVacryzcLiQhcA\nhgwZhieeaI9ffvkJW7Z8K3c5RERkhLyl/QYP7o+UlORil/Z7+ulnsWTJQv2JVIVnxy+8EInt27di\n+PDBSE5O1r/eo0cvBAfXxogRgzFkyAAsWbIQWq220nsyNbMv7WesmJjraNfuMdjZ2eLAgX/g5+dn\nosqMo6Tlpqy9DyX0ALAPS6KEHgBl9KGEHgALXtrPWIGBQZg9+20kJiZi2rRJ3M1MRERWz2JDFwCG\nDXsZrVu3xZ4932PHjq1yl0NERFQhFh26KpUK//vfSjg5OWHGjKm4d++e3CURERGVm0WHLgDUrh2C\nN9+cg4SEBMyYMVXucoiIiMrN4kMXAEaMGIVWrR7H7t07rPbSX0RERFYRuiqVCitWfAQHBwdMnz4F\n8fFVd4UKIiKyXlYRugAQEhKGGTPeQnx8PGbO5G5mIiKyPlYTugDwyitj0KLFo9ixYxt++GG33OUQ\nERGViVWFrlqtxooVH8Pe3h7Tpk1CQsJ9uUsiIiIymlWFLgCEh9fBG2/MQlzcPbz55htyl0NERGQ0\nqwtdABgzZiyaNWuOrVu/w08//Sh3OUREREaxytDV7Wb+BHZ2dnj99YlITHwgd0lERESlssrQBYCI\niLp4/fUZuHs3FrNnz5C7HCIiolJZbegCwGuvTUDjxk3x7bcb8OuvP8ldDhERUYmsOnRtbGzwwQef\nwNbWFlOnTkRSUqLcJRERET2UVYcuANSrVx+TJ0/DnTu3MWfOm3KXQ0RE9FBWH7oAMH78ZDRo0Agb\nNqzFH3/8Jnc5RERExVJE6Nra2uKDDz6BjY0NJk8eh5SUZLlLIiIiKkIRoQsADRo0xMSJU3H79i3M\nnTtb7nKIiIiKUEzoAsDEiVNRv34DrF37Ffbt+1PucoiIiAwYHbqiKKJ3794YPXp0ZdZTIXZ2dvjg\ng4+hVqsxefI4pKamyF0SERGRntGhu2bNGoSGhlZmLSbRqFETjB8/CTduxGDevDlyl0NERKRnVOjG\nxsZi37596N+/f2XXYxKTJ7+BunXr4auvvsDBg/vlLoeIiAiAkaG7YMECTJs2DYIgVHY9JmFvb48V\nKz6GSqXCxIljkZaWJndJREREsCltwN69e+Hj44N69erh6NGjRm/Y19e1QoVVVJcuHTBt2jS89957\nWLZsAT744IMyb0PuHkxFCX0ooQeAfVgSJfQAKKMPJfRgLEGSJKmkAcuWLcOuXbugVquRlZWFtLQ0\ndO7cGYsXLy5xw3Fx8p/ElJmZiaeeegIXL0Zh5849ePzxNkZ/ra+vq0X0UFFK6EMJPQDsw5IooQdA\nGX0ooQfA+P84lLp7efLkydi7dy9+//13LFu2DK1atSo1cC2Fg4MDli//CCqVChMmvIr09HS5SyIi\noipMUZ/TLU6LFo9i9OixuHbtKhYunCd3OUREVIWVKXQfffRRrFq1qrJqqTRvvPEmQkPD8NlnH+Po\n0SNyl0NERFWU4me6AODo6Ijlyz8GAEyc+CoyMjJkroiIiKqiKhG6ANCq1WN45ZUxuHz5EhYtmi93\nOUREVAVVmdAFgBkz3kJwcG2sWrUS//77t9zlEBFRFVOlQtfJyQkrVnwMURQxYcKryMzMlLskIiKq\nQqpU6ALA44+3wcsvj0J09EUsWfKe3OUQEVEVUuVCFwDefHMuAgODsXLlchw/fkzucoiIqIqokqHr\n7OyM5ctX6nczZ2VlyV0SERFVAVUydAGgbdt2GDp0BC5cOI///c86rrBFRETWrcqGLgC89dY7qFUr\nECtWLMOpUyfkLoeIiBSuSoeui4srli37EFqtFuPHv4rs7Gy5SyIiIgWr0qELAO3bP4kXXxyGc+fO\nYPnyJXKXQ0REClblQxcA5s6dhxo1amL58iU4c+a03OUQEZFCMXQBuLq6YenSD6DRaDB+/Bjk5OTI\nXRIRESkQQzdXx45PYdCgF3HmzCl8+OH/5C6HiIgUiKFbwNtvz4e/fwCWLl2E06e5m5mIiEyLoVuA\nu7sHli5dgZycHAwdOhSpqalyl0RERArC0C2kc+duiIwcgv/++w8DBvRGcnKS3CUREZFCMHSL8f77\nyzFo0CD8889R9O3bEwkJ9+UuiYiIFIChWwwbGxusWbMGgwa9iJMnj6N372cRFxcnd1lERGTlGLoP\noVarsWzZhxg+fCTOnz+LXr2exp07t+Uui4iIrBhDtwQqlQoLFy7Bq6+OR3T0RfTs2Q03bsTIXRYR\nEVkphm4pBEHAnDnzMGXKG7h+/Rp69uyGK1cuy10WERFZIYauEQRBwBtvvIlZs+bi1q2beO65pxEV\ndUHusoiIyMowdMtg/PjJmD9/Ee7ejUWvXk/j9OlTcpdERERWhKFbRiNHjsGSJSuQkJCAPn2exfHj\nx+QuiYiIrARDtxyGDBmGDz9chZSUZPTt2xNHjhyWuyQiIrICDN1yev75F/DZZ18hMzMDAwf2xoED\n++QuiYiILBxDtwJ69uyNr75aD41Gg0GD+uG3336WuyQiIrJgDN0K6tr1aaxb9x1UKhVeemkQfvhh\nt9wlERGRhWLomkCHDh2xceNW2NnZ4+WXh2Dbts1yl0RERBaIoWsirVu3xebNO+Ds7IIxY17Ghg1r\n5S6JiIgsDEPXhFq0eBTbtu2Gp6cnJk58DatXfyZ3SUREZEEYuibWqFETbN/+I3x9q2HGjKn4+OMP\n5S6JiIgsBEO3EtSrVx87d+5BQEB1zJ37JpYuXQRJkuQui4iIZMbQrSRhYeHYuXMPAgODsGjRfCxY\n8A6Dl4ioimPoVqLg4NrYuXMPQkJCsWLFUsyePZ3BS0RUhTF0K1mNGjWxc+dPqFu3Hj777BNMnToR\noijKXRYREcmAoWsGfn5+2L79RzRo0Ahr136FceNGQ6PRyF0WERGZGUPXTLy9vbFt2240b94Cmzdv\nwujRI5CTkyN3WUREZEYMXTPy8PDE5s078fjjbbBr13YMHz4YmZmZcpdFRERmwtA1MxcXV2zcuBXt\n2z+Jn3/egyFDBiI9PV3usoiIyAwYujJwcnLC2rXfokuXbti79w8MGtQPqakpcpdFRESVrNTQzc7O\nRv/+/dGrVy/06NEDK1euNEddiufg4IAvv1yHHj164dChg+jfvxeSkhLlLouIiCqRTWkD7OzssGbN\nGjg6OkKr1eKFF15Au3bt0KhRI3PUp2h2dnb49NMvYW9vjy1bvkWfPj3w3Xc74O3tLXdpRERUCYza\nvezo6AhAN+vlR11My8bGBitXfooXXxyK06dPok+f7rh7967cZRERUSUodaYLAKIook+fPoiJiUFk\nZGTps9zgYHiJRa+8lHDsTLHDvZo3KPZ1WcerhCI9VGY9XwFwGDkan3++Cr16PY2tW3ejevUaFd9+\ngT6s6s+/oNweLKaeco5HzHWLqofjOd4SxisiL4CH/v0uzKjQValU2LFjB1JTU/Hqq6/i0qVLCAsL\nK/Fr1CqhyGu+vq4P+QZFx1rC+MI9VHY9n376Mby83LFo0SL07v0M/vjjDwQHB1d4+3l9yP3nWZHx\napVgUfWUZ/xDv8ZK6i843uBrLaCe8ozXP7eQeso7vrh/a+Wsp8zjoYy8MJYglfFiwCtXroSzszOG\nDRtW4ri4OOs+G9fX11WWHiRJwtKli7B48QJUr14D27btRkhIyf/BKYlcfZiSEnoA2IclUUIPgDL6\nsPgeRBHIzISQmQEh9x4ZmRCyMiFkZgKZGRAyMuE+dJBRmyt1ppuQkABbW1u4uroiMzMThw8fxiuv\nvFLhPqh4giBg6tTpcHBwxDvvzEbPnk9jy5ZdqFu3ntylERHJy8gAzHtfN7bg89z3szLzt5M7Hpm5\n28nIHZf3fna2cbWZKnTj4uIwffp0iKIIURTxzDPPoH379sYVQeU2duwEODo6YMaM19G79zP47rsd\naNiwsdxlEREVJUlAdjaE9DQI6em5tzT9PdLTIaQ95D1JA9fEFMMAzMrSPy9XAJa1fEEAHB0hOThA\ncnCE5OICyccXkoM9JAdHIO91BwdIjo6Avb3hcwcHuBj5vUoN3YiICGzfvr2CLVF5jBgxCg4Ojpg8\neRz69OmBTZu2onnzlnKXRUTWSJKAjIwioWd4nw4UfC2taEgWHZf7ulZb7tIcCpZZXAB6+0BydCga\ngA4ORQPRwQGSfe57jo4FxjoCDvaGz/O2aWsLCGU7NluYyUKX5BUZOQQODg4YO3YU+vV7Dhs2bMbj\nj7eRuywiqkyiqAuylBQIqakQUpINH6emQJWaCkg5cI5/UCQQ9Y/T8h8jIx2CCdbzllQqSE7OkJyc\nACcniN4+kJyc9K9JTk6QnAs8dnIGDN43fM+rpi/i00VdANo7AHZ2FQ5AS8bQtQJ9+z4POzt7jB49\nHAMH9sGaNZvQvv2TcpdFRAVJku64YEpKbiimFB+aqbrHKv17KbrX8h6npEBISzU6IJ2KK8XWVh9u\nors7pIDqucH38PArGJYlhSTs7U0bir6ukCz5RCoTY+haiR49noODw3oMH/4iBg9+HqtXr0GXLk/L\nXRaR9cvJyZ095oeeKi0lPwALhmZaam5gFgzRlPyvL+fFgyQ7O0iurpBcXCEGBUN0dc197gLJxS3/\nsasrJFfc0+i4AAAgAElEQVQ3iC4ukFxc4FHTDwlZAJxzw9HRUReMtram/TMik2HoWpHOnbth/frN\nGDJkIIYOjcSnn36JHj16yV0WkbxEEUJyEoTERKiSEiEkJkJISoQqMbH415ISgdRkeCcl6YKynMtr\nSmo1JBddOIoB1SE560JRdHXLD0gXV/2Y/OB0g+iS/1hycdHNHsvD1xXaKjRLVAKGrpVp164DNm3a\nhkGD+mPkyKH48MNV6N9/oNxlEVWMkcGpSnxQJECF5KQyHauUnJwAd3eInl6QagUWmUnqQzMvLAuG\npqsrRGfdPRwdFX3skSoHQ9cKPfZYa2zZshMDBvTB2LGjkJWVhcGDX5K7LKrqSgzOB/qQzAvSigan\n6O4BsXp1iPXqQ/LwgOTuAbHQveThAdHDE5KHJ0R3D0ju7oC9PXx9XfGAM0SSAUPXSjVr1gLbtn2P\n559/DpMnj0NmZgZefnm03GWRUmRkQBUfB9X9eKjux0OIj4fq/n2o7scDmalwi40zTXB6eEKsXgNi\n/UfyQ1Iflh6FXvPUvwc7u0psnqjyMHStWMOGjbBjxx707dsDM2dOQ0ZGJsaNmyh3WWSJ0tL0AaoP\n0fgCz+/H54bsfaji43UXLShB3hFIyckZoocHg5PISAxdKxcRURe7du1B3749MW/eW8jISMfrr8+A\nwGNNyiVJhiEaHwchNyzzn+cFqm52KqSnl75Ze3uI3j7QhIVD8vaG6O2ju/n4QPLxzX3uDc/QWojX\n2up21TI4icqEoasAISFh2LlTN+NdsuQ9ZGZmYvbstxm81kKSdB9FiYszDMr4eMNdvPfv658bc8at\n5OCgC9HwiEIh6gvJx0cfonnPJWcX404MqmKfqyQyJYauQgQGBmHXrp/Qt28PrFy5HBkZ6Zg/f7Hc\nZVVdkqQ7eSg2Fqo7t6G6GwukJcL5+q1Cx0lzH2dllb5JR0eIPr7Q1K2nuwpQboDqZ6O5AZoXrnB2\n5tm1RBaGoasgAQHVsWPHHvTv/xxWr/4MWVlZ+Prr1XKXpTzp6VDF3oH6bm6g6oP1DtR37kAVeweq\nu7HFzkYLXj1IcnKG6OMDTf1HdCFaIDAfGqJEZNUYugpTrVo1bN/+PQYM6IN1677BvXt3MG/eYtSu\nHSJ3aZZPo4Hq3l1daBYIT/Wd27rHsXd0AZuU+NBNSCoVRN9qutmof4D+pg2oDrewIDywdc4PUafi\nLuBHRErG0FUgLy9vbN26C6+8Mgy//PIL9u/fjwkTpmDs2ImwL++Vb6yZJEF4kKAL0oKz0dhYqGIL\nzFTj7pX4kRfRwwNiQAA0TZvlBmkARL8AiAHVIfr76+59fAGbh/y18nWFhsdCiao0hq5Cubm5Y+PG\nrfjzzz2YMGEiFi2ajy1bvsWiRcvQrl0HucsznbQ0qO8WmJnmBqsqNm+GGgvV3TslHjOVHBwg+gcg\np9XjEIsJUq2fP0T/AN0ViIiIKoChq2CCIGDAgAFo0aINFi2aj9WrP0O/fj3Rp08/vP32Qvj5+cld\n4sPlnoikvhEDJMXB4eKV/BlqgWBVJSc9fBMqFUQ/f90xU78AXaDm7uoV/fz1wSq5e/CEIyIyC4Zu\nFeDm5o758xdjwIBBmDZtErZt24Jff/0FM2fOxtChL0OtVstTWFoa1DdioI65BlXMdaivX4c6RndT\nxVyHKiVZP9S10JeKnp4Qa9SEpnkLaP0Dip2hij6+gFy9EREVg6FbhTRq1AQ//PAb1q79GvPnv40Z\nM17Hpk0b8P77/0OTJs1M/w2zs6G6dVMfpLowvaZ7fP06VPFxxX6Z5OQEbWAQcgJbQxsYBKd6dZDs\n6gWtf26g+gcADg6mr5eIqJIxdKsYtVqNoUNH4JlneuDtt2dh8+ZN6Nr1SQwdOgIzZ74Fd3cP4zcm\nirqPzsRch+r6NYNZqjrmOlR3bkMQxSJfJtnaQluzFjSPNIA2MAjawCCIuffawGBIPj4Gu3udfF2R\nxROQiEgBGLpVVLVq1fDRR59h0KAXMW3aJHz11Rf4/vtdePvt+ejb93nd1awkCUJCAtS5s1OVfvdv\n7u7gmzcgZGcX2bYkCBADqiPn0ccKhGkQxKBg3b1/AHf7ElGVxNCt4to2boIDH32OXz/7GCd3bkPW\nqyNxcdZ0NPX0hGNsLFRpqcV+nejjkztTDS4UrEHQ1qhV/kW5iYgUjKGrdFlZUF+OLjBLzdv9mzt7\nTUgAAAzOvQEAEu4jOeE+Yn184dGmLVA7JDdYdTNVba1AwMVFro6IiKwWQ1cJJAlCfDxsoqOgvhgF\ndXQUbC5GQX0pGrh9C17FXPBBsreHtlYgNI2b5odpkC5Qf4m+iCnz38btO7cReOEC3hs6Ak891VWG\nxoiIlIWha01EEapbN2Fz8QLUFy/mh2t0FFQPHhQZrq1eA2jfHhkBNQ1OVBKDgiBW8wNUqmK/Taem\nzXHwmR5YunQRPv30Iwwa1B/du/fEu+++hxo1alZ2l0REisXQtUQ5OVBfvQL1xagCs9eLsLl0sci6\nqJJKBW1wbeS0ehza8Aho6kRAWycC2vA6kFxc4evritRynPnr4uKCOXPm4fnnX8C0aZPwww+78Oef\nv2PatJkYOXI0bG1tTdUtEVGVwdCVU1oabC5dzA/V3Fmr+uoVCBqNwVDJwQHa0HBo6tTJD9fwCGhD\nQiv1pKV69epj5849+PbbDXj77VmYO/dNfPvtBixe/D+0avVYpX1fIiIlYuiagXD/ftHjrdEXob55\no8hY0d0DmibN8kO1Th1owiMg1gqU7WM2KpUKL7wwGF27Po13352Ldeu+QY8eXRAZOQSzZ78NLy9v\nWeoiIrI2DF1TkSSobt8qsEv4ItQXL8AmOgqq+/eLDNf6+SP7iQ76UNXWiYAmPAJStWoWex1gLy9v\nLFv2IQYMiMS0aZOwfv0a7NnzPd56ax4GDoyE6iHHiImISIehW1YaDdTXrhaatUZBHR1d5DOtkkoF\nMTAIWc1bFtglXAfaOhGQ3NxlaqDiWrV6DL/9th9ffPEpFi2aj4kTX8OGDWuxePH/UL/+I3KXR0Rk\nsRi6D5OeDpvTJwuEq+5sYfWVyxBycgyGSnZ20IaGI7tAqGrCI6ANDVPsNYJtbW0xZsxYPPdcb8ya\nNR3ff78TnTq1xahRr2Hq1Olw4ed4iYiKYOgCEFKSYXPqJGxOnoDNqeOwOXkCuHIZnoU+3yq6uELT\nsBG0deoW2CVcB2JQcJW9rGH16jXw5Zdr8dtvP2P69Nfx8ccfYMeOrZg/fzGeeeZZ3eUkiYgIQBUM\nXSE5CTanT+kC9uR/uvsrlw3GiG7uQLt2yKgdVuCEpgjdNYMZIsV66qmuOHCgHVasWIIPP1yOYcMi\n0blzVyxY8D6CgoLlLo+IyCIoOnSF5KQiM9giAevugewn2kPTqAk0TZoip1ETiMG14VvNrVyfb63K\nHB0dMX36bPTtOwBvvDEZv/76Mw4e3I9Jk17Hq6+Oh52dndwlEhHJSjGhW6aAbdwUmsZN9AHL2atp\nhYfXwdatu7F163eYM+dNLFjwDjZv3oRFi5ahbdt2cpdHRCQbqwzdIgF74jhsrl4xGKML2A7QNG7C\ngJWBIAjo128AOnfuioUL5+Grr75Anz7Pol+/AZg7dz6qVasmd4lERGZn8aGrD9gTx/NnsKUFbOOm\nupObGLCyc3f3wHvvLcWAAYMwbdpkbNnyLX755Se8+eYcDBkyDOoqegIaEVVNFhW6QlJi0V3EJQRs\nTpOm0DRqwoC1Ak2bNsdPP/2Br79ejQUL3sEbb0zGpk3r8P77y9GoURO5yyMiMgvZQteogPXwQHa7\nJ3Nnr00YsFZOrVZjxIhX8OyzPTFnzkxs27YFXbp0wPDhIzF9+iy4WfEFQ4iIjGGW0DUI2JPHYXvy\nONTXrhqMYcBWHX5+/li16ku88MKLmD59Cr744lPs2rUD8+YtRK9effnZXiJSrMoJ3T/+gOPev2Bz\n6oRxAdu4KcTAIAZsFdO+/ZPYu/cwVq5cjuXLl2DUqOFYv34tFi1agtDQcLnLIyIyucoJ3U6dkHcR\nQIOAzTsGy4ClXPb29pgy5Q306dMfM2ZMxR9//Ib27R/HuHGTMGHCFDgo9DKaRFQ1VU7oTp+OpPD6\nDFgyWu3aIdi4cSu+/34XZs16A0uXLsLWrd/lnvncW+7yiIhMotS12GJjYzFkyBA888wz6NGjB9as\nWVP6VhcuRHaPXjwmS2UiCAJ69HgOf/31D0aNeg03bsRg4MA+6NevH44d+wdSoWthExFZm1JDV61W\nY8aMGfjxxx+xadMmrF+/HpcvXy7ty4jKzcXFFfPmLcSvv+5HixaPYuvWrXj66U7o0OFxfPbZx0hI\nKLo+MRGRNSg1dH19fVGvXj0AgLOzM0JDQ3Hv3r1KL4yoQYOG+P77X/DTTz+hZ8/euHQpGrNmTUej\nRhEYNWoY9u/fC1EU5S6TiMhoZTqme/PmTVy4cAGNGjWqrHqIDKhUKnTt2hXNmrVGfHw8Nm/ehHXr\nvsb27VuxfftWBAUFIzJyCAYOjIS/f4Dc5RIRlUiQjDxQlpaWhhdffBGvvvoqnnrqqRLHBgej2BnI\nsWNpxY5v3ty52NflHK9SqYr0YE315ynYhyXUU57xeT3kjZckCX//fRTr13+DXbu2Iz39LADAwcER\nLi4ucHBwgCAIFlN/npgYFeKKWbnK0v/8C4/39XU16EPuesozvmAPllBPecf7+roiMLD4vT3WUD8A\ntGzpavV5Aej+fhvDqJmuRqPB+PHj8dxzz5UauHlUqqIF+Pq6PmRs8duQe3zhHuSup7zj8/qwlHrK\nM16lUhmMf/bZznj22c5ISkpCSIgKqakpyMzMQGZmBtRqNVxcXJCUdB9hYWEWUX9JX2MNf/6Fxxd8\nbAn1lGd83nNLqaf844v/AmupX/c11p8XxjJqpjtt2jR4enpixowZRm+4uP/RW5PC/5u3Vkrow9ge\nTp8+hQ0b1mDLlu+QlJQIAGjbth0iI4ege/eesn/mVwk/C0AZfSihB0AZfSihB6Dk/1QUVGpGHzt2\nDLt378aRI0fQq1cv9O7dG/v3769wgUSm1rBhIyxcuASnTkXh448/R5s2T+Dgwf0YM+ZlNGpUBzNn\nvo6zZ8/IXSYRVWFGH9MtK2v/n4uS/vdl7X1UpIcrVy5hw4Z12LhxHeLidGfdN2vWHJGRL6F3775w\ncTHuf6emoISfBaCMPpTQA6CMPpTQA2DCmS6RNQsJCcOsWXNx4sR5fPPNRnTp0g0nThzHlCnj0aBB\nHUyc+Br++ecoL7xBRGbB0KUqwdbWFk8/3R3r1n2H48fPYcaM2fDx8cWGDWvRvXtntGvXCqtWrcT9\n+7zwBhFVHoYuVTkBAdUxadLr+PvvE9i8eSd69eqDq1ev4K23ZqJRozoYOXIo9u79gxfeICKTk20R\neyK5qVQqtG//JNq3fxL379/Hli2bsH79GuzcuQ07d25DYGAQXnhhMF54YTCqV68hd7lEVMmys4H0\ndCA9XShwr3uclpb/WkZG0TEbNxr3PXgi1UMo6eC+tfdhzh4kScKxY/9g/fo12L59K9LT06BSqdCx\n41OIjHwJXbp0g62tbbm2rYSfBaCMPpTQA6CMPsrSgyiimMDLv8/IKDksC79XOEA1mvIv0GNsknKm\nS1SAIAho0eJRtGjxKObNW4gdO7Zh/fpv8Ntvv+C3336Br281DBgwCIMHD0FISNELbxCRjiQBaWlA\naqqAlBQBKSmGj9PSdI+1WiA+3v6hQVg4LE1BpZLg5AQ4OenuvbzEAs91rzk7S3B0zB9jeF/0NehX\nkS8ZZ7oPoYT/QQLK6MMSejh37iw2bFiDzZs34cGDBwCA1q3bIjJyCJ599jk4OjqWug1L6MMUlNCH\nEnoATN+HJAFZWXnhmB+SqanIDUvd4/zwzH+vuK+RpPKHpL19ySFX8N7R0XCMs/PDw9LREbC3N/2q\ns8Z+ZIih+xD8S2k5LKmHzMxM7NnzPdatW4MDB/YCANzc3NGv3/OIjHwJDRs+fDEQS+qjIpTQhxJ6\nAPL70GhQKAx1j4ubZRYOybzHea/n5JQvjezsJLi6SnBxAVxcdI9dXQFXVwnOzvmPdWN0z11cJNSs\n6YTs7LQiM0y12sR/WJWMoVtBSvtLac0stYdr165i48a12LhxPWJj7wAAGjduisjIIejTpx/c3NwN\nxltqH2WlhD4srQdJ0h2rTEwUkJgoICkp7x548CD/eeH309JUSE6Wyr3bVRAMw9DZuWAwokBA5j/P\nC1NdkOaHp719+Xq3tJ9FeTF0K0hJvwjW3oel96DRaPD7779i/fpv8OuvP0Or1cLR0RE9e/ZGZORL\naNXqMQiCYPF9GEsJfVRWD5mZKBSQMAjJwqFZ8P3sbOOD09ZWgru7BC8vFRwdtUVmjwXDMO/1ggGa\n956Tk+l3s5aVEn6fAIZuhSnpF8Ha+7CmHmJj7+Dbbzdg/fo1uHbtKgAgLCwckZEvYeTIobCzc5O5\nwoqzpp/Hw5TUQ04ODELz4YFZ9P3MTOMTTK2W4OEhwd0d8PCQ9Dd3d6nQ86Lv54Wl0n8W1oShW0FK\n+kWw9j6ssQdRFHHo0EGsW/cNfvhhF7KysgAAoaFhaN26rf4WEFBd5krLzlp+Hnlnz8bHC7h/X0B8\nvID4eBXu3xeQnm6PO3dy9KFZcBduWXbVCoIuFN3dJXh65gem4fOi73t46HbXVnSWaS0/i5IooQeA\noVthSvpFsPY+rL2HBw8SsG3bZhw48Cf27z+A1NT8XkJCQvUB3KbNE1YRwnL+PDIzDUM0Li7vsapQ\nuOoeZ2QYl2pubg+bZepC82GzUFfXsq+nakrW/ncDUEYPAEO3wpT0i2DtfSihB0DXx507D3DmzCn8\n9ddBHDp0AEeOHEZKSrJ+TO3aIWjT5gk8/ngbtGnzhEVeCcuUP4+cHCAhQReexYWmLlhV+sepqaWH\nqL29BB8fw5u3twQfH1H/PDTUCZKUCg8PCW5ugI2VXrFACX83lNADwNCtMCX9Ilh7H0roASi+D61W\naxDChw8fMgjh4ODaBiFco0ZNc5ddREk/D61Wd7Zt4QDNn5EWfE+FxMTSQ9TGJi808wPU17domOa9\n7uxc+m5bJf9OWRsl9AAwdCtMSb8I1t6HEnoAjOtDq9Xi7NnTBiGcnJykfz8oKNgghGvWrFXZZUOj\nAe7dE3DnjoDYWBUyMx1x7VpWkRlpfLyAhAQBolhy4qlUEry8Cs9Ciz7OC1N398q5kEFV+Z2ydEro\nAWDoVpiSfhGsvQ8l9ACUrw+tVotz587gr78O4NChgzh8+BCSkhL17wcGBqNNm/wTs2rVCizT9tPS\ngNhYAbdvq/SheueOgNu38x/fu1d6kHp46EKy5Bmp7ubpKcl+4YOq/DtlaZTQA8DQrTAl/SJYex9K\n6AEwTR+6ED6LQ4cO4K+/DuLw4b8KhXAQWrdui8cea4v69dtDrQ7MDVEVYmMF3LmjC1LdTYXk5IeH\nqZ2dBH9/CQEBIgICJAQESPD3FxEW5gBb23T4+OhC1ctLQjnXgJANf6cshxJ6AIwPXSs9fYCoalKr\n1ahTpxHc3BqjcePx6NVLwokT93DqVDyuXMnErVu22LTJD5s21QBg99DtuLtLqFFDRPPmulD195dQ\nvXr+44AA3ey0uN26vr4OiIvTVl6TRArG0CWyEJIEJCejwK5e3Wy04K7e2FjdCUiGaufedMdLfXyy\n4eAQj5yca0hMPI2srCsAbgG4CT8/EW3ahKB9+1Zo3botAgODIMh9SSKiKoShS2QGWi1w6xZw5oyq\nwK7egrt7da+VdGEGJyfdDLRuXU3uzFTM3eWrm6FWr67b3as7XuoKoCFE8RGcP38Ohw8fxF9/peDw\n4YPYtu0Atm37BgBQo0ZN/WeEW7dui6CgYIYwUSXiMd2HUNJxBmvvw1p6SEkBrl9X4do1Fa5fF3D9\nukr//ObNkldv8fExPG4aEKAL1bxdvQEBItzcKn4WryiKuHDhfG4IH8Thwwdx//59/fs1atTUnxnd\nunVbBAfXLhLC1vLzKIkSegCU0YcSegB4IlWFKekXwdr7sJQetFrdmb7Fher16wLu3y/+0kQ+PiKC\ngiSEhqrh5ZVtcGJSQIAIP7/yr9BSUaIoIirqAg4dOoBDh/7CoUMHDEK4evUaBiFcu3YIqlVzs4if\nR0VYyu9URSmhDyX0ADB0K0xJvwjW3oc5e0hNhT5M84JVF6oq3LhR/EowtrYSAgMlBAWJ+ltwcP5z\nFxfz91FekiQhKuoC/vrrAA4f1oVwfHy8/n1//wA0bdoEgYEhqFMnAuHhEQgPrwNvb28Zqy47a/hZ\nGEMJfSihB4BnLxMVSxR1s9W8UL12LT9Ur18v7iQlHW9vEQ0aFAxV3ew1KEg3a5X7c6emIggC6tat\nh7p162HEiFcgSRIuXozSh/CRI4ewZ8+eIl/n7e2tD+D8WwRq1qwFlZwXJyayMAxdUpz0dBjMVAvu\nAo6JUSErq+hs1cZGQq1aEho00BQJ1aAg3fHUqkgQBERE1EVERF0MHz4SAGBjo8GRI/8hOvoiLl6M\nwqVLuvu//z6CI0cOGXy9o6MjQkPDUadOnQKhHIGQkFDYy7VPnUhGDF2yOpIE3L1reGy14Gz13r3i\nZ1aenhLq1Ss6Uw0K0p35a60XvTc3T09PtGjxKFq0eNTg9czMTFy9egXR0VEFwvgiLl+OxpkzpwzG\nqlQqBAUFo06dCISF1cndVa2bIbu7e5izHSKz4j8zZJEkSbcb+MIFFWJjgTNn7PWhGhOjKnbJNrVa\nQs2aEtq10+hDVXevu7m7y9BIFeLg4IB69eqjXr36Bq+LooibN2/khvFF/cw4OjoKP/+8Bz//bLi7\nulo1v9wwDjc4bhwQUJ0fZyKrx9AlWUmS7mL6Fy6oEBWlu124oMbFiyokJRX8B1Z3dSU3Nwnh4WKB\nMJX0M9caNThbtUQqlQqBgUEIDAxCp05dDN67f/8+oqOj9Luqo6OjcOlSNA4e3I+DB/cbjHV2dkF4\neDjCwyMMZsjBwbVha23XoaQqi/9EkdnExeWHa37Iqoss76ZWSwgJEfHEEyIiIkS0bGkPb+80BAWJ\n8OCeR0Xx9vaGt3drPPZYa4PX09PTcflydIEw1s2Qz507ixMnjhuMtbGxQe3aIQXCOFx/7+Ji3Bml\nRObC0CWTu39fKBSsulvhz7GqVBJq15bQurUGdevqAjYiQkRoqGjwuVVfX3vExYlm7oLk5OTkhIYN\nG6Nhw8YGr2s0GsTEXEN0dLTBSVzR0RcRHX0RP/6422B89eo1DM6mzpsh+/i4mLMdIj2GLpXbgwdA\nVJS60K5hVZGP3QiChKAgCS1b5hiEa1iYCAcHmYonq2RjY4OQkDCEhISha9en9a9LkoR79+4VOYkr\nOjoK+/b9iX37/jTYjouLC/z9A+DvHwA/P//cez+D1/z8/OHk5GTuFknhGLpUqqQk4MIFtUGwRkWp\nij1LODBQRJcuGkREaBERIaJuXV248t8uqkyCIMDPzw9+fn5o27adwXupqSkFPt6kmyHfvHkdt2/f\nxqVL0SVu193dA/7+/vDzC4C/v39uKPvnhnL+Y378iYzF0CW9lBTkBqraIFxjY4uGa61aIp56SpM7\na9Wibl0R4eEinJ1lKJyoBC4urmjatDmaNm2ufy3vKkhZWVm4d+8u7t6NRWxsLO7evYPY2FjExt5B\nbOyd3NfvICrqQonfw8vLq5hgDjAI6WrV/HjCFzF0q6LUVODixfwzhfPC9fbtouFao4aIjh01ubNW\n3ey1Tp38SxsSWTN7e3vUqhWIWrUCSxyXkZGBe/fuFgjm/HDOC+Zbt27i/PmzD92GIAjw9vbRB3HB\nXdsFw9nHxxc2PA1fsfiTVbD0dODff4HDh230s9eoKBVu3CgargEBIjp00Oh3CeftHnblyZ9EcHR0\nRFBQMIKCgkscl5aWhrt3Y/VBnB/M+Y+vXLlc5GIhBalUKvj6Vis0Yy46g7a2612TDkNXITQa3a7h\n48fVOH5chf/+081gRREAHPXjqlUT8cQT+WcL581eeeEIoopzdnZGSEgoQkJCSxyXmppisBu74K7t\n/F3a53Hy5PGHbsPGxgZeXl5wc3OHu7s73N09Ctx76J97eHjAza3ovVopFwy3MgxdKyRJQEyMgOPH\n1fjvP13InjqlNrhKk5OThJYttWjZ0gaBgZn62aunp4yFExEA3XHmsDBXhIWFP3SMJElITk4q9hjz\n3bt3ERt7B8nJiUhIeICYmOvIzs4uUw2urm7FhLW7QVjnv+ZpEOCOjo68Olg5MXStQEICcOJEXsDq\nQrbgx3JUKgl164po1kyLZs1ENG2qm73a2OSdMJIjY/VEVB6CIOhnrBERdYsdk3dCmCRJyMjIQHJy\nEhITE5GUlISkpAe597rniYmJ+vcL3sfEXEdKSnKZarOzs9PPmjnLLhuGroXJyABOn87bTawL2mvX\nDI/BBgaKeO65HDRtqgvZhg21PGuYqAoTBAFOTk5wcnKCv39Amb9eq9UiOTnJIKSTkhILBHhigZth\nkF+/fg05OWX7j72rq5s+gH18vGBraw9HR139jo6OcHJy1t87ORV87lRgXP69s7Pu3hrCnKErI60W\niI5W4b//VPpZ7PnzKmg0+bttPD0ldOyoyQ1YLZo0EeHrK8lYNREpjVqthqenFzw9vcr8tWWdZRcM\n7piY6zh79rTJ+rC3ty8S2oXDOu/2sJA3DHvDcLezs6vwbnWGrplIEnD7tqA/Bnv8uBonTqiRlpb/\nA7S3l9CkiW43cdOmulvt2hJ46ISILFVFZ9ne3s6IibmH9PR0ZGSkF3ufd8vIyEB6elqh+4LjdK+l\npaUjOTkZsbGxyMhIhyia5jKyarW6SFjnzcT3799r1DZKDd2ZM2di79698Pb2xu7du0sbTrmSkqDf\nRZx3NnHBKzgJgoSICBFNm4r6WWzduiLs7GQsmojIzFQqFZydneFcScfIJElCVlZWgSDXBXZ6enEB\nXtmry4cAAAsRSURBVFyQFwx9w/vExESkp6eVafd6qaHbp08fvPjii5g2bVqFGleyrCzg7FmVwdnE\nly4ZHluoXl1E9+45aNpUN5Nt3FjLz8ASEVUyQRDg4OAABweHcu0+N4ZJQ7dFixa4detWhQpSElEE\nLl/WHYfNm8meOaNCTk7+PmBXV91C6rrdxLqZrL8/j8MSESlRWS7vyWO6pbh7V3ccNu9kpxMn1EhJ\nyQ9YW1sJDRqI+mOwzZrplqZTFb3oExERVXEM3QIkCTh/XoUDB9Q4fhw4csS5yPWIw8K06NYt/2Sn\nRx4xXPuViIjoYSotdH19reOA5Y0bwG+/6W6//w7cvZv/np+fCj17Ao8+CrRqBbRoAXh4qAGoAVjP\naiHW8rMoiRJ6ANiHJVFCD4Ay+lBCD8YyKnQlqezHI+PiUsr8NeaQlAQcPGiD/fvV2L/fBpcv589k\nq1UT0a+fFu3aadCzpyMcHVMMPq6TkwPExclQdAXkXbHGmimhB4B9WBIl9AAoow8l9AAY/x+HUkN3\nypQpOHr0KBITE9GhQweMGzcOffv2rXCB5pKVBfzzj1ofsidOqCCKuiR1dpbQpYsG7dpp0K6d7tKJ\neSHr62t9AUtERJat1NBdunSpOeowGVHUfXxn3z5dyB49mr8QgI2NbhGAdu10t2bNtOCa0kREZC6K\nOJHq+nUB+/frdhkfOKBGQkL+LuN69fJCVoPHH9dy8XUiIpKNVYZuQoLuuGzebPb69fyQrV5dxMCB\nOWjXToMnntDCz4+fjyUiIstgFaGbkQEcPZp/XPb0aRUkSbfL2M1NwjPP5Ohns6GhvFYxERFZJosM\nXa0WOHVKpd9l/PffamRl6ZLUzk5Cmzb5u4wbNdKtG0tERGTpLCKuJAm4elXAvn26kD140AZJSfnT\n1YYN80O2VSstnJxkLJaIiKicZAvde/cEHDyYv8v45s3847KBgSJ69tTtMm7TRgsfHx6XJSIi62e2\n0E1NBY4cUetns+fP56/C4+kp6UO2XTsNgoMZskREpDyVFro5OcDx4/nHZf/9Vw2NRrfL2MFBQvv2\nugtStG+vQYMGXCCAiIiUr1JCt2dP4M8/XZCaqgtZQZDQpImov/JTy5ZaODhUxncmIiKyXJUSurt3\nAyEhEvr10+0ybttWAw+PyvhORERE1qNSQvfaNcDJKa0yNk1ERGS1KuVIalBQZWyViIjIuvH0JSIi\nIjNh6BIREZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIR\nEZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eI\niMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMGLpE\nRERmwtAlIiIyE4YuERGRmRgVuvv370e3bt3QtWtXfPbZZ5VdExERkSKVGrqiKGLevHlYvXo1vv/+\ne/zwww+4fPmyOWojIiJSlFJD99SpUwgKCkKNGjVga2uL7t274/fffzdHbURERIpSaujevXsXAQEB\n+ud+fn64d+9epRZFRESkRKWGriRJ5qiDiIhI8WxKG+Dv74/bt2/rn9+9exfVqlUrdcO+vq4Vq8wC\nKKEHQBl9KKEHgH1YEiX0ACijDyX0YKxSZ7oNGzZETEwMbt26hezsbPzwww/o1KmTOWojIiJSlFJn\numq1GrNnz8bw4cMhSRL69euH0NBQc9RGRESkKILEg7ZERERmwStSERERmQlDl4iIyEwYukRERGZS\n6olUZbF//34sWLAAkiShb9++eOWVV0y5ebOYOXMm9u7dC29vb+zevVvucsolNjYW06ZNQ3x8PNRq\nNfr3748hQ4bIXVaZZWdnIzIyEjk5OdBqtejatSvGjh0rd1nlIooi+vbtCz8/P6xatUrucsqlY8eO\ncHFxgUqlgo2NDbZs2SJ3SeWSkpKCN998E9HR0VCpVFiwYAEaN24sd1lGu3r1KiZNmgRBECBJEm7c\nuIEJEyZY5d/xr7/+Glu2bIEgCKhTpw4W/r+9u3mJag8DOP6dHKRQexElCyzIjCySFr1AEyamSTXV\nxGCLNiVRbdIow14oghYJLfoHWkREEBEaRG1EszGmQiuGYIgwIhhMKkRT5yXPnOcu4l64G+89x7nz\na7rPZz1n+A6HmYcznHmmo4P8/HzTWY7cunXrr/fCv/qslQxJp9NSX18vsVhMfvz4IXv37pWhoaFM\nPX3WDAwMSDQaFb/fbzrFtS9fvkg0GhURkcnJSdmxY0dOngsRkXg8LiIilmVJU1OTRCIRw0Xu3Lx5\nU9ra2uT48eOmU1yrq6uTsbEx0xmzdvbsWbl//76IiExPT8vExIThIvfS6bT4fD4ZHh42neLYyMiI\n1NXVSSqVEhGRkydPSldXl+EqZ96/fy9+v19SqZRYliWHDx+WT58+zXhMxr5e/l12NG/YsIH58+eb\nzpiV0tJSqqqqACgoKKCioiJnV3fOmzcP+HnVa1mW4Rp3RkZGePr0KU1NTaZTZkVEsG3bdMasTE5O\nMjg4SDAYBMDr9VJYWGi4yr1wOMyyZcv+tqo3l9i2TSKRwLIsksnkv1q89Cv58OED69evJz8/n7y8\nPDZu3Eh3d/eMx2Rs6OqO5l9TLBbj3bt3VFdXm05xxbZtAoEAPp8Pn8+Xk6/j6tWrtLe34/F4TKfM\nisfj4ciRIwSDQe7du2c6x5VYLMaiRYs4f/48+/fv59KlSySTSdNZrj1+/Jjdu3ebznBl8eLFNDc3\nU1tbS01NDUVFRWzZssV0liOVlZUMDAwwPj5OIpEgFArx+fPnGY/J2NAV/bnvL2dqaorW1lYuXLhA\nQUGB6RxX5syZw4MHDwiFQkQiEYaGhkwnOdLX10dJSQlVVVU5/x65e/cunZ2d3Lhxgzt37jA4OGg6\nyTHLsohGoxw8eJCuri7mzp2bs/8RPj09TW9vLzt37jSd4sr379/p6enhyZMn9Pf3E4/Hc+4+moqK\nCo4ePUpzczPHjh1j9erVeL0z3yqVsaHrdkez+m9YlkVrayv79u2jvr7edM6sFRYWsmnTJvr7+02n\nOPL69Wt6e3vZvn07bW1tvHz5kvb2dtNZrpSWlgJQXFxMQ0MDb9++NVzkXFlZGWVlZaxbtw6AxsZG\notGo4Sp3QqEQa9eupbi42HSKK+FwmPLychYuXEheXh4NDQ28efPGdJZjwWCQzs5Obt++zYIFC1i+\nfPmMj8/Y0P2ddjTn+hUJ/LwLe+XKlRw6dMh0imujo6NMTEwAkEwmef78OStWrDBc5czp06fp6+uj\np6eH69evs3nzZq5du2Y6y7FEIsHU1BQA8XicZ8+eUVlZabjKuZKSEpYsWcLHjx8BePHiRc6utX30\n6BF+v990hmtLly4lEomQSqUQkZw9F6OjowAMDw/T3d39j+ckYz8Z+l12NP95NTI2NkZtbS0tLS1/\n3XSRK169esXDhw9ZtWoVgUAAj8fDqVOnqKmpMZ3myNevXzl37hy2bWPbNrt27WLbtm2ms/6Xvn37\nxokTJ/B4PKTTafbs2cPWrVtNZ7ly8eJFzpw5g2VZlJeX09HRYTrJsWQySTgc5sqVK6ZTXKuurqax\nsZFAIIDX62XNmjUcOHDAdJZjLS0tjI+P4/V6uXz5MkVFM/9jku5eVkoppbJEN1IppZRSWaJDVyml\nlMoSHbpKKaVUlujQVUoppbJEh65SSimVJTp0lVJKqSzRoauUUkpliQ5dpZRSKkv+AO2e4yf8wTuC\nAAAAAElFTkSuQmCC\n",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFKCAYAAABcq1WoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xd4U2X/BvD7ZLRpumlLS6EDgbKh\niIggU7aAgPhDRKsIUoYgiK++ioAguBARXmZBEARFUBGhiChIEQcqe+/RMlpGd9KRcX5/nDZtaFra\nkuY07f25rlw5zXmSfPMk5OY5Oec8giiKIoiIiMhhFHIXQEREVN0wfImIiByM4UtERORgDF8iIiIH\nY/gSERE5GMOXiIjIwVSOeJJbtzLs/pi+vlqkpOjt/rhkjf3sGOxnx2A/Owb7WRIQ4FnsOqcd+apU\nSrlLqBbYz47BfnYM9rNjsJ/vzWnDl4iIyFkxfImIiByM4UtERORgDF8iIiIHY/gSERE5GMOXiIjI\nwRi+REREDsbwJSIih/vxx61YtGi+3GXIhuFLRETkYA45vSQREZEtGzeux65dPwMAOnbsjOeeG45/\n/tmHFSuWwNVVA1/fGnjnndk4eHB/kdtUKueNMKesPDZ2C7p37wSNxkfuUoiInN6MGVOxdetmuz2e\nQiGgb98BmDFjdontbty4hgMH/sGKFV8AAKKjX0DXrt3x3XcbMH78q2jZshX27PkVaWmpNm/z8/O3\nW82O5nSbnTMy0jFixHN48cUX5S6FiIjuw9mzZ9G0aXOoVCqoVCo0b94S58+fRdeu3fHxxx/giy9W\noUGDhvDz87d5mzNzupGvp6cXOnTohF27duHkyRNo0qSp3CURETm1GTNm33OUWhYBAZ6lms1OEABR\nFC1/GwwGCIICvXv3Rdu27fDbb3H4739fxezZc2zeFhYWbreaHc3pRr4AEB09DgCwYsVSmSshIqLy\niohoiOPHj8FoNMJoNOLkyROIiGiI1as/g1KpwoABT6Jbt564fPmizducmdONfAGgR49eqFevHr79\ndgPefnsG/P2de/MDEVF1FBQUjFatHsKECdEwm0X07z8AQUG1EBgYhEmTxsHT0wuenp4YOvQ56PX6\nIrc5M0EsPOavIKXZ/FBW69d/jokTJ+LNN6di8uQ37P74JCnt5iO6P+xnx2A/Owb7WRIQ4FnsOqfc\n7AwAL774Ijw9vbBq1Qrk5ubKXQ4REVGpOW34enp6YtiwKNy8mYQfftgkdzlERESl5rThCwAvvTQa\nCoUCMTFL4ICt50RERHbh1OEbFhaO3r374ujRw/j7731yl0NERFQqTh2+ADB6tHTY0fLlS2SuhIiI\nqHScPnwfeaQ9mjdviR9/3Ir4+Ctyl0NERHRPTh++giAgOnoszGYzVq5cLnc5REQkk/Pnz1kGYe+8\n8xZycrLL/ViHDx9ESkqyvUorwunDFwAGDhyMgICa+PLLL5CZmSl3OUREJIM9e35FQkI8AGDmzA/g\n6qop92Nt27alQsPXKc9wdTdXV1e8+OJLmDPnfWzY8BVGjoyWuyQiIrqHYcMGY+3ajRBFEX36PIaF\nC5ehUaMmmDx5PN54420EBdWCyWTCnDnv4fr1azAajXjppTFo3boNtm+PxaZNG6FSqVG/fgQGDhyM\nH37YhD17foWvry+mT38LX3yxAZ9+Oge+vr44c+Y0UlNT8OyzL2Dbtq1IS0vFokXLIQjAzJlTkZWV\nhezsbLz66uvQ6TKxd28cLl26iNmz5+DMmZP4+ut1UCpVaNiwMSZMePW+X3uVCF8AeOGFkZg/fy5W\nrFiKF198CQpFlRjUExFVOPcZU+FqxykFoRDg3ncAdPeYrKFhw8a4ePECjEYDGjVqjOPHjyIiohGS\nk5MRFFQLAPDLLz/Bz88fb701HampqZg4cQzWrPkaX3+9DnPmzEdgYBC2bduCOnXqoG3bdujSpRua\nNGlm9TxKpQoLFizFzJlTcezYUSxYsASzZk3DwYP7ER5eF/36DUSnTl1w4MC/+PLLNXjvvY9Rv34E\nJk9+A15eXlizZiWWLfscLi4umDbtTRw9ehgtWkTeVxdVmfANCAjA4MFDsH79Ouza9TN69Ogtd0lE\nRFSCyMgHceLEMeTm5uCpp57Gnj270bLleURENLS0OX78KI4cOYSjRw8DAHJycmAwGNC9ey9MmfI6\nevXqg+7de5W4iblxY2n2Oz8/f8tMSL6+ftDpMlGjhh/WrPkM69evhcFggEZj/TiXLl1EUlIiJk8e\nDwDQ6TKRmJiIFi3u77VXmfAFgFGjxmL9+nWIiVnK8CUiKiXdjNn3HKWWRUCAJ3SlOLdzq1atsW7d\nauTkZKNfvwHYtm0rjh07ggcffMjSRqVS4/nnRxT5To+KehE9evRBXNxOvPLKWCxeXPwOt0ql0uay\nKIrYuPEr+PvXxLRps3D69EksWjTf6r5qtbSped68Rfd8PWVRpbbNNmvWHB06dMJvv+3GqVMn5S6H\niIhKEBoahqSkJGRm6qDVusPPzw9798ZZhW+TJs3w++97AAApKcmIiVkMs9mMmJjF8Pf3x9Chz6FZ\ns+ZITEyEIAgwmUxlqiEtLRW1a9cBAOzZsxtGoxEAoFAoYDKZEBoajsuXL1l2vlq5Mga3bt2879de\nqvA9e/YsunfvjnXr1gEAbty4gaioKAwbNgwTJ06sVBMbcK5fIiLn4evri6CgIABS0N64cQM1awZa\n1j/2WHe4uWkxZswIvPHGq2jRIhIKhQJarTtGj34REyeOhSAIaNAgAi1btsL8+R9j//5/Sv38vXv3\nxYYNX+LVV19G06bNcOfOHWzbtgWRkQ9i6tT/4vr1a5g48TX85z8TMXbsCKSlpcLfP+C+X/c9pxTU\n6/UYPXo0wsPD0bBhQzz33HN466230KlTJ/Tp0wfz5s1DUFAQhg0bVuxjVMTUUsVNWWUymdCu3YO4\nceM6Dh06xbl+7xOnBnMM9rNjsJ8dg/0sua8pBV1cXLBixQrUrFnTctvff/+Nbt26AQC6du2Kv/76\nyw5l2odSqcSoUWOQk5ODtWs/l7scIiKiIu4ZviqVqsjeX1lZWXBxcQEA+Pn54datWxVTXTk988xz\nnOuXiIgqrfve27k0U/n5+mqhUinv2a6sihvSBwR44qWXRuLTTz9FXNxPePbZZ+3+3NVJSZtOyH7Y\nz47BfnYM9nPJyhW+Wq0W2dnZ0Gg0SEpKstokbUtKir5cxZXkXr8pDBv2IhYsWIC5cz9Bjx79IQiC\n3WuoDvjbjWOwnx2D/ewY7GfJff3ma0v79u2xY8cOAMDPP/+Mjh07lq+yChQWFo5evR7H4cOH8M8/\nf8tdDhERkcU9w/f48eOIiorC999/jy+++AJRUVEYP348Nm/ejGHDhiE1NRUDBw50RK1lxrl+iYio\nMrrnoUb24MhDjQoTRRHdunXEyZPH8e+/RxESEmr3Oqo6bj5yDPazY7CfHcPe/RwXtwtdunSz2+M5\nit03OzsLzvVLROTcbty4jp07d8hdht1V6fAFgEGDnoK/fwDWrVvDuX6JiCqRYcMGw2QywWg0okeP\nTjh9Wjot8OTJ45GYeAMAMG/eRzh8+CA+/3wFVq6MwaxZ0zFu3EvYv/8fTJ36huWx+vaVRsaXLl3E\nK6+MwcSJY/HWW68hI6Nybumo8uGbP9dvenoaNmz4Su5yiIgqpRqtm9m8aAptNfQcN8pmG8/o4ZY2\nmrWrgfDwUj1n/pSC586dsUwpaDabraYUfOaZKERGPogXXxwFADAaDViy5LNip42dP/9jvP76FCxY\nsBRt2jyCTZs2lqs/KlqVD19AmutXOlPXUpjNZrnLISIiFEwpeOzYETz11NM4efIELlywnlLwbvnT\nAxbn5MkT+Oij2Rg/Pho7dvxomRChsqlSUwoWp2bNmnjyyf/D119/ybl+iYhsSD5w/J5tMpasuGeb\n7Kjh8Jw8AbDTlIJ3U6vVAFDk3A35sxFpNBosXBhT6c/tUC1GvoA01y8AxMRwtiMiosqgNFMK5k/t\ndzd3d3fcuXMbAHD+/Dno9dLJnOrXb4B9+/4EAOzcuaNMMxw5UrUJ3+bNW+DRRztyrl8iokrkXlMK\nhoXVxZkzp/G//31idb/69SOg0bhhzJgR2LHjRwQFBQMAJk78D9au/Rzjx0fjxx9jS9yELacqfZzv\n3bZv34YXXngGzz33AubNW2j3mqoiHhfpGOxnx2A/Owb7WVJtj/O9W8+evREWFo5vvvkat2/flrsc\nIiKqpqpV+HKuXyIiqgyqVfgC0ly/Hh6enOuXiIhkU+3C19PTC88+G4WkpERs2fK93OUQEVE1VO3C\nFwBGjhwNQRCwfPkSOGB/MyIiIivVMnzDw+uid+++OHz4EP79t3IeA0ZERFVXtQxfgHP9EhHJ6ccf\nt2LRovl2eSydLhP//LMPALB27WocP3603I+VmJiIkyfvfbav+1Vtw7ddu0fRrFkLxMb+gISEeLnL\nISKicjpz5rQlfKOihqNZsxblfqyDB//FqVMn7FVasarFuZ1tyZ/r95VXxmLVqhV4551ZcpdERFSt\n3LhxDf/5zyu4eTMJQ4YMQ79+A6zWf/fdRuzc+RMEQYGOHbvgmWeew9mzp/HJJx9BrVbDxcUFM2d+\ngHnz5kCv1yEkJBTHjx9Fly7dkJaWisOHDyI1NRWXLl1EdPRY7Ny5A5cvX8L06bPRtGkzLFw4DydP\nnkBubi4GDhyMDh06Y9Wq5VCpVAgMDELt2iH49NM5EAQBWq0WU6bMgKdn8SfOKItqG76ANNfvu+9O\nx7p1a/Daa/+Fh4eH3CURETncjBmu2LrVfnGgUAB9+7pixoycEtslJMRj1aovodNlYvjwYejb9wnL\nhAjXr19DXNwuLFmyEgAwduxIdO3aHT/+uBWDBj2F3r374sCBf5GcfAfDhkXh4sULGDDgSatNzgkJ\n8Viy5DNs3boZ69atxqpVX2L79q3YuXMH6tdvgKCgYEyYMBk5OdkYMmQg+vcfiD59+sHHxwcdOnTG\nxIlj8frrUxASEopNm77Bpk0b8cILI+3SR9U6fPPn+v344w+wceN6jBgxSu6SiIiqjRYtIqFSqeDt\n7QN3d3ekpaXBx8cHAHDq1AlcvZqACRNGAwD0eh0SE6+jQ4fOmDv3QyQkxKNbtx4ICwvHiRPHbD5+\no0ZNIAgC/Pz8Ua9eAyiVSvj6+kGnOwJXV1ekp6dhzJgRUKlUSE1NKXL//OkJAcBgMKBx4yZ2e+3V\nOnwBaa7fBQs+wYoVSzF8+MhiJ2gmIqqqZszIuecotSykczuX5vGsp/0rPAugSqVGu3aP4o033i5y\nr88++wJ//rkXs2fPwPjxk4p9dKVSaXNZFEUcOnQABw/ux6JF0mbmHj06Frl/RU5PWO2TJn+u3wsX\nzuPXX3+RuxwiomrjxImjMJlMSElJQVZWFry8vC3rGjZsjIMHDyA7OxuiKGL+/LnIycnGd99tQHp6\nGnr27IOnnx6Gs2dPQxAEm9MOliQtLRU1awZCpVLh99/3wGQyw2AwWE1hWJHTE1b7kS8gzfX79ddf\nIiZmCbp37yV3OURE1UJoaDimTXsT164lIDp6nNUIMygoCEOGPIOXXx4FhUKBTp26wNVVg9q1QzBt\n2pvw8PCAWq3GlCnvIDU1BcuWLURAQM1SP/dDD7XFl1+uwfjx0ejYsTPat++AuXM/QPfuPTF79gz4\n+Phi4sT/YM6c9/Dll2vg4uKKGTNm2+21V6spBUsyaFBf/PHHXvz2299o1Kix3R7X2XFqMMdgPzsG\n+9kx2M8STilYCtHR0kk3VqxYKnMlRERU1TF88xSe6/fOnTtyl0NERFUYwzdP/ly/2dnZnOuXiIgq\nFMO3EM71S0REjsDwLSR/rt/ExBvYunWz3OUQEVEVxfC9S/5cvzExiznXLxERVQiG71041y8RUcUr\nzZSCu3fvdFA1jsfwtYFz/RIRyW/dujVyl1BhGL42cK5fIqKKlz+l4PPPP43Y2B+s1n311Rc4f/4s\npkx5HQcP7scbb0zC+PHROH36FPr27WZpN3XqGzh4cD/0eh2mTn0DEyeOxfjx0Th//pyjX06ZMHxt\nyJ/r12w2Y9WqFXKXQ0RU4Vq3drd5WblSbWkzbpzGZpvoaI2lzdq1aoSHl+45ExLi8eGH87BwYQxW\nroyx2s9m2LDn4eHhgfff/xgAcOHCecybt6jYMxBu3Lgebdu2x4IFS/Haa29i0aJPy94JDsTwLcbA\ngYPh7x+AdevWIDMzU+5yiIiqHFtTChanfv0GcHFxKXb9sWNHsXnzdxg/PhqffPIhdLrK/b3NiRWK\nodFoMHz4SMyd+yHn+iWiKu/AAd092yxZkn3PNlFRBkyerMGtW6V51uKnFLybWq22ebvRaMxbr8Kr\nr76OZs1alOaJZceRbwleeGEkXFxcsGLFUpjNZrnLISKqUkqaUhAAzGbbh3sKgoDs7GxkZ2fj7Nkz\nAIAmTZrht9/iAACXLl3E11+vq9Da7xfDtwSBgYEYNOgpzvVLRFQB8qcUnDRpbJEpBQEgIqIhRo16\nvsj9Bg58CtHRL+D992eiYUPpN+Cnnnoa164lYNy4l/DRR7MRGfmgQ15DeXFKwXs4duwIunXriM6d\nu+Kbb3649x2qGE4N5hjsZ8dgPzsG+1nCKQXvQ/PmLdG+fQfs2bMbp0+fkrscIiKqAhi+pcC5fomI\nyJ4YvqXQq1cfhIZyrl8iIrIPhm8pSHP9jkZ2djbWrVstdzlEROTkGL6lNGxYFDw8PLFy5XIYDAa5\nyyEiIifG8C0lT08vDBv2HOf6JSKi+8bwLQPO9UtERPbA8C2DunUfQK9ej+PQoYPYv59z/RIRUfmU\nK3x1Oh3Gjx+PqKgoDB06FHv37rV3XZVWwVy/POyIiIjKp1zh+/3336Nu3bpYu3YtFixYgPfee8/e\ndVVa7dt3QNOmzREb+wOuXk2QuxwiInJC5QpfX19fpKamAgDS09Ph6+tr16IqM0EQMHr0OJhMJs71\nS0RE5VLuczuPHDkS8fHxSE9PR0xMDCIjI4ttazSaoFIpy11kZZOdnY2wsDDk5ubi6tWrcHd3l7sk\nIiJyIuWaz/eHH35AcHAwVq5cidOnT2PKlCnYtGlTse1TUvTlLrA4cp+4+/nnR2Du3A+xePFyvPji\nS7LVUdHk7ufqgv3sGOxnx2A/S+w+scLBgwfRoUMHAECjRo1w8+ZNmEym8lXnpDjXLxERlVe5wjcs\nLAxHjhwBAFy7dg3u7u5QKqvOZuXSyJ/r9/z5c9i9e6fc5RARkRMpV/g+/fTTuHbtGp577jm89tpr\nmDFjhp3Lcg7R0WMBADExS2SuhIiInEm5fvN1d3fHggUL7F2L08mf6zcu7lecPn0KjRo1lrskIiJy\nAjzD1X0qmOt3mcyVEBGRs2D43qeCuX7XIzmZc/0SEdG9MXzvU+G5fteuXS13OURE5AQYvnbAuX6J\niKgsGL52wLl+iYioLBi+dsK5fomIqLQYvnbCuX6JiKi0GL52xLl+iYioNBi+dsS5fomIqDQYvnbE\nuX6JiKg0GL52NnDgYPj7B2Dt2tXQ6XRyl0NERJUQw9fONBoNhg8fibS0VGzcuF7ucoiIqBJi+FYA\nzvVLREQlYfhWgMDAQAwcOBjnz59DXNwuucshIqJKhuFbQTjXLxERFYfhW0FatIhEu3aPYvfuXThz\n5rTc5RARUSXC8K1AnOuXiIhsYfhWoN69H0doaBjn+iUiIisM3wqkVCrx0kujkZWVhXXr1shdDhER\nVRIM3wo2bFgU3N09ONcvERFZMHwrmJeXN4YNew43blxHbOwPcpdDRESVAMPXAV56aQwEQcAHH8xC\nZmam3OUQEZHMGL4OULfuA3j55Ym4fPkSpk79r9zlEBGRzBi+DvLmm1PRokUkvvpqLbZu3Sx3OURE\nJCOGr4O4uLhg2bKVcHNzw+TJr+Datatyl0RERDJh+DpQ/foNMGvWh0hLS8X48aNhMpnkLomIiGTA\n8HWwqKjh6NOnH/74Yy8WL/6f3OUQEZEMGL4OJggC5s1biMDAIHz44SwcPnxQ7pKIiMjBGL4y8PPz\nw6JFMTAajRgzZiR0Op3cJRERkQMxfGXSuXNXjBv3Ci5evIBp096UuxwiInIghq+M3nprGpo1a4F1\n69Zg61ae/YqIqLpg+MrI1dXVcvjRa69NwPXr1+QuiYiIHIDhK7OIiIaYOfN9pKamYsKEMTCbzXKX\nREREFYzhWwm88MII9O79OPbu3YMlSxbKXQ4REVUwhm8lIB1+tAg1awbigw/exdGjh+UuiYiIKhDD\nt5Lw9/fHwoXLYDAYePgREVEVx/CtRLp27YbRo1/G+fPnMH36FLnLISKiCsLwrWSmTp2Bpk2bY+3a\nz/Hjj7Fyl0NERBWA4VvJ5B9+pNFoMHnyeCQm3pC7JCIisjOGbyXUsGEjzJjxHpKTk/Hyy6N5+BER\nURXD8K2kXnzxJfTs2Rt798Zh2bLFcpdDRER2xPCtpARBwKefLkZAQE28994MHDt2RO6SiIjIThi+\nlVhAQAAWLlxqOfxIr9fLXRIREdkBw7eSe+yxHoiOHotz587inXfelrscIiKyA4avE5g6dSYaN26K\nNWtW4qeffpS7HCIiuk/lDt8tW7bgiSeewJNPPom4uDg7lkR302g0WLZsJVxdXfHqqy8jKSlR7pKI\niOg+lCt8U1JSsHjxYnz11VdYtmwZdu3aZe+66C6NGzfBjBmzcefOHc5+RETk5MoVvn/99RfatWsH\nDw8P1KxZE7NmzbJ3XWTDiBHR6N69J+LifsXy5UvkLoeIiMqpXOF79epVZGdnY8yYMRg2bBj++usv\ne9dFNgiCgAULlsLfPwCzZ8/AsWNH5S6JiIjKQRBFUSzrnZYvX46DBw9i0aJFuH79Op5//nns3r0b\ngiDYbG80mqBSKe+7WJJs374djz/+OBo3boz9+/dDq9XKXRIREZWBqjx38vPzQ6tWraBSqRAaGgp3\nd3ckJyfDz8/PZvuUFPsfnxoQ4IlbtzLs/rjO4KGHOuCll0bjs89iMH78RHz00bwKe67q3M+OxH52\nDPazY7CfJQEBnsWuK9dm5w4dOmDfvn0wm81ISUmBXq+Hr69vuQuksps+fRYaN26Czz//DD//vF3u\ncoiIqAzKFb6BgYHo1asXhgwZglGjRmHq1KlQKHjIsCNpNBosXSodfjRx4jgkJSXJXRIREZVSuX7z\nLauK2PzAzRqSFSuW4u23/4uuXbth/frv7P6fIPazY7CfHYP97BjsZ4ndNztT5fHSS2Pw2GPdsXv3\nLnz22TK5yyEiolJg+Dq5gsOP/PHuu9Nx4sRxuUsiIqJ7YPhWAYGBgZg/fzFyc3MxduxIZGVlyV0S\nERGVgOFbRfTs2QcjRozC6dOn8O670+Quh4iISsDwrULeeWc2GjZshJUrl2Pnzh1yl0NERMVg+FYh\nbm5uWLZsFVxcXPDKK+Nw8+ZNuUsiIiIbGL5VTNOmzTBt2kzcvn0LEyeOhQOOJCMiojJi+FZBo0aN\nRZcuj2HXrl+wcmWM3OUQEdFdGL5VkEKhwMKFy+Dn54eZM6fh1KmTcpdERESFMHyrqMDAIHz66WLk\n5ORgzJgRyM7OlrskIiLKw/Ctwnr3fhzDh4/EqVMnMWvWdLnLISKiPAzfKm7GjPcQEdEQK1Ysw65d\nP8tdDhERgeFb5Wm1WixdutJy+NGtW7fkLomIqNpj+FYDzZu3wNtvz8CtWzcxadI4Hn5ERCQzhm81\nMXr0OHTu3BW//LIDq1atkLscIqJqjeFbTeQfflSjRg3MnDkVp0+fkrskIqJqi+FbjQQF1cKnny5G\ndnY2xowZycOPiIhkwvCtZvr06Yvnnx+BkyeP4733ZspdDhFRtcTwrYZmznwP9es3QEzMYuzevUvu\ncoiIqh2GbzXk7u6OZctWQq1WY8KEMbh9+7bcJRERVSsM32qqRYtIvPXWdNy8mYRXX32Zhx8RETkQ\nw7caGzduAjp27IIdO7Zj9eqVcpdDRFRtMHyrMYVCgUWLlsHX1xfvvDMFZ8+ekbskIqJqgeFbzdWq\nFYx58xYhOzsbo0ePQE5OjtwlERFVeQxfQt++/REVNRwnThzD+++/K3c5RERVHsOXAADvvvsB6tWr\nj6VLFyIu7le5yyEiqtIYvgSg4PAjlUqFCRPG4M6dO3KXRERUZTF8yaJly1Z4881pSEpKxKuvjufh\nR0REFYThS1bGj5+IDh064aeftuGLLz6XuxwioiqJ4UtWpMOPYuDj44Pp09/C6dOn5S6JiKjKYfhS\nEcHBtfHJJwuRlZWF/v374+LFC3KXRERUpTB8yab+/Qdg8uTXcf78eTz+eDf8/fc+uUsiIqoyGL5U\nrDffnIbly5cjLS0Ngwf3w/fffyt3SUREVQLDl0o0atQorF//HVxdNRg9egQ+/fRj7gVNRHSfGL50\nT126PIbY2J9Rp04IPvhgFiZNehm5ublyl0VE5LQYvlQqjRs3wfbtuxAZ2Qrr16/DM88MRlpaqtxl\nERE5JYYvlVpgYBC+//5H9O7dF3v37kHfvj1w5cplucsiInI6DF8qE3d3d3z++TqMGTMeZ8+eQZ8+\nj+HAgX/lLouIyKkwfKnMlEol3n33fXz44SdITk7GoEF9sXXrD3KXRUTkNBi+VG4jRozCunUboFSq\nMHJkFBYtWsA9oYmISoHhS/ele/de2LLlJ9SqFYx3352G//xnEgwGg9xlERFVagxfum/Nm7fATz/9\nimbNWmDt2s/x7LP/h/T0NLnLIiKqtBi+ZBe1agVjy5af0KNHL8TF/Yr+/Xvh6tUEucsiIqqUGL5k\nNx4eHlizZj1GjozGqVMn0bv3Yzhy5JDcZRERVToMX7IrlUqFDz6Yi9mzP8StWzcxYEAfbN++Te6y\niIgqFYYvVYjo6HFYvforAMDw4cMQE7OYe0ITEeW5r/DNzs5G9+7dsWnTJnvVQ1VInz598cMP2xEQ\nUBPTpr2FKVNeh9FolLssIiLZ3Vf4Ll26FN7e3vaqhaqgli1b4aeffkXjxk2xcuVyvPDCM8jMzJS7\nLCIiWZU7fC9cuIDz58+jS5cudiyHqqI6dUIQG7sDXbo8hl9+2YEnnuiNGzeuy10WEZFsBLGcP8RF\nR0dj2rRp2Lx5M2rXro0nn3xblRT0AAAgAElEQVSy2LZGowkqlbLcRVLVYDAYMH78eCxfvhy1a9dG\nbGwsIiMj5S6LiMjhVOW50+bNmxEZGYmQkJBStU9J0ZfnaUoUEOCJW7cy7P64ZM3e/Txr1seoVSsU\nM2dOxaOPdsBnn61G9+697Pb4zoqfZ8dgPzsG+1kSEOBZ7LpyhW9cXBwSEhIQFxeHxMREuLi4ICgo\nCO3bty93kVQ9CIKAl19+BaGhYXj55VF47rmn8f77H2PEiFFyl0ZE5DDlCt/58+dblhcuXIjatWsz\neKlM+vcfgODgYERFDcWbb76GS5cuYsaM2VAq+fMEEVV9PM6XZNO6dRts374LERENEROzGC+++Bx0\nOp3cZRERVbj7Dt8JEyaUuLMVUUnCwsKxbdsv6NixM376aRsGDnwcSUmJcpdFRFShOPIl2Xl7+2D9\n+u/wzDPP4ciRQ+jTpxtOnTopd1lERBWG4UuVgouLC+bPX4wpU6bj6tUE9OvXE7t375K7LCKiCsHw\npUpDEARMmvQfxMSsQm5uDoYNewpr166WuywiIrtj+FKlM2jQU/j2263w9vbGa6+9glmz3oHZbJa7\nLCIiu2H4UqXUtu0j+PHHXahXrz4WLvwUo0YNR1ZWltxlERHZBcOXKq0HHqiHH3/ciXbtHsXWrZvx\n5JN9cevWLbnLIiK6bwxfqtR8fWtg48bNeOqpp3HgwH706dMNZ8+ekbssIqL7wvClSs/V1RWLFy/H\n66+/hfj4y+jbtwd+//03ucsiIio3hi85BUEQ8Prrb2HRohjo9ToMGTIQX3/9pdxlERGVC8OXnMqQ\nIc/gm29+gIeHB155ZSw+/HAWyjkrJhGRbBi+5HTat++AH3/chbCwcMyb9zHGjh2J7OxsucsiIio1\nhi85pfr1G2D79l/Rpk1bbNr0Lf7v/wbgzp07cpdFRFQqDF9yWv7+/vjuu60YOPBJ/P33X3j88W64\nePG83GUREd0Tw5ecmkajwbJlqzBp0n9w6dJF9OnTjXtCE1Glx/Alp6dQKDBlynTMn78YGRkZePLJ\nfnjhhWE4ffqU3KUREdnE8KUqY9iwKGzZ8hPatGmL7dtj0aVLO0yYMAbx8VfkLo2IyArDl6qUhx56\nGLGxP2Pdug1o1KgJNmz4Cu3aPYgpU17HzZs35S6PiAgAw5eqIEEQ0LNnH/z66+9YuvQzBAfXxmef\nxeDhh1vigw/eRVpaqtwlElE1x/ClKkuhUGDw4CH4888DmDPnU3h6euLTT+eiTZsWWLhwPvR6vdwl\nElE1xfClKk+tVmP48JH4++/DmDp1JgBg1qzpaNs2EqtXr4TBYJC5QiKqbgTRAefmu3Urw+6PGdCm\nOUzmoqXrx72C7JHRAADPcaOg/vuvIm0MrR9CxvLVAADN2tXQzp9r8zmS/zoIuLhAee4svIc+abNN\nxryFMHTuCgDw6dUFitu3i7TJHvIM9P99GwDg/s7bcI39oUgbU2gY0r7fBgBw2b4NHlP/a/P5Urfu\ngDm4NoTUFPh262izjW7KdOQMHgIA8Hr2/6CysddvbtfuyJw7HwDgtnA+3FZ/VqSNqNVCdfoUbt3K\ngGr/P/AaPcLm86WvWgtjy1YAAN+2kRCMxiJtsqLHImv0ywAAj0kvw2XvniJtjM1bIn21dL5m16+/\nhPvHH9h8vuQ9+wAPDyguX4LP4P4222TOmYfcbj0BAD79ekJx47plndlsRkZ6OlZl6fG60Yjw8Lr4\nrmEjtDxxHBAEq8cx1wpGauzPAACXXT/D443JNp8v9butMIfXBTIzUaPzIzbb6F5/CzlDnwUAeA1/\nFqpjRyzrlAoBJrOI3I6dkTl/MQDALWYx3JYvLfI4okqFlL8PAwBURw7Ba0SUzedLj1kF40MPAwB8\nOz4MwcZIP2v4S8iaMAkA4PGfSXDZvbNIG2Ojxkj/8hsAgOt3G+H+/rs2ny9l116IPr5QXL8Gn/69\nbLbJnP0Rcvv0BQB4D+oLpY2d4XL6DYBu5nsAAO1H70GzcX2RNmZ/f6TuiAMAqPfshufkCTafL+3r\nTTA1iAByc1Gj3YOWfi5MP+k/yI4aDgDwjB4O9YH9RR7H0LYdMpasAABoVi6Hdsn/bD5f8oHjAADl\nyRPwjnraZpuMRTEwtHsUAODb9VEI6WlF2mQ/+zz0k98AALhPeR2uO7YXaWOqVx9pGzcDAFy2bobH\njKk2ny9l+68Qa9aEcPMmfPs8ZrNN5ozZyO0/EADgPWQglBeKHi+f06sPdO9/DADQzpsDzZdfFGkj\nenkjZfcfCAjwROqWn+A5frTN50tbuwGmJk0BADVaN7PZRs7vcnsJCPAsdp3Krs9E5AQUCgW8fXzw\nwpBncBoivvjic+y4fAmBajW8vX3g5uYmd4lEVMU578g3wLNCHpesVYd+vnLlMj7++AN8883XEEUR\nDz/8CKZOnYFHHmnvsBqqQz9XBuxnx2A/S0oa+fI3X6r2wsLCsWhRDPbs2Yc+ffrhn3/24YknemPo\n0CdxrNCmYSIie2H4EuVp1Kgx1qz5Ctu370KHDp3w66870a1bR0RHD+c5o4nIrhi+RHdp3boNvvtu\nKzZu3IzIyFbYvHkTHn20DV577RVcv35N7vKIqApg+BLZIAgCunR5DDt2xGHlyrV44IF6WLt2Ndq2\njcQ777yN5GROX0hE5cfwJSqBIAjo338A9uzZhwULliAgoCaWLl2Ihx5qgblzP0RmJncqIXJqZjOg\n0wGZmQ59Wu7tTCViP1vLycnBmjUrMX/+XNy+fRv+/v6YOPE1vPDCSGg0mnI/LvvZMdjPjnHf/SyK\ngMEAITsLQlYWoNdDyM6GkKWHkJUFITsL0GdJfxe6HdlZEPRZBW2yCrXRF2pT+PbsbOkpFQqkffUt\nDI91t1MvlLy3M8OXSsR+ti0zMwMxMUuwZMlCZGSko3btOnj99bcwZMgzUKnKfvg8+9kx2M92IoqA\nTgdBp4Ogy4Sg00Ghy4SgywR0OngrTMhISraEoJCVBdwVggXhWNBG0OuB/DA1mexbsloN0U0L0c0N\n0GggarUQNRrLbaKPL3RvvwNznRC7PSfDl8qN/Vyy5OQ7+N//PsWqVcuRnZ2NBg0i8Oab09Cv3xMQ\n7jpbVknYz45RLftZFKWQKxSUQmZmwbLuruXM/GUdhMyMguXC6/Q6CHaIDlEQADc3KfzcCsIQbm4Q\nNW4QtXnrNG557Qq3KRScGqkd7g5UjRugzbsux3+K7xfDl8qN/Vw6169fwyeffISvvloLk8mEyMhW\nmDLlHXTu3LVUIcx+dgyn6Oe8sFSkp0FIS4OQnlZ8GBYXmlZ/Z0Iwm8tfjiBAdPeA6O4uXTw8Cy17\nWK9zl9Z5BvkhzaQoCNG84LQEZn6AuroWOaVrVcLwpXJjP5fNhQvn8NFH72Hz5k0AgA4dOuHtt99B\n69ZtSrwf+9kxHNLPZjOEjPS84Ey3CtGC5XTp70LLVuttnB+9tEStuyUMzR6eQKHl/Nvh7gGzR35o\neuSFqDtErUeRUIVWW+aA5OdZwvClcmM/l8+xY0fw/vvvYteuXwAAvXv3xVtvTUPjxk1stmc/O0ap\n+jk3F0JaGhQZxQVlqnVopqVBUWhZyMwo8yZZUauF2csbopcXRC9vmL3zl32k68JBefeIs/Cy1h1Q\nyH8QCz/PEoYvlRv7+f7s2/cnZs+egX/+2QdBEPDUU0/jjTemICws3Kod+9nOjEYIKSlQJN+BIvkO\nhDvStWeuDvrrN4sP0Yx0aeegMhAFAaKnF0Rvbyk4ve5a9vKG6C2FqNnLO29d/rJ0O9TqCuoIefDz\nLGH4Urmxn++fKIrYuXMH3nvvXZw8eRxqtRpRUcPx6qtvIDAwEAD7uURmM4TUFCiSky0hWjhQFcl3\nICTfgeJO3nVyMhRpqaV+eFGthujtDbMlQPOCMi9ArZel8LSMUr29IXp4VorRZmXCz7OE4Uvlxn62\nH7PZjM2bv8OHH87G5cuXoNVqER09Di+//Arq1w+pHv0sitIo0yo4k6G4Yx2iluWUZAgpKaXaYUhU\nqWCu4QfRz0+6riFdm/1qWJa9wmsjRVRLI868wIWbW5Xe6UcO/N6QMHyp3NjP9mcwGPDVV2sxd+6H\nSEpKhI+PDyZNmoQePfqhQYMIucsrPVGU9qS1MRq1CtSUZGn9HWm5NDsTiQoFxBo1pPDMD1K/QoFa\no4YUsr41LIErenrdM0T5eXYM9rOE4Uvlxn6uOHq9HitXLsfChfOQmiptJm3YsBH69n0C/foNQNOm\nzcp0rLDdmM0Qbt+GMukGFEmJUCQlQZFYaDnphnR96yaE3NzSPaSvr1WIWo9M85Z9a0D0ywtTb58K\n2ZTLz7NjsJ8lDF8qN/ZzxcvISMcff/yKr77agLi4XcjOO91deHhd9Os3AP36PYFWrVrffxAbjVDc\nvpUXoolQJCbeFah5t926WeLZhUQXF5gDg2AOCIDZz79oiNbwsx61+vjIcoIDW/h5dgz2s4ThS+XG\nfnaM/H7OzMzErl0/IzZ2C375ZQf0eh0AIDi4Nvr27Y9+/Qbg4YcfgVKpLLizwQDFrZt5o9OkvBC9\nAcXNJOuQvX2rxN9ORY0G5ppBMAcGwhxUC6bAQClk8y9BtWAODIToW8NpfyPl59kx2M8Shi+VG/vZ\nMWz1c1ZqKvbH/oCD27bg8l9/wFuvRy0AdTUaNK/hh1C1Gp46HRR3bpd4XKmo1cJcMxCmoFp5IRpk\nFbJSuAZKm3qdNFRLi59nx2A/S0oK38qxLYiousnKgjIhHsqEK1DExwOpt+B58Yr1ZuDkZIQCePLu\n+2ZnA9evIQPARYUCBv8AuNVvAP9mzSEE15HC1TJaDZIOhanioUrkbBi+RBUhOxvKawlQXLmSF7Lx\nUMRflpbj46G4dbPIXfInJDR7ecMcGAhj0+Yw1wwsGK3mBaohoCb+jr+CH3b9jG3btuLGjevArZvw\nOHYUPXr0RL/QAXisVWu4u7s79jUTUalxszOViP1cjNxcKK4m5IXpFSjyrqWQvQJlUqLNu4lqNcy1\n68AUGg5TaCjMIaEwhYbBq2kE7rh6wRwYJJ1Lt5TMZjMOHtyP2NgtiI3dgvj4ywAANzc3dO3aHX37\n9kevXn3g5eVtj1ft9Ph5dgz2s6RCfvOdM2cODhw4AKPRiNGjR6Nnz57FtmX4Oq9q288GAxTXrlqP\nWuPzlhPiobhx3ebvrKJSCXPtEJhCpVA1h4TCFBIKU2g4zKGhUrgW3lkqjz36WRRFHD9+DNu2/YDY\n2C04e/YMAECtVqNTpy7o128AevfuCz8/v/t6HmdWbT/PDsZ+ltg9fPft24eVK1dixYoVSElJwaBB\ngxAXF1dse4av86qy/Ww0QnHjuvWoNX85IR6K69ds7hksKhTSyDUktFCwhsEcGibdViu4XIfVVEQ/\nnz17BrGxUhAfP34UAKBUKtG+fQf07fsEHn+8H4KCatn1OSu7Kvt5rmTYzxK7h6/JZEJOTg60Wi1M\nJhPat2+PP//80/rwh0IYvs7LafvZZIIi8YYUpFcuW0asls3E167aPJZVFASYawVbNgebQkKlYM1f\nDq5dISfBr+h+vnz5ErZt24rY2B9w4MC/AABBEPDQQw+jX78B6Nu3P0JDwyrs+SsLp/08Oxn2s6RC\nDzXasGED9u/fj48//rjYNhXxJrRp4wmzjZHJuHG5GDnSkLeswd9/F/0PQevWJixfLp3IYO1aNebP\nd7H5HH/9pYOLC3DunAJDh7rZbDNvXjY6d5a+xHv10uL27aJ7lQ4ZYsB//yudCeidd1wRG1t0ZBQa\nasb330uzqWzfrsLUqa42n2/rVj2Cg0WkpgLdutneoWbKlBwMHiydwu/ZZ91w+nTRMwV17WrE3Lk5\nAICFC12wenXRQNFqRZw+rcStWxnYv1+B0aNt98GqVVlo2VJ6L9q2dYetswdGR+di9GjpfZk0yRV7\n9xbtg+bNTVi9Wnpfvv5ahY8/tt0He/bo4OEBXL4sYPBADWA0QDAYAEPetdGIJRiLx02xAIAO2Iur\nqFPwAAoloFLh/8L3YUbfP2EOCcP033rhu31hgEpptWdwrVpmxMZK78uuXUq88YYGtnz3nR7h4SIy\nM4HOnW2/L6+/noOhQ6XOGT5cg2PHCj6bCoUCZrMZHTsaMX++9L7ExKixfHnRz6ZKBfz9t3T875Ej\nCowYYft9iYnJwkMPSe9Lx45a6PXS6zIaTcjK0iMrS4/c3PkQxTkAAD+/b2A0doebmxvUhf6D0aiR\nGV9+mZX3OlV4/33b78uuXTr4+ADXrwvo39/279azZ+egTx+pDwYNckN8fNHPZr9+RsycKfXBRx+5\nYOPGop9Nf38RO3boAQB79igxebLt9+Xrr7PQoIEZublAu3buln4ubNKkXERFSZ/N6GgNDhwo+p3R\ntq0JS5ZIn82VK9VYssT2d8aBA9L7cvKkAlFRtt+XRYuy0a6d9J3RtasW6elFvzOefdaAyZOl74wp\nU1yxY0fRfy/16pmxcaP0vmzdqsKMGbbfl+3b9ahZU8TNmwL69LH9vsyYkYP+/aX3ZcgQN1y4UPR9\n6dXLiPffl96XefNc8OWXRd8XLy8Ru3frERDgiS1b9Bg/3vb7snZtFpo0kd6H1q1t/3uR87vcXirs\nUKOdO3fi22+/xapVq0ps5+urhUple1R8PxQ2Tj/n6alBQID0hms0ts9Q5+qqQECAOq998WexCwjw\nhIsLcOdO8W18fLQICJCWVSrb7dzdXREQIP3D0Gptt1GrFZY3ytu7+Ofz8/NAQEDxzwUAXl5ulppc\nXGy3c3NzQUCA9EH18LDdJn9DRkCAJ3x9i38+X193y/MplYCt8zh4eNzP+yICJhOQKwVswKyp8Dh/\nGBnH9VCkfFv0gRQKCA3qA62GAuHhwDf1gMy8sywpVZZwVT05CO4fDJJqugkoDhV9qLK+L25uxbfx\n9Cx4X1xdi7ZTKBTQaEr3vuTXVJb3Jb+di4sCLi7e8Pb2xvPPT0OdOvWwadMm/PxzMkQxFWlpqVCr\n1dBqtXB3d4eLi9ryfF5exT+fv7/0OcnJKb6Nt3dBH6jVtttptQV94F7M9LQqVUEf+JRwJsoaNaQ+\nyM0taHP390bh7wxb7wsAaDSl/86Qntd+3xnFfaZcXBSlfF+kz6bZbL/vjNK9L9p7vi9ASf9e5Psu\nd4Ryj3z37t2LBQsW4LPPPoOPj0+JbbnZ2Xk5tJ+NRigvX4Ly7Bkoz52BKu9aee4cFLpMq6aiQgFz\naBiMEQ1himgkXTeIgKlBBEQn3LO3MnyeU1NTsGPHdmzbtgW7d+9CTo40yqlb9wHLaS4jIx+U53zT\ndlIZ+rk6qMh+FkXAaATyNnbBYBBgMEj/wTIagdxcoci6/PWF/zYYhEL3kf7TMXSoAV5e9qvV7pud\nMzIyMGzYMKxevbpUe04yfJ1XhfRzVhaU589Bde4MlGfPQHXurBSyFy8UOVG/6OICU70GBeEa0RDG\nBg1hqldf+u9wFVHZPs+ZmRnYuVM6zeXOnT9bTnNZu3Ydy2ku27RpW+x+HpVVZevnyshkks7jkpsr\nBVlODpCTI11b3ybdLt0GZGdLyzk5AtRqV6Sm5tgIwMLhd3cAFl1XEKrW7SrKypVZls3v9mD38N2w\nYQMWLlyIunXrWm776KOPEBwcbLM9w9d53U8/C2mpBeFaaDSriL9S5DAds4cnTBERMDXIC9eIhtLf\noeGV5qT8Fakyf56zsrIQF/crYmN/wI4d25GengYA8PX1RYsWkWjZshVatoxEixaRCA0Nq9Qj48rc\nz2YzkJUlhVh2trSclSXkBaFgFXZSEBaEoK1l27cVF6QFj2syyfP+qdUi1GppBKpSiXnXgIuLdHvB\npaBd/nL+xcVFzLtP0XVqtfVjFrST1nl5AW3amGwdCVhuPLczlds9+1kUobiZJIXr2TPSaDY/bG2c\naMLsH5A3ipXC1digIUwNG8EcVKtanwLRWT7Pubm5+P33PYiN3YK9e/fgypXLVut9fHzQvHkkWraM\ntARyeHjdShPIZelnUcwf0UmBqNcXBGN2tmAVjkWvC9rq9db3KQhX679zciq+jwRBhKurFDouLiI0\nGmnZ1VXMuw3QaETL+vy2rq62lotfX7OmFnq9/q4ALByu1gGoUlXNf/4MXyo3Sz+bzVDEX8nbVHy2\n0G+yZ6FISy1yP1NIKEwNCsLVmBe2om8NGV5F5eesn+fU1BQcPXoER44cxtGj0uXSpYtWbby9fdCi\nRUs0b97SEsrh4Q/Y3GGyLERRGqnpdAJ0uruvbd8mii5ISTFYQq9wgNoKUbPZ/ong5iaFnkYjXbu5\niXBzk/7Ovy58e36wFQ486/C71/qCx3BUyDnr59neGL5UJkJKMlRHDkN15BA8Lp6F4dgJqM6fhZA3\nz2w+UaWCqe4DeTs8ReSNZhvCWK+BtDsklVpV+jynpaXi2LGjeYF8CEePHsGFCxcAuAPwAOABrTYI\nDzzQEqGhTREcHIGaNetBqw1EVpaiSGjq9cWHqz02kapURcPPVgjmL+cHp5tb6f8uuL1qjvDuVpU+\nz/eDsxpRsYSMdKiOHoHq8CGojhyE+tBBKO/alKjSamGMaGS9w1NEQ5jqPlAhJ5ygysFgANLTBaSn\nAxkZAtLSBMvf6ekCMjKKjjCloNRCpwuGTtfHchtgnTh6PXD8uHQpLa1WhLu7CHd3oEYNs2XZ+rrk\n22rXdkdWViY0GunxNJpqsUsBVUL82FUnej1Ux49BfeQgVIcOQnXkEJTnz1nt/GSuUQO5XbvB0OpB\nGFs+CO9Oj+C2WwkHk1KlZDYDmZn54WkdmmlpUnCmp8OynB+sGRkFt+WflKOslEoRHh5S2Pn6iqhT\n5+5QlJbV6hxkZNzAnTuXkZh4Htevn8aNG+cgiukAMgFkws1NRNOmddGqVSO0bNkSLVu2Qv36Dcq9\nl3VAAHDrVoVv7CO6J4ZvVZWbC9XJ49KI9vBBqA8fgvLMKatTKpo9vWB4tCOMkQ/CENkKxsgHYQ4J\ntd4uFuAJcPORQ4mi9Ptj4dC0DkncFZgC0tJQaFkKUVEsW3hKe3yK8PQEgoLM8PIS8y4otFxwm6en\nCA8PEVqtdai6uJRl02qtvEs7AIBOp8Px48dw9Oghy+/IBw/GYf/+Xy330Gq1aNq0uWWHrpYtW6FB\ngwioOIQlJ8LffKsCoxHKM6ehPnIob0R7EKqTJ6yOmRXd3GBs3jJvRCsFremBevcc0bKf7092NpCc\nLFhd7tyRrlNSCv7OzFQhOdlsGZ0aDGULTkGQQlMKTxHe3sWHZuG/vb0L7uPmVjl/j9Tr9Thx4hiO\nHj2MI0eky9mzp2Eq9B9JNzc3NGnSLG+HrlZo0SISDRs2KhLI/Dw7BvtZwh2uqhKzGcoL56E6fNAy\nolUdPwohK8vSRHRxgbFps7wRrRS2poiGlWa2HWeVk1NykOYvF/67tJtu3dwALy9zMSPNoiHq7Y1C\nISsWeyrKqiorKwsnTx63jI6PHDmMM2dOwVjoxOIajQZNmzbL28taCuSOHR9Gamp2CY9M9sDvDQnD\n11mJIhRXLhca0R6C6shhKDILXreoVMLUqIlls7ExshWMjZtK2/7soKr2c04OLCPPwkFaeDR6d5Dq\ndKULUq1WRI0a0sXXV4SfX/F/+/lJt4WEVM1+dqTs7GycOnXCKpBPnz4Jg8FgaaNUKlG7dh3UqRNi\nuYSEhFqua9euA1dX2xMUUOlV1e+NsuLezs5AFKG4cd0SsurD0rUiJaWgiSDA1CACuS1bFWw+btZC\nGjZVczodkJQkIClJgdu3bQdp4dFpZmbpglSjkQLygQfMRYLz7kt+kPLtkIdGo0GrVq3RqlVry205\nOTk4ffqkZXP1hQtncOnSZfz11x8obtwRGBiUF8YhCAkJsyzXqSOFtIeHh6NeElVhHPnKRLh1C+rD\nB6x2iFLcumnVxhReN29E21oa0bZoCdGj+P9JVQS5+zkzsyBUExMFJCUJSExU5N0mWNZlZNw7TF1d\nC8KzpCAt3EZrewY2u5O7n6uL/H7Ozc3FtWtXcfVqAq5eTUBCQrxlOT4+HtevX7XahF2Yr6/vXaEs\nBbMU1qHw8fGtNGf0kgs/zxKOfOWWlQX1/n+gOrgf6vxDfK5dtWpiql0HOY/3LxjRtoyssmeDEkUp\nVPNDND9Uk5IKQjV/3b029fr7mxESYkZgoIigIBGBgWYEBBQfpNX8O5HyuLi4oG7dB1C37gM215tM\nJiQlJSIhIQFXr8YjISHesnz1agLOnTuDo0cP27yvu7tHoVCWRs/5f4eEhCIgoOZ9n92LnB/DtyIY\njVAdPgiXvXug3rsH6n//hpA3PRsgnd84p0cvy2+0hpYPQqxZU8aC7UMUgfR0WI1MExMVuHlTsBq1\n3rxZ8o5IgiCFZt26+aEqXdesWRCwQUEiAgJEe/20TWRFqVQiOLg2goNro23bR4qsF0URd+7cQULC\nlbyRc0Ewx8dL16dPn7L52K6urggOrm0VyvnBHBISilq1gnnYVDXAd9geRBHKUyfhsjdOCts//7Da\nKcrYtDlyO3aG4eFHYGz1IMzBtZ1qCCaKQGoqrDb9Wo9SC/7Ozi7+dSkUIvz9RdSrZ7aEaGCgaBWw\ngYFSqPLEWVSZCYIAf39/+Pv7W/3GXFh6elpeKCcgIeGKZVkaSSfgt99227yfUqlErVrBVjuF1ahR\nA76+NQpd+6FGjRrw8vLmKNpJ8TffclJcvpQ3so2Dy++/QXH7tmWdse4DMHTsgtxOnWFo3xGiv79s\ndZaG0Qhcvy4gPl6BhAQBV64oLMtJSSrcuCGWOOOKQiGNSvM3/dasab0ZWLqWgpf/obdN7s9zdVGZ\n+lmv1+Patat3/d58xbKcmHgDZrO5xMdQKBTw9fWFr68Uyn5+fpbl/KDOX65RI3+dL1wqeJNRZepn\nOfE3XzsQkpLg8ru0Gdnl99+gjL9iWWcKDEL2U08jt1MXGDp0grlOiIyVFmU2SzstXbkiBWp+sMbH\nSyF77Zpg8wT1CoWIWolNSOEAAAkTSURBVLWAJk3Md41SrUet/v6iXefAJKoOtFotGjSIQIMGETbX\nGwwGXL9+DdevX0NycjJSUpILXd+x+jslJRmXLl20OvFISTw8PG2MpouGduEwd3d3r/Y7ktkTw7cY\nQloq1H/+IY1s9+6B6sxpyzqztw9yHu8vbUru1AWm+g1k3YwsisDt24JVoMbHFyxfvSogN9d2fUFB\nZjz4oBmhoWaEhZkREiIiNFT6OzhYRHCwJ27d0jv4FRGRWq1GWFg4wsLCS9XebDYjPT3NKpALL9+5\nU/T2s2dPI6vQCXpK4uLiYmMUbTu869ULQW6uAHd3d2i17uU+F3dVxvDNl5UF9T/7LJuSVUcOQ8jb\n5CO6uSG3y2PI7dgFhk6dpWNrHfxhSk0FEhIUVqPXwiPY4nZg8vc3o2lTsyVQ88M1LMyM2rWlWV2I\nyPkpFAr4+PjCx8cXQL1S30+v1xcJ6sIj7Ltvv379Ok6dOlmm2jQaDdzd3eHu7gGtVpsXyh5511q4\nuxddzg/u/PvZauvMv3dX3/A1GqE6dMB6j+S8cyGLKhWMDz1sGdkaHnxImp26AmVmSuEaHy9YQjZ/\nOT5egfR02+Hq5SWdACI/WMPCCpZDQszg+QCIqCRarRZarRa1a9cp9X2MRiNSUlKKDe2srAwkJ6dC\np9PlXTKh1+uh0+mQlJQInU6H3ELnnr+/2u8O6sIhbyvIrf8TkL/s6+sLb2+f+66ptKpP+JrN1nsk\n//WnZY9kURBgbNYChg6dYOjUGblt28PeqZWTgyKbhfODNT5ewJ07tv8Hp9VKI9VHHpHCVBrBFmwa\n9va2a5lERPekUqkQEBCAgIAAm+tLs8OVwWCAXq+zBHTBcmbe33rLsvV66zDPb5OamorMzIxS/+59\nN4VCga+++haPPda9XPcvq6obvqJYsEfy73uK7pFcrz5yBg+R9kh+tCPEGn52eVq9Hjh/XoEzZxQ4\nezb/WonLlwWYzUVHry4uIkJCRDRvbiwSrKGh0vGu3MeBiKoatVoNb28fu442RVFEbm6uzXC2DvOi\n6wEgIqKh3Wq5lyoVvoqkRGlUm79HckK8ZZ2pVjCyhzyD3A6dYOjYGeYybGKxJTMTOHs2P2CVlqBN\nSBCKzKNao4YZbdqYUK+eFKjSCFbaRFyzplitZqMhIqoogiDA1dUVrq6uqGGnAVVFcerwFdJSof7j\nd2lT8u+/We+R7OuLnH4DpLDt1AWmevXLtUdyaipw5owS584VjGbPnlXg2rWiiVmzphkdOpgQEWFG\nRIQZDRtK1/7+FX4oNRERORHnC19RhNvC+cCOWPgdOFCwR7JWi9zHukt7JHfsJO2RXIYh5e3bQqHN\nxAWbjG/eLPoYwcFmdOlitISrdDHB19dur5KIiKow5wtfnQ7un3wIGI0wPPwIDB07S5cHH7rnHLai\nCNy8Kdz1e6x0sbXDU2ioGd27G/NGsdKItkEDM7y8KurFERFRdeB84evhgTv7j8M/LBBpetunXhNF\n6XSJ1qNY6XfZtDTrTc+CICI8XESbNgarzcX165vh7u6IF0RERNWN84UvADEgAHB3hzkzAwkJgtVe\nxfnLd09Fp1RKx8N26GC22lxcr56Zk58TEZFDOV34iiIwfbor/v0XOHXKA1lZ1iGrVouoX99cZKen\nBx4wc/o5IiKqFJwufPV6YMMGNbKzYQnZ/IBt2NCE8HDOnENERJWb08WUuztw4kQmAgM9kZzME/4T\nEZHzccrTO6jVDp/XgIiIyG6cMnyJiIicGcOXiIjIwRi+REREDsbwJSIicjCGLxERkYMxfImIiByM\n4UtERORgDF8iIiIHY/gSERE5GMOXiIjIwRi+REREDiaIoijKXQQREVF1wpEvERGRgzF8iYiIHIzh\nS0RE5GAMXyIiIgdj+BIRETkYw5eIiMjBnC5833//fTz99NMYOnQojh49Knc5VdqcOXPw9NNPY/Dg\nwfj555/lLqdKy87ORvfu3bFp0ya5S6mytmzZgieeeAJPPvkk4uLi5C6nStLpdBg/fjyioqIwdOhQ\n7N27V+6SKi2V3AWUxT///IMrV65gw4YNuHDhAqZMmYINGzbIXVaVtG/fPpw7dw4bNmxASkoKBg0a\nhJ49e8pdVpW1dOlSeHv/f3v398r6H8Bx/LkzubBxzDJaIblRSigXWHJBLlz7kRa3cqVc0FKUq7lS\nKAp/gLZwI0pZuZgr5UJRXGExy8evxgU6d6fOt9x8a3vbp9fjbrt61i5ee38+n7bfpjNsy7IslpaW\niEajpNNpFhYW6OjoMJ1lO5ubm1RXVzM+Ps7d3R3Dw8Ps7u6azvqRcmp84/E4nZ2dANTU1PD09MTr\n6ytut9twmf00NzdTX18PQFFREW9vb3x+fuJ0Og2X2c/l5SUXFxcagwyKx+O0tLTgdrtxu93Mzs6a\nTrIlj8fD+fk5AM/Pz3g8HsNFP1dOXXZOpVL/fJglJSXc398bLLIvp9NJQUEBAJFIhPb2dg1vhoTD\nYSYnJ01n2Nr19TXv7++MjIwwODhIPB43nWRLPT09JBIJurq6CAaDTExMmE76sXLq5Ptf+mXMzNvf\n3ycSibC+vm46xZa2trZoaGigoqLCdIrtPT4+sri4SCKRYGhoiIODAxwOh+ksW9ne3sbv97O2tsbZ\n2RmhUEjPMXwjp8bX5/ORSqX+vk4mk5SWlhossrfDw0OWl5dZXV2lsLDQdI4txWIxrq6uiMVi3N7e\nkp+fT3l5Oa2trabTbMXr9dLY2EheXh6VlZW4XC4eHh7wer2m02zl+PiYQCAAQG1tLclkUrervpFT\nl53b2trY29sD4PT0FJ/Pp/u9GfLy8sLc3BwrKysUFxebzrGt+fl5otEoGxsb9Pb2Mjo6quHNgEAg\nwNHREV9fX1iWRTqd1v3IDKiqquLk5ASAm5sbXC6XhvcbOXXybWpqoq6ujoGBARwOB9PT06aTbGtn\nZwfLshgbG/v7Xjgcxu/3G6wS+X/Kysro7u6mr68PgKmpKX79yqmzR07o7+8nFAoRDAb5+PhgZmbG\ndNKPpb8UFBERyTJ99RMREckyja+IiEiWaXxFRESyTOMrIiKSZRpfERGRLNP4ioiIZJnGV0REJMs0\nviIiIln2BzQKNGAGnBgwAAAAAElFTkSuQmCC\n",
"text/plain": [
- "\u003cmatplotlib.figure.Figure at 0xc1dc310\u003e"
+ "\u003cmatplotlib.figure.Figure at 0x7f7a18df6b50\u003e"
]
},
"metadata": {
@@ -668,13 +549,10 @@
" w_at_step = []\n",
" b_at_step = []\n",
" for step_num in range(num_training_steps):\n",
- " loss, gradients_and_variables = value_and_gradients_fn(inputs, labels, wb)\n",
- " loss_at_step.append(np.asscalar(loss.numpy()))\n",
- " \n",
- " optimizer.apply_gradients(gradients_and_variables)\n",
+ " loss_at_step.append(run_step(inputs, labels))\n",
" w, b = wb.variables\n",
- " w_at_step.append(np.asscalar(w.read_value().numpy()))\n",
- " b_at_step.append(np.asscalar(b.read_value().numpy()))\n",
+ " w_at_step.append(np.asscalar(w.numpy()))\n",
+ " b_at_step.append(np.asscalar(b.numpy()))\n",
"\n",
" print(w_at_step)\n",
" t = range(0, num_training_steps)\n",
@@ -688,171 +566,12 @@
"\n",
"train_model(inputs, labels, wb, optimizer, num_training_steps)"
]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "UNurY9VJ-hpH"
- },
- "source": [
- "## Other Ways to Compute Gradients\n",
- "\n",
- "Using our loss function as an example (`loss_fn()`), there are several other ways we could compute gradients:\n",
- "\n",
- "1. `tfe.implicit_gradients()`\n",
- "1. `tfe.gradients_function()`\n",
- "1. `tfe.implicit_value_and_gradients()`\n",
- "1. `tfe.value_and_gradients_function()`\n",
- "\n",
- "Each of these functions does the following:\n",
- "* Wraps a function.\n",
- "* Returns a function with the same input signature as the wrapped function.\n",
- "\n",
- "They differ only in what information they return.\n",
- "\n",
- "### Gradients-only functions\n",
- "\n",
- "The following two functions return a function that returns only the variables' gradients:\n",
- "\n",
- "1. `tfe.gradients_function()`: Returns the partial derivatives of the function `f()` with respect to the parameters of `f()`.\n",
- "1. `tfe.implicit_gradients()`: Returns the partial derivatives of the function `f()` with respect to the trainable parameters (`tf.Variable`) used by `f()`.\n",
- "\n",
- "In our example above, the `tf.layers.Dense` object encapsulates the trainable parameters.\n",
- "\n",
- "### Value and gradients functions\n",
- "\n",
- "The following two functions are identical to their counterparts above, except that they also return the value of the wrapped function.\n",
- "\n",
- "1. `tfe.implicit_value_and_gradients()`\n",
- "1. `tfe.value_and_gradients_function()`\n",
- "\n",
- "### Gradient demos\n",
- "\n",
- "In the demos below, we show examples for the `implicit_*` functions, since our existing loss function works seamlessly with these versions. (The other versions require that your parameters are tensors and tensors only; in our example, we're using a `Dense` layer.)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "height": 85,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 100,
- "status": "ok",
- "timestamp": 1505502831671,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 240
- },
- "id": "aEoCftnfAIH5",
- "outputId": "72f1c1dc-a574-463f-f860-c4e5f48fcdaa"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[(\u003ctf.Tensor: id=673, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n",
- " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n",
- " (\u003ctf.Tensor: id=671, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n",
- " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)]"
- ]
- },
- "execution_count": 13,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# tfe.implicit_gradients() demo\n",
- "gradients_fn = tfe.implicit_gradients(loss_fn)\n",
- "\n",
- "# Returns only gradients and variables:\n",
- "gradients_fn(inputs, labels, wb)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "height": 102,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 88,
- "status": "ok",
- "timestamp": 1505502831785,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 240
- },
- "id": "bbgCUdCzAVhH",
- "outputId": "152aa9b6-9e42-4b7e-848a-9423c0b1929c"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(\u003ctf.Tensor: id=688, shape=(), dtype=float32, numpy=1.0623235\u003e,\n",
- " [(\u003ctf.Tensor: id=720, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n",
- " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n",
- " (\u003ctf.Tensor: id=718, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n",
- " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)])"
- ]
- },
- "execution_count": 14,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# tfe.implicit_value_and_gradients() demo\n",
- "value_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)\n",
- "\n",
- "# Returns the value returned by the function passed in, gradients, and variables:\n",
- "value_gradients_fn(inputs, labels, wb)"
- ]
}
],
"metadata": {
"colab": {
+ "collapsed_sections": [],
"default_view": {},
- "last_runtime": {
- "build_target": "",
- "kind": "local"
- },
"name": "Eager Execution Tutorial: Working with Gradients",
"provenance": [],
"version": "0.3.2",
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb
index 0088da5c4b..bfcc7feb07 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb
@@ -16,7 +16,9 @@
"\n",
"We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n",
"\n",
- "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled."
+ "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n",
+ "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n",
+ "As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled."
]
},
{
@@ -48,11 +50,8 @@
"# Import TensorFlow.\n",
"import tensorflow as tf\n",
"\n",
- "# Import TensorFlow eager execution support (subject to future changes).\n",
- "import tensorflow.contrib.eager as tfe\n",
- "\n",
"# Enable eager execution\n",
- "tfe.enable_eager_execution()"
+ "tf.enable_eager_execution()"
]
},
{
@@ -137,32 +136,27 @@
"source": [
"# Step 3: Iterate\n",
"\n",
- "Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n",
- "\n",
- "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls."
+ "When eager execution is enabled `Dataset` objects support iteration.\n",
+ "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls."
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
- "height": 153,
- "output_extras": [
- {
- "item_id": 1
- }
- ]
+ "base_uri": "https://localhost:8080/",
+ "height": 153
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 201,
+ "elapsed": 388,
"status": "ok",
- "timestamp": 1505952405928,
+ "timestamp": 1525154629129,
"user": {
"displayName": "",
"photoUrl": "",
@@ -171,7 +165,7 @@
"user_tz": 420
},
"id": "lCUWzso6mbqR",
- "outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a"
+ "outputId": "8e4b0298-d27d-4ac7-e26a-ef94af0594ec"
},
"outputs": [
{
@@ -179,9 +173,9 @@
"output_type": "stream",
"text": [
"Elements of ds_tensors:\n",
- "tf.Tensor([4 9], shape=(2,), dtype=int32)\n",
+ "tf.Tensor([1 9], shape=(2,), dtype=int32)\n",
"tf.Tensor([16 25], shape=(2,), dtype=int32)\n",
- "tf.Tensor([36 1], shape=(2,), dtype=int32)\n",
+ "tf.Tensor([ 4 36], shape=(2,), dtype=int32)\n",
"\n",
"Elements in ds_file:\n",
"tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n",
@@ -191,22 +185,19 @@
],
"source": [
"print('Elements of ds_tensors:')\n",
- "for x in tfe.Iterator(ds_tensors):\n",
+ "for x in ds_tensors:\n",
" print(x)\n",
"\n",
"print('\\nElements in ds_file:')\n",
- "for x in tfe.Iterator(ds_file):\n",
+ "for x in ds_file:\n",
" print(x)"
]
}
],
"metadata": {
"colab": {
+ "collapsed_sections": [],
"default_view": {},
- "last_runtime": {
- "build_target": "",
- "kind": "local"
- },
"name": "Eager Execution Tutorial: Importing Data",
"provenance": [],
"version": "0.3.2",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index be20d1b777..f66d844660 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -38,6 +38,7 @@ _allowed_symbols = [
'binary_classification_head',
'clip_gradients_by_norm',
'forward_features',
+ 'logistic_regression_head',
'multi_class_head',
'multi_head',
'multi_label_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 3dcf0374c8..2a6d17e81b 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -205,8 +205,9 @@ def regression_head(weight_column=None,
shape `[D0, D1, ... DN, label_dimension]`.
Also supports custom `inverse_link_fn`, also known as 'mean function'.
- `inverse_link_fn` takes `logits` as argument and returns predicted values.
- This function is the inverse of the link function defined in
+ `inverse_link_fn` is only used in `PREDICT` mode. It takes `logits` as
+ argument and returns predicted values. This function is the inverse of the
+ link function defined in
https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function
Namely, for poisson regression, set `inverse_link_fn=tf.exp`.
@@ -305,6 +306,70 @@ def poisson_regression_head(
name=name)
+def logistic_regression_head(
+ weight_column=None,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
+ name=None):
+ """Creates a `_Head` for logistic regression.
+
+ Uses `sigmoid_cross_entropy_with_logits` loss, which is the same as
+ `binary_classification_head`. The differences compared to
+ `binary_classification_head` are:
+
+ * Does not support `label_vocabulary`. Instead, labels must be float in the
+ range [0, 1].
+ * Does not calculate some metrics that do not make sense, such as AUC.
+ * In `PREDICT` mode, only returns logits and predictions
+ (`=tf.sigmoid(logits)`), whereas `binary_classification_head` also returns
+ probabilities, classes, and class_ids.
+ * Export output defaults to `RegressionOutput`, whereas
+ `binary_classification_head` defaults to `PredictOutput`.
+
+ The head expects `logits` with shape `[D0, D1, ... DN, 1]`.
+ In many applications, the shape is `[batch_size, 1]`.
+
+ The `labels` shape must match `logits`, namely
+ `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`.
+
+ If `weight_column` is specified, weights must be of shape
+ `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`.
+
+ This is implemented as a generalized linear model, see
+ https://en.wikipedia.org/wiki/Generalized_linear_model.
+
+ Args:
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch and label dimension. Defaults to
+ `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by
+ `batch size * label_dimension`. See `tf.losses.Reduction`.
+ name: name of the head. If provided, summary and metrics keys will be
+ suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
+
+ Returns:
+ An instance of `_Head` for logistic regression.
+
+ Raises:
+ ValueError: If `loss_reduction` is invalid.
+ """
+ def _logistic_loss(labels, logits):
+ labels = head_lib._assert_range( # pylint:disable=protected-access
+ labels, n_classes=2, message='Labels must be in range [0, 1]')
+ return nn.sigmoid_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ # TODO(roumposg): Rename to _regression_head, since it supports loss_fn arg.
+ return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access
+ weight_column=weight_column,
+ label_dimension=1,
+ loss_reduction=loss_reduction,
+ loss_fn=_logistic_loss,
+ inverse_link_fn=math_ops.sigmoid,
+ name=name)
+
+
def multi_label_head(n_classes,
weight_column=None,
thresholds=None,
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 98962ca427..19b86df556 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -1211,5 +1211,124 @@ class PoissonRegressionHead(test.TestCase):
self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval())
+class LogisticRegressionHead(test.TestCase):
+
+ def setUp(self):
+ ops.reset_default_graph()
+
+ def test_train(self):
+ head = head_lib.logistic_regression_head()
+
+ # Create estimator spec.
+ logits = np.array([[0], [-1], [1]], dtype=np.float32)
+ labels = np.array([[.4], [.6], [.8]], dtype=np.float32)
+ # Following the documentation in
+ # tf.nn.sigmoid_cross_entropy_with_logits:
+ # With x = logits, z = labels.
+ # loss = max(x, 0) - x * z + log(1 + exp(-abs(x)))
+ # loss = [0 - 0 * 0.4 + ln(1 + exp(-0)),
+ # 0 + 1 * 0.6 + ln(1 + exp(-1)),
+ # 1 - 1 * 0.8 + ln(1 + exp(-1))]
+ # = [0.6931, 0.9133, 0.5133]
+ # training_loss = (0.6931 + 0.9133 + 0.5133) / 3
+ expected_loss = 0.7066
+ atol = 0.001
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ with ops.control_dependencies((check_ops.assert_near(
+ math_ops.to_float(expected_loss), math_ops.to_float(loss),
+ atol=atol, name='assert_loss'),)):
+ return constant_op.constant(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42.,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ loss, train_result = sess.run([spec.loss, spec.train_op])
+ self.assertAlmostEqual(expected_loss, loss, delta=atol)
+ self.assertEqual(expected_train_result, train_result)
+
+ def test_train_labels_too_large(self):
+ head = head_lib.logistic_regression_head()
+
+ # Create estimator spec.
+ logits = np.array([[0], [-1], [1]], dtype=np.float32)
+ labels = np.array([[.4], [1.2], [.8]], dtype=np.float32)
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return constant_op.constant(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42.,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[1.2\]\[0.8\]\]'):
+ _ = sess.run(spec.loss)
+
+ def test_train_labels_negative(self):
+ head = head_lib.logistic_regression_head()
+
+ # Create estimator spec.
+ logits = np.array([[0], [-1], [1]], dtype=np.float32)
+ labels = np.array([[.4], [-0.2], [.8]], dtype=np.float32)
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return constant_op.constant(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42.,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[-0.2\]\[0.8\]\]'
+ ):
+ _ = sess.run(spec.loss)
+
+ def test_predict(self):
+ head = head_lib.logistic_regression_head()
+
+ # Create estimator spec.
+ logits = np.array([[0], [-1], [1]], dtype=np.float32)
+ expected_predictions = 1. / (1. + np.exp(-logits))
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42.,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
+
+ # Assert spec contains expected tensors.
+ keys = prediction_keys.PredictionKeys
+ self.assertItemsEqual(
+ (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys())
+ self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype)
+ self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype)
+
+ # Assert predictions.
+ with self.test_session():
+ _initialize_variables(self, spec.scaffold)
+ self.assertAllClose(
+ expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
+ self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval())
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 811fa89bc3..5cef4068ed 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -107,7 +107,7 @@ class WALSModel(object):
# the prep_gramian_op for row(column) can be run.
worker_init_op = model.worker_init
- # To be run once per integration sweep before the row(column) update
+ # To be run once per iteration sweep before the row(column) update
# initialize ops can be run. Note that in the distributed training
# situations, this should only be run by the chief trainer. All other
# trainers need to block until this is done.
@@ -436,7 +436,7 @@ class WALSModel(object):
gramian: Variable storing the gramian calculated from the factors.
Returns:
- A op that updates the gramian with the calculated value from the factors.
+ An op that updates the gramian with the calculated value from the factors.
"""
partial_gramians = []
for f in factors:
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 91929184a2..2ff4d41d75 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest
def _inner_product(x, y):
- """Inner product between tensors x and y.
+ r"""Inner product between tensors x and y.
The input tensors are assumed to be in ROW representation, that is, the method
returns \\(x * y^T\\).
@@ -131,10 +131,6 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
mapped_dim = 5000
stddev = 5.0
- # TODO(sibyl-vie3Poto): Reduce test's running time before moving to third_party. One
- # possible way to speed the test up is to compute both the approximate and
- # the exact kernel matrix directly using matrix operations instead of
- # computing the values for each pair of points separately.
points_shape = [1, input_dim]
points = [
random_ops.random_uniform(shape=points_shape, maxval=1.0)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 32c776cb38..3a5c8eb5f9 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -673,9 +673,6 @@ class KroneckerProductFB(FisherBlock):
output factors.
"""
- def __init__(self, layer_collection):
- super(KroneckerProductFB, self).__init__(layer_collection)
-
def _setup_damping(self, damping, normalization=None):
"""Makes functions that compute the damping values for both factors."""
def compute_damping():
@@ -1309,6 +1306,8 @@ class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
else:
raise ValueError("Global config variable TOWER_STRATEGY must be one of "
"'concat' or 'separate'.")
+ else:
+ inputs = tuple(inputs)
# Now we perform the analogous processing for grads_list
if isinstance(grads_list[0][0], (list, tuple)):
@@ -1351,6 +1350,8 @@ class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
else:
raise ValueError("Global config variable TOWER_STRATEGY must be one of "
"'concat' or 'separate'.")
+ else:
+ grads_list = tuple(tuple(grads) for grads in grads_list)
if self._num_uses is None:
raise ValueError("You must supply a value for the num_uses argument if "
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index bf25144982..dd2395f8c9 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import init_ops
@@ -691,11 +692,12 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
index += num_val
return grouped_vals
+ @test_util.enable_c_shapes
def testEmbeddingLookupSparse(self):
vocab_size = 13
batch_size = 10
param_shape = [2, 5]
- expected_lookup_result_shape = [None] + param_shape
+ expected_lookup_result_shape = param_shape
sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
self._RandomIdsAndWeights(batch_size, vocab_size))
@@ -719,7 +721,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
None if ignore_weights else sp_weights,
combiner=combiner)
- self.assertEqual(embedding_sum.get_shape().as_list(),
+ self.assertEqual(embedding_sum.get_shape().as_list()[1:],
expected_lookup_result_shape)
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index d3bb0fda57..0a863f0e20 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -863,6 +863,38 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self):
+ """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[2., 3., 1.],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ # 'GB' is out of the vocabulary.
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5])
+ }, constant_op.constant([[1], [0], [1]])
+
+ country = feature_column_lib.sparse_column_with_keys(
+ 'country', keys=['US', 'CA', 'MK', 'IT', 'CN'])
+ country_weighted_by_price = feature_column_lib.weighted_sparse_column(
+ country, 'price')
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id')
+ classifier = linear.LinearClassifier(
+ feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testSdcaOptimizerCrossedFeatures(self):
"""Tests LinearClassifier with SDCAOptimizer and crossed features."""
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index 2e92ad6eb3..78b7970069 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -42,47 +42,3 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
)
-
-cuda_py_test(
- name = "linear_operator_block_diag_test",
- size = "medium",
- srcs = ["python/kernel_tests/linear_operator_block_diag_test.py"],
- additional_deps = [
- ":linalg_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
- shard_count = 5,
- tags = [
- "noasan",
- "optonly",
- ],
-)
-
-cuda_py_test(
- name = "linear_operator_kronecker_test",
- size = "medium",
- srcs = ["python/kernel_tests/linear_operator_kronecker_test.py"],
- additional_deps = [
- ":linalg_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
- shard_count = 8,
- tags = [
- "noasan",
- "optonly",
- ],
-)
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
index 554854da84..a262a099cf 100644
--- a/tensorflow/contrib/linalg/__init__.py
+++ b/tensorflow/contrib/linalg/__init__.py
@@ -39,14 +39,14 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
-from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import *
-from tensorflow.contrib.linalg.python.ops.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator import *
+from tensorflow.python.ops.linalg.linear_operator_block_diag import *
from tensorflow.python.ops.linalg.linear_operator_circulant import *
from tensorflow.python.ops.linalg.linear_operator_composition import *
from tensorflow.python.ops.linalg.linear_operator_diag import *
from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
from tensorflow.python.ops.linalg.linear_operator_identity import *
+from tensorflow.python.ops.linalg.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 213c2eced5..12039ecc6f 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -198,6 +198,14 @@ class SDCAOptimizer(object):
example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1])
flat_ids = array_ops.reshape(id_tensor.values, [-1])
+ # Prune invalid IDs (< 0) from the flat_ids, example_ids, and
+ # weight_tensor. These can come from looking up an OOV entry in the
+ # vocabulary (default value being -1).
+ is_id_valid = math_ops.greater_equal(flat_ids, 0)
+ flat_ids = array_ops.boolean_mask(flat_ids, is_id_valid)
+ example_ids = array_ops.boolean_mask(example_ids, is_id_valid)
+ weight_tensor = array_ops.boolean_mask(weight_tensor, is_id_valid)
+
projection_length = math_ops.reduce_max(flat_ids) + 1
# project ids based on example ids so that we can dedup ids that
# occur multiple times for a single example.
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile
index 65fba52d46..e4f86e258a 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/Makefile
@@ -1,4 +1,3 @@
-
# Find where we're running from, so we can store generated files here.
ifeq ($(origin MAKEFILE_DIR), undefined)
MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
@@ -69,12 +68,12 @@ LIB_NAME := libtensorflow-lite.a
LIB_PATH := $(LIBDIR)$(LIB_NAME)
# A small example program that shows how to link against the library.
-BENCHMARK_PATH := $(BINDIR)benchmark_model
+MINIMAL_PATH := $(BINDIR)minimal
-BENCHMARK_SRCS := \
-tensorflow/contrib/lite/tools/benchmark_model.cc
-BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
+MINIMAL_SRCS := \
+tensorflow/contrib/lite/examples/minimal/minimal.cc
+MINIMAL_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
@@ -100,7 +99,7 @@ $(wildcard tensorflow/contrib/lite/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \
-$(BENCHMARK_SRCS)
+$(MINIMAL_SRCS)
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
# File names of the intermediate files target compilation generates.
@@ -119,17 +118,17 @@ $(OBJDIR)%.o: %.c
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(BENCHMARK_PATH)
+all: $(LIB_PATH) $(MINIMAL_PATH)
# Gathers together all the objects we've compiled into a single '.a' archive.
$(LIB_PATH): $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
-$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH)
+$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
- -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \
+ -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
# Gets rid of all generated files.
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index 4928012997..5700007256 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -42,7 +42,6 @@ android_binary(
custom_package = "org.tensorflow.lite.demo",
inline_constants = 1,
manifest = "AndroidManifest.xml",
- manifest_merger = "android",
nocompress_extensions = [
".tflite",
],
diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc
new file mode 100644
index 0000000000..106e3b0270
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include <cstdio>
+
+// This is an example that is minimal to read a model
+// from disk and perform inference. There is no data being loaded
+// that is up to you to add as a user.
+//
+// NOTE: Do not add any dependencies to this that cannot be built with
+// the minimal makefile. This example must remain trivial to build with
+// the minimal build tool.
+//
+// Usage: minimal <tflite model>
+
+using namespace tflite;
+
+#define TFLITE_MINIMAL_CHECK(x) \
+ if(!(x)) { \
+ fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
+ exit(1); \
+ }
+
+
+int main(int argc, char *argv[]) {
+ if(argc != 2) {
+ fprintf(stderr, "Usage: %s <model>\n");
+ return 1;
+ }
+ const char* filename = argv[1];
+
+ // Load model
+ std::unique_ptr<tflite::FlatBufferModel> model
+ = tflite::FlatBufferModel::BuildFromFile(filename);
+ TFLITE_MINIMAL_CHECK(model != nullptr);
+
+ // Build the interpreter
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+ InterpreterBuilder builder(*model.get(), resolver);
+ std::unique_ptr<Interpreter> interpreter;
+ builder(&interpreter);
+ TFLITE_MINIMAL_CHECK(interpreter != nullptr);
+
+ // Allocate tensor buffers.
+ TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
+
+ // Fill input buffers
+ // TODO(user): Insert code to fill input tensors
+
+ // Run inference
+ TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
+
+ // Read output buffers
+ // TODO(user): Insert getting data out code.
+
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md
index 7a3a231626..ab50789307 100644
--- a/tensorflow/contrib/lite/g3doc/rpi.md
+++ b/tensorflow/contrib/lite/g3doc/rpi.md
@@ -32,7 +32,7 @@ This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc v
Log in to you RPI, install the toolchain.
```bash
-sudo apt-get instal build-essential
+sudo apt-get install build-essential
```
First, clone this TensorFlow repository. Run this at the root of the repository:
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 9d8ea55fd1..ebb0aedc20 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -125,7 +125,8 @@ Interpreter::~Interpreter() {
for (int i = 0; i < context_.tensors_size; i++) {
TfLiteTensor* tensor = &context_.tensors[i];
- if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
+ if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
+ tensor->delegate->FreeBufferHandle != nullptr) {
tensor->delegate->FreeBufferHandle(tensor->delegate,
&tensor->buffer_handle);
}
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 6f3433abcf..1074f64263 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -325,9 +325,7 @@ class Interpreter {
void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; }
- profiling::Profiler* GetProfiler(profiling::Profiler* profiler) {
- return profiler_;
- }
+ profiling::Profiler* GetProfiler() { return profiler_; }
// The default capacity of `tensors_` vector.
static constexpr int kTensorsReservedCapacity = 128;
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 2fc803715b..a43251cad1 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -173,8 +173,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
} else {
throw new IllegalArgumentException(
String.format(
- "Input error: %s is not a valid name for any input. "
- + "The indexes of the inputs are %s",
+ "Input error: '%s' is not a valid name for any input. Names of inputs and their "
+ + "indexes are %s",
name, inputsIndexes.toString()));
}
}
@@ -195,8 +195,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
} else {
throw new IllegalArgumentException(
String.format(
- "Input error: %s is not a valid name for any output. "
- + "The indexes of the outputs are %s",
+ "Input error: '%s' is not a valid name for any output. Names of outputs and their "
+ + "indexes are %s",
name, outputsIndexes.toString()));
}
}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 17f4be09c6..005dca0253 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -238,10 +238,6 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
if (tensor == nullptr) return nullptr;
int num_dims = tensor->dims->size;
jintArray result = env->NewIntArray(num_dims);
- jint* dims = env->GetIntArrayElements(result, nullptr);
- for (int i = 0; i < num_dims; ++i) {
- dims[i] = static_cast<jint>(tensor->dims->data[i]);
- }
- env->ReleaseIntArrayElements(result, dims, 0);
+ env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data);
return result;
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 61d6c35ec8..210d943724 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -195,8 +195,8 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "WrongInputName is not a valid name for any input. The indexes of the inputs"
- + " are {input=0}");
+ "'WrongInputName' is not a valid name for any input. Names of inputs and their "
+ + "indexes are {input=0}");
}
int index = interpreter.getInputIndex("input");
assertThat(index).isEqualTo(0);
@@ -212,8 +212,8 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "WrongOutputName is not a valid name for any output. The indexes of the outputs"
- + " are {MobilenetV1/Predictions/Softmax=0}");
+ "'WrongOutputName' is not a valid name for any output. Names of outputs and their"
+ + " indexes are {MobilenetV1/Predictions/Softmax=0}");
}
int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax");
assertThat(index).isEqualTo(0);
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 689f9bfa71..57b3136cce 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -31,6 +31,7 @@ cc_library(
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:schema_fbs_version",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"//tensorflow/contrib/lite/testing:util",
"//tensorflow/core:tflite_portable_logging",
"@com_google_googletest//:gtest",
@@ -672,6 +673,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a64ac42bc4..3ac0210f36 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -96,15 +96,23 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
constexpr int kBwProjectionBiasTensor = 34; // Optional
// Output tensors.
-constexpr int kFwScratchBufferTensor = 0;
-constexpr int kFwOutputStateTensor = 1;
-constexpr int kFwCellStateTensor = 2;
-constexpr int kFwOutputTensor = 3;
+constexpr int kFwOutputStateTensor = 0;
+constexpr int kFwCellStateTensor = 1;
+constexpr int kFwOutputTensor = 2;
+
+constexpr int kBwOutputStateTensor = 3;
+constexpr int kBwCellStateTensor = 4;
+constexpr int kBwOutputTensor = 5;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
-constexpr int kBwScratchBufferTensor = 4;
-constexpr int kBwOutputStateTensor = 5;
-constexpr int kBwCellStateTensor = 6;
-constexpr int kBwOutputTensor = 7;
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckLstmTensorDimensions(
@@ -296,9 +304,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Resize the output, state and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 8);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -330,12 +340,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_output_state =
GetOutput(context, node, kFwOutputStateTensor);
TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
@@ -349,13 +355,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
fw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
fw_cell_size->data[0] = n_batch;
fw_cell_size->data[1] = n_fw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ &context->tensors[node->temporaries->data[0]];
+ fw_scratch_buffer->type = input->type;
+ fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -392,17 +406,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
TfLiteTensor* bw_output_state =
GetOutput(context, node, kBwOutputStateTensor);
TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
@@ -416,13 +426,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
bw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
bw_cell_size->data[0] = n_batch;
bw_cell_size->data[1] = n_bw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Create a scratch buffer tensor.
+ node->temporaries->data[1] = *(scratch_tensor_index) + 1;
+ TfLiteTensor* bw_scratch_buffer =
+ &context->tensors[node->temporaries->data[1]];
+ bw_scratch_buffer->type = input->type;
+ bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -553,7 +569,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[0]];
float* fw_input_gate_scratch = nullptr;
float* fw_cell_scratch = nullptr;
float* fw_forget_gate_scratch = nullptr;
@@ -624,7 +640,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[1]];
float* bw_input_gate_scratch = nullptr;
float* bw_cell_scratch = nullptr;
float* bw_forget_gate_scratch = nullptr;
@@ -691,9 +707,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace bidirectional_sequence_lstm
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_lstm::Prepare,
- bidirectional_sequence_lstm::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
+ bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index cca857bac0..a18e1bce34 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -102,9 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_projection_bias_ = AddNullInput();
}
- fw_scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
fw_output_state_ = AddOutput(TensorType_FLOAT32);
fw_cell_state_ = AddOutput(TensorType_FLOAT32);
fw_output_ = AddOutput(TensorType_FLOAT32);
@@ -164,9 +161,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_projection_bias_ = AddNullInput();
}
- bw_scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
bw_output_state_ = AddOutput(TensorType_FLOAT32);
bw_cell_state_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
@@ -349,12 +343,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int fw_output_;
int fw_output_state_;
int fw_cell_state_;
- int fw_scratch_buffer_;
int bw_output_;
int bw_output_state_;
int bw_cell_state_;
- int bw_scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 888e67966c..470b52b7bc 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -55,19 +55,24 @@ struct OpData {
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
+ // The index of the temporary tensor where the quantized inputs are cached.
+ int input_quantized_index;
};
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
+constexpr int kScratchBufferTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
gemm_support::IncrementUsageCounter(context);
- return new OpData;
+ auto* op_data = new OpData;
+ context->AddTensors(context, 1, &op_data->input_quantized_index);
+ return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
@@ -121,6 +126,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
&data->output_activation_max);
}
+ // If we have to perform on-the-fly quantization (with quantized weights and
+ // float inputs) first we need to quantize the inputs. Allocate a temporary
+ // buffer to store the intermediate quantized values.
+ if (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8) {
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = data->input_quantized_index;
+
+ TfLiteTensor* input_quantized =
+ &context->tensors[node->temporaries->data[0]];
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+
+ // TODO(raziel): add this logic to ResizeTensor.
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ }
+
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
output_size_array->data[0] = batch_size;
@@ -163,6 +189,74 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
+TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ TfLiteTensor* input, TfLiteTensor* filter,
+ TfLiteTensor* bias, TfLiteTensor* input_quantized,
+ TfLiteTensor* output) {
+ // Check the types for this hybrid Op.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+
+ int total_input_size = 1;
+ for (int i = 0; i < input->dims->size; i++) {
+ total_input_size *= input->dims->data[i];
+ }
+
+ const int input_size = filter->dims->data[1];
+ const int batch_size = total_input_size / filter->dims->data[1];
+ const int num_units = filter->dims->data[0];
+
+ // Output = bias if bias tensor exists.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // TODO(mirkov): change std::minmax_element with a vectorized call.
+ auto minmax_element =
+ std::minmax_element(input->data.f, input->data.f + total_input_size);
+ // Save matrix multiplication computation for all zero input.
+ if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) {
+ tensor_utils::ApplyActivationToVector(output->data.f,
+ batch_size * num_units,
+ params->activation, output->data.f);
+ return kTfLiteOk;
+ }
+
+ // Quantize input from float to uint8 + quantization params (scaling factor).
+ float min, max;
+ float* scaling_factors = new float[batch_size];
+
+ // Quantize each batch independently.
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input->data.f + offset, input_size,
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8) + offset, &min,
+ &max, &scaling_factors[b]);
+ // Incorporate scaling of the filter.
+ scaling_factors[b] *= filter->params.scale;
+ }
+
+ // Compute output += weight * quantized_input
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ reinterpret_cast<int8_t*>(filter->data.uint8), num_units, input_size,
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8), scaling_factors,
+ batch_size, output->data.f, /*result_stride=*/1);
+
+ // Apply activation function to floats.
+ tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
+ params->activation, output->data.f);
+ delete[] scaling_factors;
+
+ return kTfLiteOk;
+}
+
#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
if (params->activation == kTfLiteActNone) { \
macro_name(target_namespace, kNone); \
@@ -195,9 +289,17 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops);
} else if (kernel_type == kPie) {
- // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
- // we just defer to the MINI ones.
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ if (input->type == kTfLiteFloat32) {
+ // Pie currently only supports quantized models and float inputs/outputs.
+ TfLiteTensor* input_quantized =
+ &context->tensors[node->temporaries->data[0]];
+ return EvalPieQuantized(context, node, params, data, input, filter, bias,
+ input_quantized, output);
+ } else {
+ // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
+ // we just defer to the MINI ones.
+ TF_LITE_FULLY_CONNECTED(optimized_ops);
+ }
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops);
}
@@ -245,7 +347,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- switch (input->type) { // Already know in/out types are same.
+ switch (filter->type) { // Already know in/out types are same.
case kTfLiteFloat32:
return EvalFloat<kernel_type>(context, node, params, data, input, filter,
bias, output);
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
index 87413000a9..05dd028b48 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
@@ -224,6 +225,60 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
}
};
+// In the hybrid model the weights are quantized (to uint8). But the bias,
+// input (and output) are expected to be in float precision.
+class HybridFullyConnectedOpModel : public SingleOpModel {
+ public:
+ HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
+ const TensorData& weights,
+ const TensorData& output = {TensorType_FLOAT32})
+ : batches_(batches), units_(units) {
+ int total_input_size = 1;
+ for (int i = 0; i < input.shape.size(); ++i) {
+ total_input_size *= input.shape[i];
+ }
+ input_size_ = total_input_size / batches_;
+
+ input_ = AddInput(input);
+ weights_ = AddInput(weights);
+
+ TensorData bias{TensorType_FLOAT32, {units_}};
+ bias_ = AddInput(bias);
+
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
+ CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+ .Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_FULLY_CONNECTED,
+ ops::builtin::Register_FULLY_CONNECTED_PIE());
+ BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+ }
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+ void SetWeights(std::initializer_list<float> data) {
+ SymmetricQuantizeAndPopulate(weights_, data);
+ }
+
+ void SetInput(std::initializer_list<float> f) { PopulateTensor(input_, f); }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ protected:
+ int input_;
+ int weights_;
+ int bias_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+};
+
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
{"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
{"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()},
@@ -231,18 +286,43 @@ const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
{"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
});
-class FullyConnectedOpTest : public SingleOpTest {
+class FloatFullyConnectedOpTest : public SingleOpTest {
protected:
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
return *kKernelMap;
}
};
+const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
+ {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()},
+ {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
+});
+
+class QuantizedFullyConnectedOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMapNoPie;
+ }
+};
+
+const auto kKernelMapPie = new std::map<string, TfLiteRegistration*>({
+ {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
+});
+
+// Hybrid mode is used by the Pie quantized kernel.
+class HybridFullyConnectedOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMapPie;
+ }
+};
+
// TODO(ahentz): add more small tests like this one, focused on making sure the
// calculations are correct.
-TEST_P(FullyConnectedOpTest, SimpleTest) {
- FloatFullyConnectedOpModel m(GetRegistration(), 3, 2,
- {TensorType_FLOAT32, {2, 10}});
+TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
+ FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 10}});
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
@@ -260,9 +340,9 @@ TEST_P(FullyConnectedOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
-TEST_P(FullyConnectedOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
QuantizedFullyConnectedOpModel m(
- GetRegistration(), 3, 2,
+ GetRegistration(), /*units=*/3, /*batches*/ 2,
/*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
@@ -288,13 +368,40 @@ TEST_P(FullyConnectedOpTest, SimpleTestQuantized) {
EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
}
-TEST(FullyConnectedOpTest, SimpleTest4DInput) {
+TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
+ HybridFullyConnectedOpModel m(
+ /*units=*/3, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 10}},
+ /*weights=*/{TensorType_UINT8, {3, 10}, -63.5, 64}); // PIE
+
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 24, 25, 26, //
+ 58, 59, 60, //
+ },
+ /*max_abs_error=*/1.3f)));
+}
+
+TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in
// batches. In this case, we need the input to have multiples of '2'.
FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
- /*units=*/3,
- /*batches=*/2,
+ /*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
@@ -316,9 +423,9 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) {
}));
}
-TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
+TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) {
QuantizedFullyConnectedOpModel m(
- GetRegistration(), 3, 2,
+ GetRegistration(), /*units=*/3, /*batches=*/2,
/*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
@@ -345,14 +452,18 @@ TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
}
INSTANTIATE_TEST_CASE_P(
- FullyConnectedOpTest, FullyConnectedOpTest,
+ FloatFullyConnectedOpTest, FloatFullyConnectedOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+INSTANTIATE_TEST_CASE_P(
+ QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
+
// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
// to debug errors and doesn't necessarily test all the important details.
-TEST_P(FullyConnectedOpTest, BlackBoxTest) {
- FloatFullyConnectedOpModel m(GetRegistration(), 16, 2,
- {TensorType_FLOAT32, {2, 8}});
+TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
+ FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/16, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 8}});
m.SetWeights(
{0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636,
-0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504,
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index c5539afb9c..df29172f83 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -303,6 +303,7 @@ cc_library(
],
hdrs = [
"common.h",
+ "compatibility.h",
"optimized/cpu_check.h",
"optimized/neon_tensor_utils.h",
"optimized/tensor_utils_impl.h",
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 47dfcbeb01..65f25168e3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
@@ -27,6 +28,22 @@ limitations under the License.
namespace tflite {
namespace tensor_utils {
+namespace {
+
+// Allocates, at least, size bytes of uninitialized storage whose alignment is
+// specified by alignment. The size parameter must be an integral multiple of
+// alignment.
+// Caller is responsible by freeing the allocated memory by calling free on
+// the passed freeing_buffer pointer.
+void* aligned_alloc(size_t alignment, size_t size, void** freeing_buffer) {
+ *freeing_buffer = malloc(size + alignment);
+ const size_t offset = ((uintptr_t)*freeing_buffer) % alignment; // NOLINT
+ return offset == 0
+ ? *freeing_buffer
+ : ((char*)*freeing_buffer + (alignment - offset)); // NOLINT
+}
+
+} // namespace
void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
int m_cols, const float* vector,
@@ -114,6 +131,114 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
delete[] vector_cache_float32x4;
}
+void NeonMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride) {
+ const int kWeightsPerUint32 = 4;
+ const int kWeightsPerNeonLane = 16;
+ // If the number of rows is not divisible by kWeightsPerUint32, we set a
+ // flag and allocate an aligned memory block. The flag is used to use the
+ // aligned memory block later in the kernel loop.
+ bool unaligned = false;
+ int8* aligned_row = nullptr;
+ void* aligned_row_free = nullptr;
+ if ((m_cols & (kWeightsPerUint32 - 1)) != 0) {
+ unaligned = true;
+ aligned_row = (int8*)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
+ &aligned_row_free);
+ }
+ void* aligned_vec_free = nullptr;
+ int8* aligned_vec = (int8*)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
+ &aligned_vec_free);
+
+ // If m_cols is not at least kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start = m_cols - (m_cols & (kWeightsPerNeonLane - 1));
+
+ int batch, row, col;
+ for (batch = 0; batch < n_batch; ++batch) {
+ const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ // Copy the vector data to an aligned vector.
+ memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
+ // Compute dot-product for every column.
+ for (row = 0; row < m_rows; ++row, result += result_stride) {
+ // Get the address of the first element of the row.
+ int8* row_ptr = (int8*)matrix + row * m_cols; // NOLINT
+ if (unaligned) {
+ memcpy(aligned_row, row_ptr, sizeof(int8) * m_cols);
+ row_ptr = aligned_row;
+ }
+
+ // Initialize the dot product sum for the row to 0.
+ int32x4_t dotprod = vmovq_n_s32(0);
+
+ // Prefetch the row to cache.
+ __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
+ 3 /* temporal locality */);
+
+ // For every block of 16 8-bit elements.
+ col = 0;
+ for (; col < postamble_start; col += kWeightsPerNeonLane) {
+ // Load 16 8-bit values from the row and vector, each, to operate on.
+ // Here the assumption is that each buffer is 4-byte aligned.
+ TFLITE_CHECK_EQ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1),
+ 0);
+ const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));
+ const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));
+ // Multiply the low bits (i.e. the lower 8 8bit numbers in the
+ // registers).
+ int16x8_t prod_16x8 =
+ vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
+ // Multiply the high bits (i.e. the lower 8 8bit numbers in the
+ // registers), and accumulate with the result of the low bits product.
+ // The assumption here is that overflow will not happen as we quantize
+ // our values to be in the range [-127, 127]. As such the sum of the 2
+ // products is always strictly smaller than 15-bits (32767 in absolute
+ // value).
+ prod_16x8 =
+ vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
+
+ dotprod = vpadalq_s16(dotprod, prod_16x8);
+ } // for col
+
+ int32 postable_sum = 0;
+ // Postamble loop.
+ // TODO(raziel): if (ABSL_PREDICT_FALSE(postamble_start < m_rows))
+ if (postamble_start < m_cols) {
+ col = postamble_start;
+ if ((m_cols - postamble_start) >= (kWeightsPerNeonLane >> 1)) {
+ // Load 8 8-bit values from the row and column each to operate on.
+ // Here the assumption is that each buffer is 4-bytes aligned.
+ TFLITE_CHECK_EQ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1),
+ 0);
+ const int8x8_t s1_8x8 = vld1_s8((const int8_t*)(aligned_vec + col));
+ const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col));
+ const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
+ dotprod = vpadalq_s16(dotprod, prod_16x8);
+ col += (kWeightsPerNeonLane >> 1);
+ }
+ for (; col < m_cols; ++col) {
+ postable_sum += row_ptr[col] * aligned_vec[col];
+ } // for col
+ }
+ // Add the 4 intermediate sum values to get the final dot-prod value for
+ // this row.
+ int64x2_t pairwiseAdded = vpaddlq_s32(dotprod);
+ int32 neon_sum =
+ vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
+
+ *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv);
+ } // for row
+ } // for batch
+
+ if (unaligned) {
+ free(aligned_row_free);
+ }
+ free(aligned_vec_free);
+}
+
void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
int v_size, float* result) {
// If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 3b6f4bd583..9e60d0657b 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -32,6 +32,14 @@ void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
vector, n_batch, result, result_stride);
}
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride) {
+ NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
+ vectors, scaling_factors, n_batch, result, result_stride);
+}
+
void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
int v_size, float* result) {
NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 19220470f4..d570dadd86 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -40,6 +40,16 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
int n_batch, float* result,
int result_stride);
+// Matrix multiplication for quantized values using symmetric quantization.
+void PortableMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride);
+void NeonMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride);
+
// Cwise product of two vectors.
void PortableVectorVectorCwiseProduct(const float* vector1,
const float* vector2, int v_size,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 5e7586eeda..2607adc0c1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -69,6 +69,30 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
}
}
+void PortableMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride) {
+ int batch, row, col;
+ for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
+ const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ // Get the address of the first row.
+ int8_t* row_ptr = (int8_t*)matrix; // NOLINT
+ for (row = 0; row < m_rows; ++row, result += result_stride) {
+ // Initialize the dot product sum for the row to 0.
+ int32_t dotprod = 0;
+ // Prefetch the row to cache.
+ __builtin_prefetch(row_ptr, 0 /* prefetch for read */,
+ 3 /* temporal locality */);
+ // For every block of 16 8-bit elements (128-bit register) from each row.
+ for (col = 0; col < m_cols; ++col, ++row_ptr) {
+ dotprod += (*row_ptr) * (vectors[col]);
+ } // for col
+ *result += (dotprod * batch_scaling_factor_inv);
+ } // for row
+ } // for batch
+}
+
void PortableVectorVectorCwiseProduct(const float* vector1,
const float* vector2, int v_size,
float* result) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index 478cda8e19..1757a9f5e5 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -37,6 +37,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
int n_batch, float* result,
int result_stride);
+void PortableMatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride);
+
// Cwise product of two vectors.
void PortableVectorVectorCwiseProduct(const float* vector1,
const float* vector2, int v_size,
@@ -122,6 +127,15 @@ void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
n_batch, result, result_stride);
}
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vector, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride) {
+ PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
+ scaling_factors, n_batch, result,
+ result_stride);
+}
+
void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
int v_size, float* result) {
PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result);
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 997dc4425d..e1c9ccd84b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -31,17 +31,35 @@ void SymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min, float* max,
float* scaling_factor);
-// Multiply a matrix by a batch vector, and store results in a batch-size
-// vector using a stride value provided in result_stride. 'result_stride' shows
-// how the number of elements between consecutive result values. For example
-// result_stride = 1, will cause the output to look like this:
-// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be
-// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows]
+// Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch
+// dimension composed by input vectors independent from each other). The result
+// of the multiplication is accumulated to the passed result buffer.
+// More specifically, for a matrix M of shape [n, i] and a batched-vector
+// of shape [i, batch] it will first compute the product of shape [n, batch].
+// This product will be accumulated to the result buffer, using a stride value
+// provided in result_stride (the number of elements between consecutive result
+// values). For example result_stride = 1, will cause the output to look like
+// this:
+// [O_1, 0_2, ... O_rows]
+// but result_stride = 3, will cause it to be arranged like this in memory:
+// [O_1, x, x, 0_2, x, x, ..., O_rows]
void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
int m_cols, const float* vector,
int n_batch, float* result,
int result_stride);
+// Same as the function above, but for values quantized using symmetric
+// quantization (e.g. by calling SymmetricQuantizeFloats).
+// The passed scaling factors is a buffer of the quantization scaling factors
+// that will be used to dequentize the products into the final result buffer.
+// These scaling factors are the multiplication of the matrix scaling factor
+// by the vector's scaling factor, one per batch (i.e. this allows quantizing
+// each batch in the batch-vector matrix independently).
+void MatrixBatchVectorMultiplyAccumulate(
+ const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
+ const int8_t* __restrict__ vectors, const float* scaling_factors,
+ int n_batch, float* __restrict__ result, int result_stride);
+
// Cwise product of two vectors.
void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
int v_size, float* result);
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 22b016746f..3d8a2eada0 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -107,6 +107,329 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
-1., 3., 7., 3., 23., 3.})));
}
+TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
+ // Note we use 29 columns as this exercises all the neon kernel: the
+ // 16-block SIMD code, the 8-block postamble, and the leftover postamble.
+ const int a_rows = 4, a_cols = 29;
+ const int kWeightsPerUint32 = 4;
+ const float a_float_data[] = {
+ /* 1st row */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* 2nd row */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* 3rd row */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* 4th row */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+
+ int8* a_int8_data = reinterpret_cast<int8*>(
+ aligned_malloc(a_rows * a_cols, kWeightsPerUint32));
+ float a_min, a_max;
+ float scaling_factor_a;
+ SymmetricQuantizeFloats(a_float_data, a_rows * a_cols, a_int8_data, &a_min,
+ &a_max, &scaling_factor_a);
+ const int8 expected_a_int8_data[] = {
+ /* 1st row */
+ 5,
+ 10,
+ 15,
+ 20,
+ 25,
+ 30,
+ 35,
+ 40,
+ 44,
+ 45,
+ 50,
+ 54,
+ 59,
+ 64,
+ 68,
+ 73,
+ 77,
+ 82,
+ 86,
+ 91,
+ 95,
+ 100,
+ 104,
+ 109,
+ 113,
+ 118,
+ 122,
+ 127,
+ 0,
+ /* 2nd row */
+ -5,
+ -10,
+ -15,
+ -20,
+ -25,
+ -30,
+ -35,
+ -40,
+ -44,
+ -45,
+ -50,
+ -54,
+ -59,
+ -64,
+ -68,
+ -73,
+ -77,
+ -82,
+ -86,
+ -91,
+ -95,
+ -100,
+ -104,
+ -109,
+ -113,
+ -118,
+ -122,
+ -127,
+ 0,
+ /* 3rd row */
+ 5,
+ -10,
+ 15,
+ -20,
+ 25,
+ -30,
+ 35,
+ -40,
+ 44,
+ -45,
+ 50,
+ -54,
+ 59,
+ -64,
+ 68,
+ -73,
+ 77,
+ -82,
+ 86,
+ -91,
+ 95,
+ -100,
+ 104,
+ -109,
+ 113,
+ -118,
+ 122,
+ -127,
+ 0,
+ /* 4th row */
+ -5,
+ 10,
+ -15,
+ 20,
+ -25,
+ 30,
+ -35,
+ 40,
+ -44,
+ 45,
+ -50,
+ 54,
+ -59,
+ 64,
+ -68,
+ 73,
+ -77,
+ 82,
+ -86,
+ 91,
+ -95,
+ 100,
+ -104,
+ 109,
+ -113,
+ 118,
+ -122,
+ 127,
+ 0,
+ };
+ for (int i = 0; i < a_rows * a_cols; ++i) {
+ EXPECT_EQ(expected_a_int8_data[i], a_int8_data[i]);
+ }
+
+ const int b_rows = 29, b_cols = 1, batches = 2;
+ const float b_float_data[] = {
+ /* batch 1 */
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ /* batch 2 */
+ 2.5,
+ -2.1,
+ 3.0,
+ -1.3,
+ 1.3,
+ -1.1,
+ 2.0,
+ -1.7,
+ 1.9,
+ -1.5,
+ 0.5,
+ -0.7,
+ 0.8,
+ -0.3,
+ 2.8,
+ -2.8,
+ 1.1,
+ -2.3,
+ 1.9,
+ -1.9,
+ 2.1,
+ -0.5,
+ 2.4,
+ -0.1,
+ 1.0,
+ -2.5,
+ 0.7,
+ -1.9,
+ 0.2,
+ };
+
+ // Quantized values of B:
+ int8 b_int8_data[b_rows * b_cols * batches];
+ float b_min, b_max;
+ float scaling_factor_b[batches];
+ SymmetricQuantizeFloats(b_float_data, b_rows * b_cols, b_int8_data, &b_min,
+ &b_max, &scaling_factor_b[0]);
+ SymmetricQuantizeFloats(&b_float_data[b_rows * b_cols], b_rows * b_cols,
+ &b_int8_data[b_rows * b_cols], &b_min, &b_max,
+ &scaling_factor_b[1]);
+
+ const int8 expected_b_int8_data[] = {
+ /* batch 1 */
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ -127,
+ 127,
+ /* batch 2 */
+ 106,
+ -89,
+ 127,
+ -55,
+ 55,
+ -47,
+ 85,
+ -72,
+ 80,
+ -64,
+ 21,
+ -30,
+ 34,
+ -13,
+ 119,
+ -119,
+ 47,
+ -97,
+ 80,
+ -80,
+ 89,
+ -21,
+ 102,
+ -4,
+ 42,
+ -106,
+ 30,
+ -80,
+ 8,
+ };
+ for (int i = 0; i < b_rows * b_cols * batches; ++i) {
+ EXPECT_EQ(expected_b_int8_data[i], b_int8_data[i]);
+ }
+
+ // Full float operation results in:
+ // -13.69, 13.69, 414.11, -414.11
+ // -6.325, 6.325, 631.263, -631.263
+ float c_float_data[a_rows * b_cols * batches];
+ for (int i = 0; i < a_rows * b_cols * batches; ++i) {
+ c_float_data[i] = 0.0;
+ }
+
+ // Testing product.
+ const float scaling_factor_c[2] = {
+ scaling_factor_a * scaling_factor_b[0],
+ scaling_factor_a * scaling_factor_b[1],
+ };
+ MatrixBatchVectorMultiplyAccumulate(a_int8_data, a_rows, a_cols, b_int8_data,
+ scaling_factor_c, batches, c_float_data,
+ /*result_stride=*/1);
+
+ // Assert we obtain the expected recovered float values.
+ const float expected_c_float_data[] = {
+ -14.474, 14.474, 414.402, -414.402, -6.92228, 6.92228, 632.042, -632.042,
+ };
+ for (int i = 0; i < a_rows * b_cols * batches; ++i) {
+ EXPECT_NEAR(expected_c_float_data[i], c_float_data[i], 0.001);
+ }
+
+ aligned_free(a_int8_data);
+}
+
TEST(uKernels, VectorVectorCwiseProductTest) {
constexpr int kVectorSize = 10;
static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0,
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 8cf1165135..668226e674 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -66,10 +66,19 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
@@ -220,12 +229,15 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@@ -250,15 +262,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
@@ -271,13 +280,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the output, state and scratch buffer tensors.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -362,7 +378,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -433,8 +450,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- lstm::Prepare, lstm::Eval};
+ static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
+ lstm::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index c068286b0d..d81220d8d3 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -97,9 +97,6 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -233,7 +230,6 @@ class LSTMOpModel : public SingleOpModel {
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index cee3ec6197..bcad58406a 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -95,9 +95,6 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -235,7 +232,6 @@ class LSTMOpModel : public SingleOpModel {
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index a9064d54e7..6fb6fe27eb 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
@@ -133,6 +134,22 @@ class SingleOpModel {
PopulateTensor(index, 0, q.data(), q.data() + q.size());
}
+ void SymmetricQuantizeAndPopulate(int index,
+ std::initializer_list<float> data) {
+ TfLiteTensor* t = interpreter_->tensor(index);
+ std::vector<float> values(data);
+ const int length = values.size();
+ std::vector<int8_t> q(length);
+ float min, max, scaling_factor;
+ tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min,
+ &max, &scaling_factor);
+ // Update quantization params.
+ t->params.scale = scaling_factor;
+ t->params.zero_point = 0;
+ PopulateTensor(index, /*offset=*/0, reinterpret_cast<uint8_t*>(q.data()),
+ reinterpret_cast<uint8_t*>(q.data() + q.size()));
+ }
+
const std::vector<int>& GetShape(int id) { return tensor_data_.at(id).shape; }
float GetScale(int id) { return tensor_data_.at(id).scale; }
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 42941a97db..3c1256d3a6 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -66,10 +66,19 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
@@ -220,12 +229,15 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output and state tensors based on the sizes of the input tensors.
+// Allocate a temprory scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -251,15 +263,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
@@ -273,13 +282,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -365,7 +381,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -439,7 +455,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace unidirectional_sequence_lstm
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
+ unidirectional_sequence_lstm::Free,
unidirectional_sequence_lstm::Prepare,
unidirectional_sequence_lstm::Eval};
return &r;
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index 93b635ae57..5881ced7c7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -100,9 +100,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -238,7 +235,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
index a354179a94..206de1962d 100644
--- a/tensorflow/contrib/lite/models/speech_test.cc
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -131,8 +131,8 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
"speech_speakerid_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"66",
- /*persistent_tensors=*/"19,20,40,41,61,62",
+ /*output_tensor=*/"63",
+ /*persistent_tensors=*/"18,19,38,39,58,59",
/*sequence_size=*/80, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
@@ -144,8 +144,8 @@ TEST_P(SpeechTest, AsrAmTest) {
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
"speech_asr_am_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"109",
- /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104",
+ /*output_tensor=*/"104",
+ /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
@@ -170,8 +170,8 @@ TEST_P(SpeechTest, EndpointerTest) {
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
"speech_endpointer_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"58",
- /*persistent_tensors=*/"28,29,49,50",
+ /*output_tensor=*/"56",
+ /*persistent_tensors=*/"27,28,47,48",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
@@ -183,8 +183,8 @@ TEST_P(SpeechTest, TtsTest) {
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
"speech_tts_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"74",
- /*persistent_tensors=*/"25,26,46,47,67,68,73",
+ /*output_tensor=*/"71",
+ /*persistent_tensors=*/"24,25,44,45,64,65,70",
/*sequence_size=*/334, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
index 5812de4b30..f7f518b75f 100644
--- a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
+++ b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
@@ -1,5 +1,5 @@
load_model: "speech_asr_lm_model.tflite"
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 3
input: "63982"
@@ -18,7 +18,7 @@ invoke {
input: "63981"
output: "-0.314846"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 6
input: "63982"
@@ -31,7 +31,7 @@ invoke {
input: "3082"
output: "-3.63721"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 8
input: "63982"
@@ -44,7 +44,7 @@ invoke {
input: "18965"
output: "-6.93985"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 13
input: "63982"
@@ -63,7 +63,7 @@ invoke {
input: "63981"
output: "-3.82091"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 19
input: "63982"
@@ -88,7 +88,7 @@ invoke {
input: "63981"
output: "-0.677399"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 26
input: "63982"
@@ -113,7 +113,7 @@ invoke {
input: "63981"
output: "0.415889"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 30
input: "63982"
@@ -131,7 +131,7 @@ invoke {
input: "51923"
output: "-14.1147"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 34
input: "63982"
@@ -144,7 +144,7 @@ invoke {
input: "16318"
output: "-1.54815"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 36
input: "63982"
@@ -157,7 +157,7 @@ invoke {
input: "28303"
output: "-14.0947"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 38
input: "63982"
diff --git a/tensorflow/contrib/lite/profiling/profiler.h b/tensorflow/contrib/lite/profiling/profiler.h
index dfa98a6708..8c3e4dc76d 100644
--- a/tensorflow/contrib/lite/profiling/profiler.h
+++ b/tensorflow/contrib/lite/profiling/profiler.h
@@ -85,7 +85,7 @@ class Profiler {
std::vector<const ProfileEvent*> GetProfileEvents() {
std::vector<const ProfileEvent*> profile_events;
profile_events.reserve(buffer_.Size());
- for (int i = 0; i < buffer_.Size(); i++) {
+ for (size_t i = 0; i < buffer_.Size(); i++) {
profile_events.push_back(buffer_.At(i));
}
return profile_events;
@@ -103,7 +103,9 @@ class ScopedProfile {
// Adds a profile event to profile that begins with the construction
// of object and ends when the object goes out of scope.
// The lifetime of tag should be at least the lifetime of profiler.
- ScopedProfile(Profiler* profiler, const char* tag) {
+
+ ScopedProfile(Profiler* profiler, const char* tag)
+ : buffer_(nullptr), event_handle_(0) {
if (profiler) {
buffer_ = profiler->GetProfileBuffer();
event_handle_ =
@@ -126,7 +128,8 @@ class ScopedOperatorProfile {
// Adds a profile event to profile that begins with the construction
// of object and ends when the object goes out of scope.
// The lifetime of tag should be at least the lifetime of profiler.
- ScopedOperatorProfile(Profiler* profiler, const char* tag, int node_index) {
+ ScopedOperatorProfile(Profiler* profiler, const char* tag, int node_index)
+ : buffer_(nullptr), event_handle_(0) {
if (profiler) {
buffer_ = profiler->GetProfileBuffer();
event_handle_ = buffer_->BeginEvent(
@@ -148,9 +151,11 @@ class ScopedOperatorProfile {
} // namespace profiling
} // namespace tflite
-#define SCOPED_OPERATOR_PROFILE(profiler, node_index) \
- tflite::profiling::ScopedOperatorProfile _profile((profiler), "OpInvoke", \
- (node_index))
+#define VARNAME_UNIQ(name, ctr) name##ctr
+
+#define SCOPED_OPERATOR_PROFILE(profiler, node_index) \
+ tflite::profiling::ScopedOperatorProfile VARNAME_UNIQ( \
+ _profile_, __COUNTER__)((profiler), "OpInvoke", (node_index))
#else
namespace tflite {
diff --git a/tensorflow/contrib/lite/profiling/profiler_test.cc b/tensorflow/contrib/lite/profiling/profiler_test.cc
index 7ea1d8f7d3..0fba0450a0 100644
--- a/tensorflow/contrib/lite/profiling/profiler_test.cc
+++ b/tensorflow/contrib/lite/profiling/profiler_test.cc
@@ -93,6 +93,20 @@ TEST(ProfilingTest, ProfilesAreCollected) {
#endif
}
+TEST(ProfilingTest, NullProfiler) {
+ Profiler* profiler = nullptr;
+ { SCOPED_OPERATOR_PROFILE(profiler, 1); }
+}
+
+TEST(ProfilingTest, ScopedProfile) {
+ Profiler profiler;
+ profiler.StartProfiling();
+ { SCOPED_OPERATOR_PROFILE(&profiler, 1); }
+ profiler.StopProfiling();
+ auto profile_events = profiler.GetProfileEvents();
+ EXPECT_EQ(1, profile_events.size());
+}
+
} // namespace
} // namespace profiling
} // namespace tflite
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 25ed9abd9f..57af973460 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -4711,6 +4711,7 @@ struct ModelT : public flatbuffers::NativeTable {
std::vector<std::unique_ptr<SubGraphT>> subgraphs;
std::string description;
std::vector<std::unique_ptr<BufferT>> buffers;
+ std::vector<int32_t> metadata_buffer;
ModelT()
: version(0) {
}
@@ -4723,7 +4724,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_OPERATOR_CODES = 6,
VT_SUBGRAPHS = 8,
VT_DESCRIPTION = 10,
- VT_BUFFERS = 12
+ VT_BUFFERS = 12,
+ VT_METADATA_BUFFER = 14
};
uint32_t version() const {
return GetField<uint32_t>(VT_VERSION, 0);
@@ -4740,6 +4742,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::Vector<flatbuffers::Offset<Buffer>> *buffers() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<Buffer>> *>(VT_BUFFERS);
}
+ const flatbuffers::Vector<int32_t> *metadata_buffer() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_METADATA_BUFFER);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_VERSION) &&
@@ -4754,6 +4759,8 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_BUFFERS) &&
verifier.Verify(buffers()) &&
verifier.VerifyVectorOfTables(buffers()) &&
+ VerifyOffset(verifier, VT_METADATA_BUFFER) &&
+ verifier.Verify(metadata_buffer()) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -4779,6 +4786,9 @@ struct ModelBuilder {
void add_buffers(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>> buffers) {
fbb_.AddOffset(Model::VT_BUFFERS, buffers);
}
+ void add_metadata_buffer(flatbuffers::Offset<flatbuffers::Vector<int32_t>> metadata_buffer) {
+ fbb_.AddOffset(Model::VT_METADATA_BUFFER, metadata_buffer);
+ }
explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -4797,8 +4807,10 @@ inline flatbuffers::Offset<Model> CreateModel(
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<OperatorCode>>> operator_codes = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<SubGraph>>> subgraphs = 0,
flatbuffers::Offset<flatbuffers::String> description = 0,
- flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>> buffers = 0) {
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>> buffers = 0,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> metadata_buffer = 0) {
ModelBuilder builder_(_fbb);
+ builder_.add_metadata_buffer(metadata_buffer);
builder_.add_buffers(buffers);
builder_.add_description(description);
builder_.add_subgraphs(subgraphs);
@@ -4813,14 +4825,16 @@ inline flatbuffers::Offset<Model> CreateModelDirect(
const std::vector<flatbuffers::Offset<OperatorCode>> *operator_codes = nullptr,
const std::vector<flatbuffers::Offset<SubGraph>> *subgraphs = nullptr,
const char *description = nullptr,
- const std::vector<flatbuffers::Offset<Buffer>> *buffers = nullptr) {
+ const std::vector<flatbuffers::Offset<Buffer>> *buffers = nullptr,
+ const std::vector<int32_t> *metadata_buffer = nullptr) {
return tflite::CreateModel(
_fbb,
version,
operator_codes ? _fbb.CreateVector<flatbuffers::Offset<OperatorCode>>(*operator_codes) : 0,
subgraphs ? _fbb.CreateVector<flatbuffers::Offset<SubGraph>>(*subgraphs) : 0,
description ? _fbb.CreateString(description) : 0,
- buffers ? _fbb.CreateVector<flatbuffers::Offset<Buffer>>(*buffers) : 0);
+ buffers ? _fbb.CreateVector<flatbuffers::Offset<Buffer>>(*buffers) : 0,
+ metadata_buffer ? _fbb.CreateVector<int32_t>(*metadata_buffer) : 0);
}
flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -6207,6 +6221,7 @@ inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *
{ auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->subgraphs[_i] = std::unique_ptr<SubGraphT>(_e->Get(_i)->UnPack(_resolver)); } } };
{ auto _e = description(); if (_e) _o->description = _e->str(); };
{ auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr<BufferT>(_e->Get(_i)->UnPack(_resolver)); } } };
+ { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } };
}
inline flatbuffers::Offset<Model> Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -6222,13 +6237,15 @@ inline flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_f
auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector<flatbuffers::Offset<SubGraph>> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0;
auto _description = _o->description.empty() ? 0 : _fbb.CreateString(_o->description);
auto _buffers = _o->buffers.size() ? _fbb.CreateVector<flatbuffers::Offset<Buffer>> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0;
+ auto _metadata_buffer = _o->metadata_buffer.size() ? _fbb.CreateVector(_o->metadata_buffer) : 0;
return tflite::CreateModel(
_fbb,
_version,
_operator_codes,
_subgraphs,
_description,
- _buffers);
+ _buffers,
+ _metadata_buffer);
}
inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) {
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index f92e546ab8..f16225fd66 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -364,6 +364,18 @@ cc_library(
}),
)
+tf_cc_test(
+ name = "import_tensorflow_test",
+ srcs = ["import_tensorflow_test.cc"],
+ deps = [
+ ":toco_tooling",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
cc_library(
name = "tooling_util",
srcs = [
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index 5bb0e3ba4d..166ead9184 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
+#include <cmath>
#include <memory>
#include <set>
#include <unordered_set>
@@ -63,6 +64,7 @@ struct NodeProperties {
// color will be chosen for the 'fontcolor' for the inside text
// label, see Color::TextColorString.
Color color;
+ float log2_buffer_size;
};
// All colors in this file are from:
@@ -162,9 +164,12 @@ NodeProperties GetPropertiesForArray(const Model& model,
}
node_properties.label += "]";
+ int buffer_size = RequiredBufferSizeForShape(array.shape());
+ node_properties.log2_buffer_size =
+ std::log2(static_cast<float>(buffer_size));
+
if (array.buffer) {
const auto& array = model.GetArray(array_name);
- int buffer_size = RequiredBufferSizeForShape(array.shape());
if (buffer_size <= 4) {
AppendF(&node_properties.label, " = ");
if (array.shape().dimensions_count() > 0) {
@@ -194,6 +199,8 @@ NodeProperties GetPropertiesForArray(const Model& model,
AppendF(&node_properties.label, "}");
}
}
+ } else {
+ node_properties.log2_buffer_size = 0.0f;
}
if (array.minmax) {
@@ -325,12 +332,18 @@ std::vector<const Operator*> OperatorsToDump(const Model& model) {
void DumpGraphviz(const Model& model, string* output_file_contents) {
AppendF(output_file_contents, "digraph Computegraph {\n");
+ // 'nslimit' is a graphviz (dot) paramater that limits the iterations during
+ // the layout phase. Omitting it allows infinite iterations, causing some
+ // complex graphs to never finish. A value of 125 produces good graphs
+ // while allowing complex graphs to finish.
+ AppendF(output_file_contents, "\t nslimit=125;\n");
constexpr char kNodeFormat[] =
"\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", "
"fontcolor = \"#%sDD\"];\n";
- constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n";
+ constexpr char kEdgeFormat[] =
+ "\t \"%s\" -> \"%s\" [penwidth=%f, weight=%f];\n";
constexpr char kRNNBackEdgeFormat[] =
"\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
@@ -358,7 +371,22 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
array_properties.color.FillColorString().c_str(),
array_properties.color.TextColorString().c_str());
}
- AppendF(output_file_contents, kEdgeFormat, input, operator_id);
+
+ // Draw lines that transport more data thicker (Otherwise, where would the
+ // data fit? right?).
+ float line_width =
+ std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
+ // Keep edges that transport more data shorter than those with less.
+ float weight = std::max(1.0f, array_properties.log2_buffer_size);
+ if (!IsInputArray(model, input) &&
+ GetOpWithOutput(model, input) == nullptr) {
+ // Give the main line of data flow a straighter path by penalizing edges
+ // to standalone buffers. Weights are generally very large buffers that
+ // otherwise skew the layout without this.
+ weight = 1.0f;
+ }
+ AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width,
+ weight);
already_added_arrays.insert(input);
}
// Add nodes and edges for all outputs of the operator.
@@ -374,7 +402,16 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
array_properties.color.FillColorString().c_str(),
array_properties.color.TextColorString().c_str());
}
- AppendF(output_file_contents, kEdgeFormat, operator_id, output);
+
+ // See comments above regarding weight and line_width calculations.
+ float line_width =
+ std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
+ float weight = std::max(1.0f, array_properties.log2_buffer_size);
+ if (!IsArrayConsumed(model, output)) {
+ weight = 1.0f;
+ }
+ AppendF(output_file_contents, kEdgeFormat, operator_id, output,
+ line_width, weight);
already_added_arrays.insert(output);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 45335fd78c..3f768bfee1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -146,16 +146,19 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input;
lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input;
- // Reorder LstmCell's 4 outputs.
+ // Reorder LstmCell's 3 outputs.
lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
src_op->outputs[kOutputTensor];
lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
src_op->outputs[kCellStateTensor];
- lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] =
- src_op->outputs[kScratchBufferTensor];
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
src_op->outputs[kOutputStateTensor];
+ // Create a new temp array for the fourth output.
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
// Add the op into model.
model->operators.emplace(op_it, std::move(lstm_cell_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index eca717680a..8e66323bd7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -138,10 +138,9 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionBiasTensor]),
base_name + "proj_bias");
- // Reorder LstmCell's outputs.
- lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
- lstm_cell_op->outputs[kScratchBufferTensor] =
- curr_op->outputs[LstmCellOperator::CONCAT_TEMP];
+ // Reorder and resize LstmCell's outputs.
+ lstm_cell_op->outputs.resize(
+ ExtendedLstmCellOutputs::kExtendedLstmOutputCount);
lstm_cell_op->outputs[kOutputStateTensor] =
curr_op->outputs[LstmCellOperator::ACTIV_TEMP];
lstm_cell_op->outputs[kCellStateTensor] =
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index de6d8889fb..bddb563206 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -79,8 +79,9 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
const auto* max_op =
op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1;
- CHECK_EQ(min_op->inputs.size(), 2);
- CHECK_EQ(max_op->inputs.size(), 2);
+ if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
+ return false;
+ }
if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
index 4a9974ed4e..1c32a78169 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
@@ -51,10 +51,10 @@ enum ExtendedLstmCellInputs {
};
enum ExtendedLstmCellOutputs {
- kScratchBufferTensor = 0,
- kOutputStateTensor = 1,
- kCellStateTensor = 2,
- kOutputTensor = 3
+ kOutputStateTensor = 0,
+ kCellStateTensor = 1,
+ kOutputTensor = 2,
+ kExtendedLstmOutputCount = 3
};
// Create optional array used for optional tensor in ExtendedLstmCell inputs.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index fa46e6bc38..347302c7a5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -96,6 +96,11 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
min = std::min(min, val);
max = std::max(max, val);
}
+ if (min == 0.f && max == 0.f) {
+ // Prevent downstream anger from quantized math that expects min and max
+ // to not be equal.
+ max = 1.f;
+ }
auto& minmax = array.GetOrCreateMinMax();
minmax.min = min;
minmax.max = max;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
index 95a50c6179..0dfdc40e4c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -78,6 +78,25 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
CHECK(is_input_constant[index_of_constant_input]);
CHECK(!is_input_constant[index_of_variable_input]);
+ // If this was a broadcasting op we can't remove it as we need the broadcast.
+ // It's possible we could replace it with a cheaper op, though.
+ const auto& input_array_0 = model->GetArray(binary_op->inputs[0]);
+ const auto& input_array_1 = model->GetArray(binary_op->inputs[1]);
+ if (!input_array_0.has_shape() || !input_array_1.has_shape()) {
+ // Both input shapes must be known.
+ return false;
+ }
+ if (input_array_0.shape().dimensions_count() ==
+ input_array_1.shape().dimensions_count() &&
+ input_array_0.shape() != input_array_1.shape()) {
+ AddMessageF(
+ "Preserving %s even though it's trivial as we need to broadcast "
+ "(lhs %s, rhs %s)",
+ LogName(*binary_op), ShapeToString(input_array_0.shape()),
+ ShapeToString(input_array_1.shape()));
+ return false;
+ }
+
// Now check if the constant operand makes this binary
// operator trivial.
const auto& constant_input_array =
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
index 8e6aaf544a..1956ab2d20 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -88,13 +88,11 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
// At that point we know that none of the outputs is used, so we will
// definitely remove the node and all its outputs.
- // Remove any input array that is not used by anything else,
- // and that is not the output of some other operator.
+ // Remove any input array that not the output of another op, and only used by
+ // this op.
for (const auto& input : op->inputs) {
- if (IsDiscardableArray(*model, input) &&
- CountOpsWithInput(*model, input) == 1 &&
- !GetOpWithOutput(*model, input)) {
- model->EraseArray(input);
+ if (!GetOpWithOutput(*model, input)) {
+ DeleteArrayIfUsedOnce(input, model);
}
}
@@ -102,22 +100,9 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
for (const auto& output : op->outputs) {
// If the output array is the model's input array, don't remove that.
// That's the case when cropping a model at a given --input_array.
- if (!IsDiscardableArray(*model, output)) {
- continue;
- }
- // Likewise, if the output array is a RNN state array, don't remove that.
- bool found_output_as_rnn_state_array = false;
- for (const auto& rnn_state : model->flags.rnn_states()) {
- if (output == rnn_state.state_array()) {
- found_output_as_rnn_state_array = true;
- break;
- }
- }
- if (found_output_as_rnn_state_array) {
- continue;
+ if (IsDiscardableArray(*model, output)) {
+ model->EraseArray(output);
}
- // Generic case: do delete this output array.
- model->EraseArray(output);
}
model->operators.erase(it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
index ea0d6dc820..69db1942cd 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
@@ -77,6 +77,13 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
}
}
+ int axis = op->axis;
+ if (axis < 0) {
+ // Handle negative axis
+ axis += model->GetArray(op->inputs[0]).shape().dims().size();
+ }
+ CHECK_EQ(axis, 0) << "Stacking only supported along 0th axis";
+
CHECK(!output_array.buffer);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
@@ -99,10 +106,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
// Erase input arrays if no longer used
for (const auto& input : op->inputs) {
- if (IsDiscardableArray(*model, input) &&
- CountOpsWithInput(*model, input) == 1) {
- model->EraseArray(input);
- }
+ toco::DeleteArrayIfUsedOnce(input, model);
}
// Erase the operator
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 2b413c0290..453ff29b0d 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -62,6 +62,9 @@ using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
namespace toco {
+
+using port::Status;
+
namespace {
bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
@@ -113,7 +116,7 @@ const TensorShapeProto& GetShapeAttr(const NodeDef& node,
}
const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
- CHECK(HasAttr(node, attr_name));
+ CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
const auto& attr = node.attr().at(attr_name);
CHECK_EQ(attr.value_case(), AttrValue::kTensor);
return attr.tensor();
@@ -145,9 +148,9 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kNone;
}
-void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
- tensorflow::TensorShapeProto_Dim>& input_dims,
- Shape* shape) {
+Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
+ tensorflow::TensorShapeProto_Dim>& input_dims,
+ int* input_flat_size, Shape* shape) {
std::vector<int> input_dims_only_sizes;
for (auto& d : input_dims) {
if (d.size() == 0) {
@@ -155,23 +158,33 @@ void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
// them of flat size 0 even though they have other nonzero dims.
// This breaks our invariant, that array dims can't be 0.
// For now, tweaking this to record a 0-D shape instead.
- input_dims_only_sizes.clear();
- break;
+ shape->mutable_dims()->clear();
+ if (input_flat_size != nullptr) *input_flat_size = 0;
+ return Status::OK();
+ }
+ // TensorFlow's shapes use int64s, while TOCO uses ints.
+ if (d.size() > std::numeric_limits<int>::max()) {
+ return Status(false, "Shape element overflows");
}
+
input_dims_only_sizes.push_back(d.size());
}
*shape->mutable_dims() = input_dims_only_sizes;
+
+ if (input_flat_size == nullptr) return Status::OK();
+
+ return NumElements(input_dims_only_sizes, input_flat_size);
}
-void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_float_data =
output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
@@ -189,20 +202,22 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_float_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor float_val have the right "
- "dimensions for this float tensor.";
+ return Status(false,
+ "Neither input_content nor float_val have the right "
+ "dimensions for this float tensor");
}
+ return Status::OK();
}
-void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
+Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
@@ -215,20 +230,22 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor int_val have the right "
- "dimensions for this uint8 tensor.";
+ return Status(false,
+ "Neither input_content nor int_val have the right dimensions "
+ "for this uint8 tensor");
}
+ return Status::OK();
}
-void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT32);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
@@ -241,20 +258,22 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor int_val have the right "
- "dimensions for this int32 tensor.";
+ return Status(false,
+ "Neither input_content nor int_val have the right dimensions "
+ "for this int32 tensor");
}
+ return Status::OK();
}
-void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
@@ -267,20 +286,22 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor int64_val have the right "
- "dimensions for this int64 tensor.";
+ return Status(false,
+ "Neither input_content nor int64_val have the right "
+ "dimensions for this int64 tensor");
}
+ return Status::OK();
}
-void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
+Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_BOOL);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_bool_data =
output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
@@ -300,20 +321,25 @@ void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
// assuming that 'false' is implied.
// So far only encountered that in an array with 1 entry, let's
// require that until we encounter a graph where that's not the case.
- CHECK_EQ(output_bool_data.size(), 1);
+ if (output_bool_data.size() != 1) {
+ return Status(false,
+ "Neither input_content nor bool_val have the right "
+ "dimensions for this bool tensor");
+ }
output_bool_data[0] = false;
}
+ return Status::OK();
}
-void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
+Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_string_data =
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
@@ -324,6 +350,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
+ return Status::OK();
}
// Count the number of inputs of a given node. If
@@ -363,38 +390,40 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
-void ConvertConstOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+Status ConvertConstOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype");
+ Status status = Status::OK();
+
auto& array = model->GetOrCreateArray(node.name());
switch (dtype) {
case DT_FLOAT:
array.data_type = ArrayDataType::kFloat;
- ImportFloatArray(tensor, &array);
+ status = ImportFloatArray(tensor, &array);
break;
case DT_INT32:
array.data_type = ArrayDataType::kInt32;
- ImportInt32Array(tensor, &array);
+ status = ImportInt32Array(tensor, &array);
break;
case DT_QUINT8:
array.data_type = ArrayDataType::kUint8;
- ImportQuint8Array(tensor, &array);
+ status = ImportQuint8Array(tensor, &array);
break;
case DT_INT64:
array.data_type = ArrayDataType::kInt64;
- ImportInt64Array(tensor, &array);
+ status = ImportInt64Array(tensor, &array);
break;
case DT_STRING:
array.data_type = ArrayDataType::kString;
- ImportStringArray(tensor, &array);
+ status = ImportStringArray(tensor, &array);
break;
case DT_BOOL:
array.data_type = ArrayDataType::kBool;
- ImportBoolArray(tensor, &array);
+ status = ImportBoolArray(tensor, &array);
break;
default:
array.data_type = ArrayDataType::kNone;
@@ -404,6 +433,10 @@ void ConvertConstOperator(const NodeDef& node,
array.GetMutableBuffer<ArrayDataType::kNone>();
break;
}
+ if (!status.ok()) {
+ status.AppendMessage(" (while processing node '" + node.name() + "')");
+ }
+ return status;
}
void ConvertConvOperator(const NodeDef& node,
@@ -451,8 +484,18 @@ void ConvertConvOperator(const NodeDef& node,
if (HasAttr(node, "dilations")) {
const auto& dilations = GetListAttr(node, "dilations");
CHECK_EQ(dilations.i_size(), 4);
- CHECK_EQ(dilations.i(0), 1);
- CHECK_EQ(dilations.i(3), 1);
+ CHECK_EQ(dilations.i(0), 1)
+ << "Can only import Conv ops with dilation along the height (1st) or "
+ "width (2nd) axis. TensorFlow op \""
+ << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
+ << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
+ << "].";
+ CHECK_EQ(dilations.i(3), 1)
+ << "Can only import Conv ops with dilation along the height (1st) or "
+ "width (2nd) axis. TensorFlow op \""
+ << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
+ << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
+ << "].";
conv->dilation_height_factor = dilations.i(1);
conv->dilation_width_factor = dilations.i(2);
} else {
@@ -2023,6 +2066,186 @@ void ConvertDynamicStitchOperator(const NodeDef& node,
} // namespace
+namespace internal {
+Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ // TODO(ahentz): Historically these functions all CHECK-fail on error. We've
+ // been slowly converting them to return Status.
+ if (node.op() == "Const") {
+ return ConvertConstOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Conv2D") {
+ ConvertConvOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Conv2DBackpropInput") {
+ ConvertTransposeConvOperator(node, tf_import_flags, model);
+ } else if (node.op() == "DepthwiseConv2dNative") {
+ ConvertDepthwiseConvOperator(node, tf_import_flags, model);
+ } else if (node.op() == "DepthToSpace") {
+ ConvertDepthToSpaceOperator(node, tf_import_flags, model);
+ } else if (node.op() == "SpaceToDepth") {
+ ConvertSpaceToDepthOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BiasAdd") {
+ ConvertBiasAddOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Relu") {
+ ConvertReluOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Relu6") {
+ ConvertRelu6Operator(node, tf_import_flags, model);
+ } else if (node.op() == "Sigmoid") {
+ ConvertLogisticOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Tanh") {
+ ConvertTanhOperator(node, tf_import_flags, model);
+ } else if (node.op() == "MaxPool") {
+ ConvertMaxPoolOperator(node, tf_import_flags, model);
+ } else if (node.op() == "AvgPool") {
+ ConvertAvgPoolOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Reshape") {
+ ConvertReshapeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BatchMatMul") {
+ ConvertBatchMatMulOperator(node, tf_import_flags, model);
+ } else if (node.op() == "MatMul") {
+ ConvertMatMulOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Div" || node.op() == "RealDiv") {
+ ConvertDivOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
+ node.op() == "StopGradient") {
+ ConvertIdentityOperator(node, tf_import_flags, model);
+ } else if (node.op() == "FakeQuantWithMinMaxVars") {
+ ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
+ } else if (node.op() == "FakeQuantWithMinMaxArgs") {
+ ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
+ } else if (node.op() == "Neg") {
+ ConvertNegOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Rsqrt") {
+ ConvertRsqrtOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Squeeze") {
+ ConvertSqueezeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Sqrt") {
+ ConvertSqrtOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Square") {
+ ConvertSquareOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Add") {
+ ConvertAddOperator(node, tf_import_flags, model);
+ } else if (node.op() == "AddN") {
+ ConvertAddNOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Mul") {
+ ConvertMulOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Sub") {
+ ConvertSubOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Sum") {
+ ConvertSumOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Tile") {
+ ConvertTileOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
+ ConvertConcatOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LRN") {
+ ConvertLRNOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Softmax") {
+ ConvertSoftmaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Log") {
+ ConvertLogOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LogSoftmax") {
+ ConvertLogSoftmaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "All") {
+ ConvertAllOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Assert") {
+ ConvertAssertOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Less") {
+ ConvertLessOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LessEqual") {
+ ConvertLessEqualOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Greater") {
+ ConvertGreaterOperator(node, tf_import_flags, model);
+ } else if (node.op() == "GreaterEqual") {
+ ConvertGreaterEqualOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Max") {
+ ConvertMaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Min") {
+ ConvertMinOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Maximum") {
+ ConvertMaximumOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Minimum") {
+ ConvertMinimumOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Merge") {
+ ConvertMergeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Pad") {
+ ConvertPadOperator(node, tf_import_flags, model);
+ } else if (node.op() == "StridedSlice") {
+ ConvertStridedSliceOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Shape") {
+ ConvertShapeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Slice") {
+ ConvertSliceOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Split") {
+ ConvertSplitOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Switch") {
+ ConvertSwitchOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Placeholder") {
+ ConvertPlaceholderOperator(node, tf_import_flags, model);
+ } else if (node.op() == "PlaceholderWithDefault") {
+ ConvertIdentityOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LegacyFedInput") {
+ ConvertPlaceholderOperator(node, tf_import_flags, model);
+ } else if (node.op() == "NoOp") {
+ ConvertNoOpOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Cast") {
+ ConvertCastOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Floor") {
+ ConvertFloorOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Gather" || node.op() == "GatherV2") {
+ ConvertGatherOperator(node, tf_import_flags, model);
+ } else if (node.op() == "ResizeBilinear") {
+ ConvertResizeBilinearOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BatchNormWithGlobalNormalization") {
+ ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
+ model);
+ } else if (node.op() == "FusedBatchNorm") {
+ ConvertFusedBatchNormOperator(node, tf_import_flags, model);
+ } else if (node.op() == "SpaceToBatchND") {
+ ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BatchToSpaceND") {
+ ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Mean") {
+ ConvertMeanOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Svdf") {
+ ConvertSvdfOperator(node, tf_import_flags, model);
+ } else if (node.op() == "NextIteration") {
+ ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
+ } else if (node.op() == "ExpandDims") {
+ ConvertExpandDimsOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Fill") {
+ ConvertFillOperator(node, tf_import_flags, model);
+ } else if (node.op() == "FloorDiv") {
+ ConvertFloorDivOperator(node, tf_import_flags, model);
+ } else if (node.op() == "FloorMod") {
+ ConvertFloorModOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Range") {
+ ConvertRangeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Rank") {
+ ConvertRankOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Stack" || node.op() == "Pack") {
+ ConvertStackOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Transpose") {
+ ConvertTransposeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "ArgMax") {
+ ConvertArgMaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Exp") {
+ ConvertExpOperator(node, tf_import_flags, model);
+ } else if (node.op() == "TopK" || node.op() == "TopKV2") {
+ ConvertTopKV2Operator(node, tf_import_flags, model);
+ } else if (node.op() == "DynamicPartition") {
+ ConvertDynamicPartitionOperator(node, tf_import_flags, model);
+ } else if (node.op() == "DynamicStitch" ||
+ node.op() == "ParallelDynamicStitch") {
+ ConvertDynamicStitchOperator(node, tf_import_flags, model);
+ } else if (node.op() == "RandomUniform") {
+ ConvertRandomUniform(node, tf_import_flags, model);
+ } else {
+ ConvertUnsupportedOperator(node, tf_import_flags, model);
+ }
+ return Status::OK();
+}
+} // namespace internal
+
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const GraphDef& tf_graph) {
@@ -2048,176 +2271,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
- if (node.op() == "Const") {
- ConvertConstOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2D") {
- ConvertConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2DBackpropInput") {
- ConvertTransposeConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthwiseConv2dNative") {
- ConvertDepthwiseConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthToSpace") {
- ConvertDepthToSpaceOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToDepth") {
- ConvertSpaceToDepthOperator(node, tf_import_flags, model);
- } else if (node.op() == "BiasAdd") {
- ConvertBiasAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu") {
- ConvertReluOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu6") {
- ConvertRelu6Operator(node, tf_import_flags, model);
- } else if (node.op() == "Sigmoid") {
- ConvertLogisticOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tanh") {
- ConvertTanhOperator(node, tf_import_flags, model);
- } else if (node.op() == "MaxPool") {
- ConvertMaxPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "AvgPool") {
- ConvertAvgPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "Reshape") {
- ConvertReshapeOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchMatMul") {
- ConvertBatchMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "MatMul") {
- ConvertMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Div" || node.op() == "RealDiv") {
- ConvertDivOperator(node, tf_import_flags, model);
- } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
- node.op() == "StopGradient") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxVars") {
- ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxArgs") {
- ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
- } else if (node.op() == "Neg") {
- ConvertNegOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rsqrt") {
- ConvertRsqrtOperator(node, tf_import_flags, model);
- } else if (node.op() == "Squeeze") {
- ConvertSqueezeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sqrt") {
- ConvertSqrtOperator(node, tf_import_flags, model);
- } else if (node.op() == "Square") {
- ConvertSquareOperator(node, tf_import_flags, model);
- } else if (node.op() == "Add") {
- ConvertAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "AddN") {
- ConvertAddNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mul") {
- ConvertMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sub") {
- ConvertSubOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sum") {
- ConvertSumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tile") {
- ConvertTileOperator(node, tf_import_flags, model);
- } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
- ConvertConcatOperator(node, tf_import_flags, model);
- } else if (node.op() == "LRN") {
- ConvertLRNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Softmax") {
- ConvertSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Log") {
- ConvertLogOperator(node, tf_import_flags, model);
- } else if (node.op() == "LogSoftmax") {
- ConvertLogSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "All") {
- ConvertAllOperator(node, tf_import_flags, model);
- } else if (node.op() == "Assert") {
- ConvertAssertOperator(node, tf_import_flags, model);
- } else if (node.op() == "Less") {
- ConvertLessOperator(node, tf_import_flags, model);
- } else if (node.op() == "LessEqual") {
- ConvertLessEqualOperator(node, tf_import_flags, model);
- } else if (node.op() == "Greater") {
- ConvertGreaterOperator(node, tf_import_flags, model);
- } else if (node.op() == "GreaterEqual") {
- ConvertGreaterEqualOperator(node, tf_import_flags, model);
- } else if (node.op() == "Max") {
- ConvertMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Min") {
- ConvertMinOperator(node, tf_import_flags, model);
- } else if (node.op() == "Maximum") {
- ConvertMaximumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Minimum") {
- ConvertMinimumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Merge") {
- ConvertMergeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Pad") {
- ConvertPadOperator(node, tf_import_flags, model);
- } else if (node.op() == "StridedSlice") {
- ConvertStridedSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Shape") {
- ConvertShapeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Slice") {
- ConvertSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Split") {
- ConvertSplitOperator(node, tf_import_flags, model);
- } else if (node.op() == "Switch") {
- ConvertSwitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "Placeholder") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "PlaceholderWithDefault") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "LegacyFedInput") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "NoOp") {
- ConvertNoOpOperator(node, tf_import_flags, model);
- } else if (node.op() == "Cast") {
- ConvertCastOperator(node, tf_import_flags, model);
- } else if (node.op() == "Floor") {
- ConvertFloorOperator(node, tf_import_flags, model);
- } else if (node.op() == "Gather" || node.op() == "GatherV2") {
- ConvertGatherOperator(node, tf_import_flags, model);
- } else if (node.op() == "ResizeBilinear") {
- ConvertResizeBilinearOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchNormWithGlobalNormalization") {
- ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
- model);
- } else if (node.op() == "FusedBatchNorm") {
- ConvertFusedBatchNormOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToBatchND") {
- ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchToSpaceND") {
- ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mean") {
- ConvertMeanOperator(node, tf_import_flags, model);
- } else if (node.op() == "Svdf") {
- ConvertSvdfOperator(node, tf_import_flags, model);
- } else if (node.op() == "NextIteration") {
- ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
- } else if (node.op() == "ExpandDims") {
- ConvertExpandDimsOperator(node, tf_import_flags, model);
- } else if (node.op() == "Fill") {
- ConvertFillOperator(node, tf_import_flags, model);
- } else if (node.op() == "FloorDiv") {
- ConvertFloorDivOperator(node, tf_import_flags, model);
- } else if (node.op() == "FloorMod") {
- ConvertFloorModOperator(node, tf_import_flags, model);
- } else if (node.op() == "Range") {
- ConvertRangeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rank") {
- ConvertRankOperator(node, tf_import_flags, model);
- } else if (node.op() == "Stack" || node.op() == "Pack") {
- ConvertStackOperator(node, tf_import_flags, model);
- } else if (node.op() == "Transpose") {
- ConvertTransposeOperator(node, tf_import_flags, model);
- } else if (node.op() == "ArgMax") {
- ConvertArgMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Exp") {
- ConvertExpOperator(node, tf_import_flags, model);
- } else if (node.op() == "TopK" || node.op() == "TopKV2") {
- ConvertTopKV2Operator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicPartition") {
- ConvertDynamicPartitionOperator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicStitch" ||
- node.op() == "ParallelDynamicStitch") {
- ConvertDynamicStitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "RandomUniform") {
- ConvertRandomUniform(node, tf_import_flags, model);
- } else {
- ConvertUnsupportedOperator(node, tf_import_flags, model);
- }
+ auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model);
+ CHECK(status.ok()) << status.error_message();
}
ResolveModelFlags(model_flags, model);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
new file mode 100644
index 0000000000..5dc78f73ad
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/import_tensorflow.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+
+using port::Status;
+using tensorflow::AttrValue;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_QUINT8;
+using tensorflow::DT_STRING;
+using tensorflow::NodeDef;
+
+namespace internal {
+Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
+ Model*);
+} // namespace internal
+
+namespace {
+
+class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
+ protected:
+ ShapeImportTest() {}
+
+ void BuildConstNode(std::initializer_list<int64_t> shape,
+ tensorflow::DataType dtype, int64_t num_elements,
+ NodeDef* node) {
+ node->set_op("Const");
+ node->set_name("Node1");
+
+ // An attribute describing the type of this const node.
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["dtype"] = dtype_attr;
+
+ // An attribute describing the content of this const node.
+ tensorflow::TensorProto t;
+ t.set_dtype(dtype);
+ auto* s = t.mutable_tensor_shape();
+ for (auto d : shape) {
+ s->add_dim()->set_size(d);
+ }
+
+ // TODO(ahentz): also need to test via tensor_content()
+ switch (dtype) {
+ case DT_FLOAT:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_float_val(i / 10000.0);
+ }
+ break;
+ case DT_INT32:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int_val(i % std::numeric_limits<int>::max());
+ }
+ break;
+ case DT_QUINT8:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int_val(i % std::numeric_limits<uint8_t>::max());
+ }
+ break;
+ case DT_INT64:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int64_val(i);
+ }
+ break;
+ case DT_STRING:
+ break;
+ case DT_BOOL:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_bool_val(i % 2);
+ }
+ break;
+ default:
+ break;
+ }
+
+ AttrValue value_attr;
+ SetAttrValue(t, &value_attr);
+ (*node->mutable_attr())["value"] = value_attr;
+ }
+
+ Status ImportNode(const NodeDef& node) {
+ Model model;
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
+ &model);
+ }
+};
+
+std::vector<tensorflow::DataType> TestTypes() {
+ return {DT_FLOAT, DT_INT32, DT_INT64, DT_BOOL, DT_QUINT8};
+}
+
+TEST_P(ShapeImportTest, ShapeElementIsNegative) {
+ NodeDef node;
+ BuildConstNode({1, -2, 10}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_EQ(status.error_message(),
+ "Tensor shape should not include negative values (while processing "
+ "node 'Node1')");
+}
+INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+TEST_P(ShapeImportTest, ShapeElementTooLarge) {
+ NodeDef node;
+ BuildConstNode({3000000000}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_EQ(status.error_message(),
+ "Shape element overflows (while processing node 'Node1')");
+}
+INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+TEST_P(ShapeImportTest, ShapeTooLarge) {
+ NodeDef node;
+ BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_EQ(status.error_message(),
+ "Tensor shape is too large (while processing node 'Node1')");
+}
+INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
+ NodeDef node;
+ BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_THAT(status.error_message(),
+ ::testing::MatchesRegex(
+ "Neither input_content nor .*_val have the right dimensions "
+ "for this .* tensor .while processing node 'Node1'."));
+}
+INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+} // namespace
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
index 2d5c231bef..906792ef56 100644
--- a/tensorflow/contrib/lite/toco/toco_port.h
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -38,10 +38,15 @@ namespace port {
class Status {
public:
+ static Status OK() { return Status(true, ""); }
+
+ // Create a failed status with no message.
Status() {}
Status(bool ok, const string& message) : ok_(ok), message_(message) {}
+ void AppendMessage(const string& message) { message_ += message; }
+
bool ok() const { return ok_; }
const string error_message() const { return message_; }
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 5cc15fa57b..f5b596df0f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -294,6 +294,35 @@ void FinishBuildingRNNStates(Model* model);
void UseArraysExtraInfo(Model* model, bool quantize_output);
+// Calculates the number of elements in tensor given a shape. Shape elements
+// are assumed to be of type T, while the result total is of type U. If U
+// doesn't have enough range to represent the sum of elements, an error is
+// returned.
+template <typename T, typename U>
+port::Status NumElements(const std::vector<T>& shape, U* num_elements) {
+ static_assert(
+ std::numeric_limits<T>::max() <= std::numeric_limits<uint64_t>::max(),
+ "vector type exceed capabilities of NumElements");
+
+ *num_elements = 1;
+ for (const T& dim : shape) {
+ if (dim < 0) {
+ // TensorFlow's shapes sometimes include -1 to represent an "unknown"
+ // size but TOCO isn't able to create arrays of unknown sizes and will
+ // crash in RequiredBufferSizeForShape().
+ return port::Status(false,
+ "Tensor shape should not include negative values");
+ }
+ if (static_cast<uint64_t>(dim) >
+ std::numeric_limits<U>::max() / *num_elements) {
+ *num_elements = 0;
+ return port::Status(false, "Tensor shape is too large");
+ }
+ *num_elements *= dim;
+ }
+ return port::Status::OK();
+}
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
index 22955ce956..87fd30db2c 100644
--- a/tensorflow/contrib/lite/toco/tooling_util_test.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -93,4 +93,85 @@ TEST_P(ShapeTest, Agrees) {
INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest,
::testing::ValuesIn(CreateShapePairs()));
+static const char kNegativeValuesMessage[] =
+ "Tensor shape should not include negative values";
+static const char kLargeTensorMessage[] = "Tensor shape is too large";
+
+TEST(NumElementsTest, Int) {
+ int count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<int>{1024, 1024, 2047}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 2146435072);
+
+ status = NumElements(std::vector<int>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<int>{1024, 1024, 2048}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, Int32) {
+ int32_t count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 2146435072);
+
+ status = NumElements(std::vector<int32_t>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<int32_t>{1024, 1024, 2048}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, Int64) {
+ int64_t count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 9223090561878065152LL);
+
+ status = NumElements(std::vector<int64_t>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<int64_t>{16777216, 16777216, 32768}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, UnsignedInt32) {
+ uint32_t count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 4292870144);
+
+ status = NumElements(std::vector<int>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<uint32_t>{1024, 2048, 2048}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, UnsignedInt64) {
+ uint64_t count;
+ port::Status status = port::Status::OK();
+
+ status =
+ NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 18446462598732840960ULL);
+
+ status = NumElements(std::vector<int>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status =
+ NumElements(std::vector<uint64_t>{16777216, 16777216, 65536}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index 64cc8c7ea5..f132050153 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -119,7 +119,7 @@ class FrameTest(test.TestCase):
frame_step = 1
result = shape_ops.frame(signal, frame_length, frame_step,
pad_end=True, pad_value=99, axis=1)
- self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
+ self.assertEqual([1, 2, None, 3, 4], result.shape.as_list())
result = shape_ops.frame(signal, frame_length, frame_step,
pad_end=False, axis=1)
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
index 1ddc2941ec..91862f0cc0 100644
--- a/tensorflow/contrib/signal/python/ops/shape_ops.py
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -43,13 +43,13 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
outer_dimensions = signal_shape[:axis]
inner_dimensions = signal_shape[axis:][1:]
if signal_shape and frame_axis is not None:
- if frame_step and frame_length is not None:
- if pad_end:
- # Double negative is so that we round up.
- num_frames = -(-frame_axis // frame_step)
- else:
- num_frames = (frame_axis - frame_length + frame_step) // frame_step
- num_frames = max(0, num_frames)
+ if frame_step is not None and pad_end:
+ # Double negative is so that we round up.
+ num_frames = max(0, -(-frame_axis // frame_step))
+ elif frame_step is not None and frame_length is not None:
+ assert not pad_end
+ num_frames = max(
+ 0, (frame_axis - frame_length + frame_step) // frame_step)
return outer_dimensions + [num_frames, frame_length] + inner_dimensions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
index 5d33e23a42..3c07a74ed8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
@@ -176,8 +176,9 @@ py_library(
py_test(
name = "structural_ensemble_test",
- timeout = "long", # Moderate but for asan/tsan timeouts
+ timeout = "long", # Moderate but for asan/tsan/msan timeouts
srcs = ["structural_ensemble_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
deps = [
":state_space_model",
diff --git a/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt
new file mode 100644
index 0000000000..067ad4018b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "GroupByReducerDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "key_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `key_func`.
+END
+ }
+ attr {
+ name: "key_func"
+ description: <<END
+A function mapping an element of `input_dataset`, concatenated
+with `key_func_other_arguments` to a scalar value of type DT_INT64.
+END
+ }
+ in_arg {
+ name: "init_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `init_func`.
+END
+ }
+ attr {
+ name: "init_func"
+ description: <<END
+A function mapping a key of type DT_INT64, concatenated with
+`init_func_other_arguments` to the initial reducer state.
+END
+ }
+ in_arg {
+ name: "reduce_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `reduce_func`.
+END
+ }
+ attr {
+ name: "reduce_func"
+ description: <<END
+A function mapping the current reducer state and an element of `input_dataset`,
+concatenated with `reduce_func_other_arguments` to a new reducer state.
+END
+ }
+ in_arg {
+ name: "finalize_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `finalize_func`.
+END
+ }
+ attr {
+ name: "finalize_func"
+ description: <<END
+A function mapping the final reducer state to an output element.
+END
+ }
+ summary: "Creates a dataset that computes a group-by on `input_dataset`."
+ description: <<END
+Creates a dataset that computes a group-by on `input_dataset`.
+END
+}
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 5918cd9bbf..b537666492 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -51,6 +51,8 @@ limitations under the License.
namespace tensorflow {
+class DeviceMgr;
+
class Device : public DeviceBase {
public:
Device(Env* env, const DeviceAttributes& device_attributes);
@@ -133,6 +135,10 @@ class Device : public DeviceBase {
// Returns the resource manager associated w/ this device.
virtual ResourceMgr* resource_manager() { return rmgr_; }
+ // Returns the device manager that owns this device, or nullptr if this Device
+ // is not owned by a device manager.
+ DeviceMgr* device_mgr() const { return device_mgr_; }
+
// Summarizes the status of this Device, for debugging.
string DebugString() const { return ProtoDebugString(device_attributes_); }
@@ -158,6 +164,11 @@ class Device : public DeviceBase {
}
private:
+ friend class DeviceMgr;
+
+ // Pointer to the device manager that owns this device. Not owned.
+ DeviceMgr* device_mgr_ = nullptr;
+
const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_;
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index a77601ba79..470abc1431 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -27,6 +27,9 @@ namespace tensorflow {
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
: name_backing_store_(128) {
for (Device* d : devices) {
+ CHECK(d->device_mgr_ == nullptr);
+ d->device_mgr_ = this;
+
devices_.push_back(d);
// Register under the (1) full name and (2) canonical name.
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 0a4895a938..a63b2b9711 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 46ec550c78..f78d197fd5 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -32,6 +32,10 @@ limitations under the License.
namespace tensorflow {
+// Forward declaration for proto class NodeExecStats so we do not need to
+// include the proto header
+class NodeExecStats;
+
// KernelAndDevice encapsulates an instantiated kernel and the device it is on.
//
// Also see:
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 0c461a9ee9..e389eb9b2a 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -322,6 +322,7 @@ class GraphView {
void Initialize(const Graph* g);
Status SetAllocAttrs(const Graph* g, const Device* device);
+ void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
NodeItem* node(size_t id) const {
DCHECK_GE(id, 0);
@@ -566,11 +567,46 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
DCHECK_EQ(item->input_type(i), n->input_type(i));
}
- uint8* output_types = item->output_type_base();
- for (int i = 0; i < num_outputs; i++) {
- output_types[i] = static_cast<uint8>(n->output_type(i));
- DCHECK_EQ(item->output_type(i), n->output_type(i));
+ // Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
+ {
+ std::vector<int> forward_input;
+ Status fwd_status =
+ GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
+ std::vector<int> scoped_allocator_attrs;
+ Status sa_status =
+ GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
+
+ int* forward_from = item->forward_from_base();
+ uint8* output_types = item->output_type_base();
+ for (int i = 0; i < num_outputs; ++i) {
+ output_types[i] = static_cast<uint8>(n->output_type(i));
+ DCHECK_EQ(item->output_type(i), n->output_type(i));
+
+ forward_from[i] = OpKernelContext::Params::kNoReservation;
+ if (sa_status.ok()) {
+ for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
+ if (scoped_allocator_attrs[j] == i) {
+ // This output slot must be explicitly allocated from a
+ // ScopedAllocator.
+ forward_from[i] = OpKernelContext::Params::kNeverForward;
+ DCHECK_EQ(output_attrs[i].scope_id, 0);
+ output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
+ }
+ }
+ }
+ if (fwd_status.ok() && forward_from[i] == -1) {
+ DCHECK_EQ(forward_input.size() % 2, 0);
+ for (int j = 0; j < forward_input.size(); j += 2) {
+ if (forward_input[j + 1] == i) {
+ DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
+ forward_from[i] = forward_input[j];
+ break;
+ }
+ }
+ }
+ }
}
+
return ptr;
}
@@ -696,22 +732,85 @@ Status ExecutorImpl::Initialize() {
return gview_.SetAllocAttrs(graph_.get(), params_.device);
}
+// If a Node has been marked to use a ScopedAllocator x for output i, then
+// sc_attr will contain the subsequence (i, x) at an even offset. This function
+// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
+// only allow one ScopedAllocator use per Node.
+bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
+ int output_index,
+ AllocatorAttributes* alloc_attr) {
+ DCHECK_LE(2, sc_attr.size());
+ for (int i = 0; i < sc_attr.size(); i += 2) {
+ if (sc_attr[i] == output_index) {
+ CHECK_EQ(alloc_attr->scope_id, 0);
+ alloc_attr->scope_id = sc_attr[i + 1];
+ return true;
+ }
+ }
+ return false;
+}
+
+void GraphView::SetScopedAllocatorAttrs(
+ const std::vector<const Node*>& sa_nodes) {
+ for (const Node* sa : sa_nodes) {
+ NodeItem* sa_item = node(sa->id());
+ AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
+ // Control edges out of the ScopedAllocator should be use instances, but may
+ // include a few other nodes.
+ for (const auto& e : sa->out_edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+ Node* use_node = e->dst();
+ NodeItem* item = node(use_node->id());
+ AllocatorAttributes* use_attrs = item->output_attr_base();
+ std::vector<int> scoped_allocator_attrs;
+ Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
+ &scoped_allocator_attrs);
+ if (!s.ok()) {
+ VLOG(2) << "Failed to find expected ScopedAllocator attr on "
+ << use_node->name();
+ continue;
+ }
+ // There should be exactly one output using ScopedAllocation.
+ for (const auto& e : use_node->out_edges()) {
+ if (!e->IsControlEdge()) {
+ AllocatorAttributes attr;
+ if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
+ e->src_output(), &attr)) {
+ // Set the scope_id on this use instance node.
+ (use_attrs + e->src_output())->Merge(attr);
+ // Propagate the other attributes of this node back to the SA node.
+ attr = *(use_attrs + e->src_output());
+ attr.scope_id = 0;
+ sa_attrs->Merge(attr);
+ }
+ }
+ }
+ }
+ }
+}
+
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
Status s;
DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
+ std::vector<const Node*> scoped_allocator_instances;
for (const Node* n : g->nodes()) {
NodeItem* item = node(n->id());
AllocatorAttributes* attrs = item->output_attr_base();
+ if (IsScopedAllocator(n)) {
+ scoped_allocator_instances.push_back(n);
+ }
// Examine the out edges of each node looking for special use
// cases that may affect memory allocation attributes.
- for (auto e : n->out_edges()) {
+ for (const auto& e : n->out_edges()) {
if (!e->IsControlEdge()) {
AllocatorAttributes attr;
s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
if (!s.ok()) return s;
- if (attr.value != 0) {
+ if (attr.value != 0 || attr.scope_id != 0) {
attrs[e->src_output()].Merge(attr);
}
}
@@ -728,6 +827,7 @@ Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
}
}
}
+ SetScopedAllocatorAttrs(scoped_allocator_instances);
return s;
}
@@ -1614,7 +1714,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
params.is_input_dead = is_input_dead;
params.output_attr_array = item.output_attrs();
- params.forward_from_array = nullptr; // later: item.forward_from();
+ params.forward_from_array = item.forward_from();
if (item.kernel_is_async) {
// Asynchronous computes.
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index e61ed8c479..668ce87749 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -144,7 +144,8 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
- if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
+ if (device_type == "CPU" || device_type == "TPU_SYSTEM" ||
+ device_type == "TPU") {
// "TPU_SYSTEM" indicates that `device` is a CPU.
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 343dd5d456..256ce527a4 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -452,6 +452,81 @@ cc_library(
],
)
+cc_library(
+ name = "collective_param_resolver_distributed",
+ srcs = ["collective_param_resolver_distributed.cc"],
+ hdrs = ["collective_param_resolver_distributed.h"],
+ deps = [
+ ":call_options",
+ ":device_resolver_distributed",
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "test_utils",
+ srcs = [],
+ hdrs = ["test_utils.h"],
+ deps = [
+ ":worker_cache",
+ ":worker_interface",
+ ],
+)
+
+tf_cc_test(
+ name = "collective_param_resolver_distributed_test",
+ size = "small",
+ srcs = ["collective_param_resolver_distributed_test.cc"],
+ deps = [
+ ":collective_param_resolver_distributed",
+ ":device_resolver_distributed",
+ ":test_utils",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
+ name = "device_resolver_distributed",
+ srcs = ["device_resolver_distributed.cc"],
+ hdrs = ["device_resolver_distributed.h"],
+ deps = [
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "device_resolver_distributed_test",
+ size = "small",
+ srcs = ["device_resolver_distributed_test.cc"],
+ deps = [
+ ":device_resolver_distributed",
+ ":test_utils",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
# on grpc_testlib.
tf_cuda_cc_tests(
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
new file mode 100644
index 0000000000..ecf5db8110
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -0,0 +1,404 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+// TODO(tucker): When we're ready to enable collectives this const will
+// transition to a settable config member.
+static const char FLAGS_collective_group_leader[] =
+ "/job:worker/replica:0/task:0";
+
+namespace tensorflow {
+namespace {
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager. Note that ParamResolverInterface
+// calls are done on behalf of an Op execution which needs to abort if the
+// step in which it executes is cancelled.
+class CancellableCall {
+ public:
+ CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+ WorkerCacheInterface* wc)
+ : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
+ wi_ = wc_->CreateWorker(remote_worker_);
+ }
+ virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+ virtual void IssueCall(const StatusCallback& done) = 0;
+
+ void Start(const StatusCallback& done) {
+ CancellationToken token = cancel_mgr_->get_cancellation_token();
+ const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+ token, [this, token]() { opts_.StartCancel(); });
+ if (not_yet_cancelled) {
+ IssueCall([this, token, done](const Status& s) {
+ cancel_mgr_->DeregisterCallback(token);
+ done(s);
+ });
+ } else {
+ done(errors::Cancelled("RPC Request was cancelled"));
+ }
+ }
+
+ protected:
+ mutable mutex mu_;
+ CancellationManager* cancel_mgr_; // Not owned
+ const string remote_worker_;
+ WorkerCacheInterface* wc_; // Not owned
+ WorkerInterface* wi_; // Owned by wc_, must be released.
+ CallOptions opts_;
+};
+
+class CompleteGroupCall : public CancellableCall {
+ public:
+ CompleteGroupCall(const CollGroupParams& group, const string& device_name,
+ CancellationManager* cancel_mgr,
+ const string& remote_worker, WorkerCacheInterface* wc)
+ : CancellableCall(cancel_mgr, remote_worker, wc) {
+ req_.set_group_key(group.group_key);
+ req_.set_group_size(group.group_size);
+ req_.set_device_type(group.device_type.type_string());
+ req_.add_device_name(device_name);
+ }
+ ~CompleteGroupCall() override {}
+
+ void IssueCall(const StatusCallback& done) override {
+ wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
+ }
+
+ CompleteGroupRequest req_;
+ CompleteGroupResponse resp_;
+};
+
+class CompleteInstanceCall : public CancellableCall {
+ public:
+ CompleteInstanceCall(const CollGroupParams& group,
+ const CollInstanceParams& instance,
+ const string& node_name, const string& device_name,
+ bool is_source, CancellationManager* cancel_mgr,
+ const string& remote_worker, WorkerCacheInterface* wc)
+ : CancellableCall(cancel_mgr, remote_worker, wc) {
+ req_.set_name(node_name);
+ req_.set_type(instance.type);
+ req_.set_data_type(instance.data_type);
+ instance.shape.AsProto(req_.mutable_shape());
+ req_.set_group_key(group.group_key);
+ req_.set_group_size(group.group_size);
+ req_.set_instance_key(instance.instance_key);
+ req_.set_device_type(group.device_type.type_string());
+ for (int32 offset : instance.impl_details.subdiv_offsets) {
+ req_.add_subdiv_offset(offset);
+ }
+ req_.set_device(device_name);
+ req_.set_is_source(is_source);
+ }
+
+ ~CompleteInstanceCall() override {}
+
+ void IssueCall(const StatusCallback& done) override {
+ wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
+ }
+
+ CompleteInstanceRequest req_;
+ CompleteInstanceResponse resp_;
+};
+
+} // namespace
+
+CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ DeviceResolverDistributed* dev_resolver, WorkerCacheInterface* worker_cache,
+ const string& task_name)
+ : CollectiveParamResolverLocal(dev_mgr, dev_resolver, task_name),
+ worker_cache_(worker_cache),
+ group_leader_(task_name == FLAGS_collective_group_leader
+ ? ""
+ : FLAGS_collective_group_leader) {}
+
+void CollectiveParamResolverDistributed::CompleteParamsAsync(
+ const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+ const StatusCallback& done) {
+ CompleteGroupDistributed(device, cp, cancel_mgr,
+ [this, device, cp, cancel_mgr, done](
+ const Status& s, const GroupRec* gr) {
+ if (s.ok()) {
+ CompleteInstanceDistributed(device, gr, cp,
+ cancel_mgr, done);
+ } else {
+ done(s);
+ }
+ });
+}
+
+void CollectiveParamResolverDistributed::CompleteGroupAsync(
+ const CompleteGroupRequest* request, CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ CollectiveParams cp;
+ cp.group.group_key = request->group_key();
+ cp.group.group_size = request->group_size();
+ cp.group.device_type = DeviceType(request->device_type());
+ for (const string& dn : request->device_name()) {
+ cp.instance.device_names.push_back(dn);
+ }
+ CompleteGroupDistributed(
+ cp.instance.device_names[0], &cp, cancel_mgr,
+ [this, response, done](const Status& s, const GroupRec* gr) {
+ if (s.ok()) {
+ mutex_lock l(gr->mu);
+ response->set_group_key(gr->group.group_key);
+ response->set_group_size(gr->group.group_size);
+ response->set_device_type(gr->group.device_type.type_string());
+ response->set_num_tasks(gr->task_set.size());
+ for (const string& dn : gr->device_list) {
+ response->add_device_name(dn);
+ }
+ for (const string& tn : gr->task_list) {
+ response->add_task_name(tn);
+ }
+ } else {
+ LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
+ }
+ done(s);
+ });
+}
+
+void CollectiveParamResolverDistributed::CompleteInstanceAsync(
+ const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ CollectiveParams* cp = new CollectiveParams;
+ cp->name = request->name();
+ cp->group.group_key = request->group_key();
+ cp->group.group_size = request->group_size();
+ cp->group.device_type = DeviceType(request->device_type());
+ cp->instance.type = CollectiveType(request->type());
+ cp->instance.instance_key = request->instance_key();
+ cp->instance.data_type = request->data_type();
+ cp->instance.shape = TensorShape(request->shape());
+ for (int32 offset : request->subdiv_offset()) {
+ cp->instance.impl_details.subdiv_offsets.push_back(offset);
+ }
+ VLOG(1) << "New cp " << cp << " for device " << request->device() << " : "
+ << cp->ToString();
+ StatusCallback done_and_cleanup = [this, cp, done](const Status& s) {
+ done(s);
+ delete cp;
+ };
+ // Start by completing the group.
+ CompleteGroupDistributed(
+ request->device(), cp, cancel_mgr,
+ [this, cp, request, response, cancel_mgr, done_and_cleanup](
+ const Status& cg_status, const GroupRec* gr) {
+ if (cg_status.ok()) {
+ // Then complete the instance.
+ CompleteInstanceDistributed(
+ request->device(), gr, cp, cancel_mgr,
+ [this, gr, cp, response,
+ done_and_cleanup](const Status& ci_status) {
+ if (ci_status.ok()) {
+ // Now source_rank should be known, so
+ // retrieve it.
+ FindInstanceRec(
+ gr, cp,
+ [this, gr, cp, response, done_and_cleanup](
+ const Status& fi_status, InstanceRec* ir) {
+ if (fi_status.ok()) {
+ mutex_lock l(ir->out_mu);
+ response->set_instance_key(cp->instance.instance_key);
+ response->set_source_rank(ir->source_rank);
+ done_and_cleanup(fi_status);
+ } else {
+ done_and_cleanup(fi_status);
+ }
+ });
+ } else {
+ done_and_cleanup(ci_status);
+ }
+ });
+ } else {
+ done_and_cleanup(cg_status);
+ }
+ });
+}
+
+bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) {
+ mutex_lock l(group_mu_);
+ const auto& it = group_table_.find(group_key);
+ return it != group_table_.end();
+}
+
+Status CollectiveParamResolverDistributed::UpdateGroupCache(
+ const CompleteGroupResponse& resp) {
+ // Build a new record from resp.
+ std::unique_ptr<GroupRec> gr(new GroupRec);
+ mutex_lock grl(gr->mu);
+ gr->group.device_type = DeviceType(resp.device_type());
+ gr->group.group_key = resp.group_key();
+ gr->group.group_size = resp.group_size();
+ gr->group.num_tasks = resp.num_tasks();
+ if (resp.device_name_size() != gr->group.group_size) {
+ return errors::Internal(
+ "CompleteGroupResponse group_size doesn't match device_name list");
+ }
+ for (const string& dn : resp.device_name()) {
+ gr->device_set.insert(dn);
+ gr->device_list.push_back(dn);
+ }
+ if (resp.task_name_size() != gr->group.group_size) {
+ return errors::Internal(
+ "CompleteGroupResponse group_size doesn't match task_name list");
+ }
+ for (const string& tn : resp.task_name()) {
+ gr->task_list.push_back(tn);
+ gr->task_set.insert(tn);
+ }
+ CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
+ {
+ // Group membership should never change. Once a record is in group_table_
+ // it never gets removed.
+ mutex_lock l(group_mu_);
+ auto it = group_table_.find(gr->group.group_key);
+ if (it == group_table_.end()) {
+ group_table_[gr->group.group_key] = std::move(gr);
+ }
+ }
+ return Status::OK();
+}
+
+void CollectiveParamResolverDistributed::CompleteGroupDistributed(
+ const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+ const GroupRecCallback& done) {
+ VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
+ << " dev: " << device << " is_leader=" << (group_leader_.empty());
+ VLOG(0) << "cp: " << cp->ToString();
+ if (group_leader_.empty()) {
+ // This is the group leader, so resolution is local.
+ return CompleteGroupLocal(device, cp, done);
+ } else if (!GroupIsCached(cp->group.group_key)) {
+ // Need to update Group cache from the leader.
+ CompleteGroupCall* call = new CompleteGroupCall(
+ cp->group, device, cancel_mgr, group_leader_, worker_cache_);
+ call->Start([this, device, cp, call, done](const Status& s) {
+ if (s.ok()) {
+ Status status = UpdateGroupCache(call->resp_);
+ if (status.ok()) {
+ CompleteGroupLocal(device, cp, done);
+ } else {
+ done(status, nullptr);
+ }
+ } else {
+ done(s, nullptr);
+ }
+ delete call;
+ });
+ return;
+ } else {
+ return CompleteGroupLocal(device, cp, done);
+ }
+}
+
+bool CollectiveParamResolverDistributed::InstanceIsCached(int32 instance_key) {
+ mutex_lock l(instance_mu_);
+ const auto& it = instance_table_.find(instance_key);
+ return it != instance_table_.end();
+}
+
+void CollectiveParamResolverDistributed::UpdateInstanceCache(
+ const GroupRec* gr, CollectiveParams* cp,
+ const CompleteInstanceResponse& resp, const StatusCallback& done) {
+ Notification note;
+ InstanceRec* ir = nullptr;
+ int32 source_rank = resp.source_rank();
+
+ auto continue_with_ir = [this, cp, &ir, source_rank, done](const Status& s) {
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ Status status;
+ do {
+ mutex_lock l(ir->out_mu);
+ if (ir->source_rank != source_rank) {
+ if (ir->source_rank >= 0) {
+ ir->status = errors::Internal(
+ "UpdateInstanceCache: CompleteInstanceResponse for instance ",
+ cp->instance.instance_key, " gives source_rank=", source_rank,
+ " but cache already holds value=", ir->source_rank);
+ status = ir->status;
+ break;
+ }
+ ir->source_rank = source_rank;
+ }
+ if (ir->known_count < cp->group.group_size) {
+ ir->known_count = cp->group.group_size;
+ if (ir->known.size() != cp->group.group_size) {
+ ir->status = errors::Internal(
+ "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
+ cp->instance.instance_key, " has known.size()=", ir->known.size(),
+ " < group_size=", cp->group.group_size);
+ status = ir->status;
+ break;
+ }
+ for (int i = 0; i < ir->known.size(); ++i) {
+ ir->known[i] = true;
+ }
+ }
+ status = ir->status;
+ } while (false);
+ // Callback outside of lock.
+ done(status);
+ };
+
+ FindInstanceRec(
+ gr, cp, [this, &ir, continue_with_ir](const Status s, InstanceRec* irec) {
+ ir = irec;
+ continue_with_ir(s);
+ });
+}
+
+void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
+ const string& device, const GroupRec* gr, CollectiveParams* cp,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ if (group_leader_.empty()) {
+ // This is the group leader so resolution is local.
+ return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
+ } else if (InstanceIsCached(cp->instance.instance_key)) {
+ return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
+ } else {
+ CompleteInstanceCall* call = new CompleteInstanceCall(
+ cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
+ group_leader_, worker_cache_);
+ call->Start([this, device, gr, cp, call, done](const Status& s) {
+ if (s.ok()) {
+ UpdateInstanceCache(
+ gr, cp, call->resp_, [this, device, gr, cp, done](const Status& s) {
+ if (!s.ok()) {
+ done(s);
+ } else {
+ CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
+ }
+ });
+ } else {
+ done(s);
+ }
+ delete call;
+ });
+ return;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
new file mode 100644
index 0000000000..a35131d835
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+
+namespace tensorflow {
+class ConfigProto;
+class WorkerCacheInterface;
+class DeviceResolverDistributed;
+class DeviceMgr;
+
+class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
+ public:
+ CollectiveParamResolverDistributed(const ConfigProto& config,
+ const DeviceMgr* dev_mgr,
+ DeviceResolverDistributed* dev_resolver,
+ WorkerCacheInterface* worker_cache,
+ const string& task_name);
+
+ void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ void CompleteGroupAsync(const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ protected:
+ // Returns true iff there's an entry for this group_key in the
+ // local group_table_.
+ bool GroupIsCached(int32 group_key) LOCKS_EXCLUDED(group_mu_);
+
+ // Updates group_table_ with contents of resp.
+ Status UpdateGroupCache(const CompleteGroupResponse& resp)
+ LOCKS_EXCLUDED(group_mu_);
+
+ // Finds the GroupRec that corresponds to cp->group_key and also
+ // populates cp->group from that GroupRec.
+ //
+ // Semantics are like those of CompleteGroupLocal but will make a
+ // remote call to the group leader if necessary.
+ void CompleteGroupDistributed(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const GroupRecCallback& done);
+
+ // Returns true iff there's an entry for this instance_key in the
+ // local instance_table_.
+ bool InstanceIsCached(int32 instance_key) LOCKS_EXCLUDED(instance_mu_);
+
+ // Updates instance_table_ with contents of resp.
+ void UpdateInstanceCache(const GroupRec* gr, CollectiveParams* cp,
+ const CompleteInstanceResponse& resp,
+ const StatusCallback& done)
+ LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+ // Finish populating *cp. Semantics are like those of
+ // CompleteInstanceLocal but will make a remote call to the group
+ // leader if necessary.
+ void CompleteInstanceDistributed(const string& device, const GroupRec* gr,
+ CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done)
+ LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+ WorkerCacheInterface* worker_cache_; // Not owned
+ const string group_leader_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
new file mode 100644
index 0000000000..95a010286d
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -0,0 +1,324 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace {
+
+static Device* NewDevice(const string& type, const string& name) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ attr.mutable_locality()->set_numa_node(3); // a non-default value
+ return new FakeDevice(attr);
+}
+
+class FakeWorker : public TestWorkerInterface {
+ public:
+ FakeWorker(const string& name, DeviceMgr* dev_mgr,
+ CollectiveParamResolverDistributed* cpres)
+ : name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
+
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ std::vector<DeviceAttributes> dev_attr;
+ device_mgr_->ListDeviceAttributes(&dev_attr);
+ for (const auto& da : dev_attr) {
+ *response->add_device_attributes() = da;
+ }
+ done(Status::OK());
+ }
+
+ void CompleteGroupAsync(CallOptions* opts,
+ const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ StatusCallback done) override {
+ param_resolver_->CompleteGroupAsync(request, response, &cm_, done);
+ }
+
+ void CompleteInstanceAsync(CallOptions* ops,
+ const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ StatusCallback done) override {
+ param_resolver_->CompleteInstanceAsync(request, response, &cm_, done);
+ }
+
+ private:
+ string name_;
+ DeviceMgr* device_mgr_;
+ CancellationManager cm_;
+ CollectiveParamResolverDistributed* param_resolver_;
+};
+
+class FakeCache : public TestWorkerCache {
+ public:
+ // Override the Locality methods to actually pass through to the
+ // worker.
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return false;
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ string task_name;
+ string dev_part;
+ if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
+ done(errors::Internal("failed to parse device name"));
+ return;
+ }
+ auto it = workers_.find(task_name);
+ if (it == workers_.end()) {
+ done(errors::Internal("failed to find worker ", task_name));
+ return;
+ }
+ WorkerInterface* wi = it->second;
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ Notification note;
+ Status status = wi->GetStatus(&req, &resp);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ for (const auto& it : resp.device_attributes()) {
+ if (it.name() == device) {
+ *locality = it.locality();
+ done(Status::OK());
+ return;
+ }
+ }
+ done(errors::Internal("device not found: ", device));
+ }
+};
+
+class DeviceResDistTest : public ::testing::Test {
+ protected:
+ DeviceResDistTest() {}
+
+ ~DeviceResDistTest() override {
+ for (DeviceMgr* dm : device_mgrs_) {
+ delete dm;
+ }
+ for (auto it : dev_resolvers_) {
+ delete it.second;
+ }
+ for (auto it : cp_resolvers_) {
+ delete it.second;
+ }
+ for (FakeWorker* w : workers_) {
+ delete w;
+ }
+ }
+
+ void DefineWorkers(int num_workers, int num_devices,
+ const string& device_type) {
+ ConfigProto config;
+ for (int w = 0; w < num_workers; ++w) {
+ string name = strings::StrCat("/job:worker/replica:0/task:", w);
+ // TODO(tucker): When config option becomes available, set here.
+ // if (w == 0) {
+ // config.set_collective_group_leader(name);
+ // }
+ DefineWorker(config, name, device_type, num_devices);
+ }
+ }
+
+ void DefineWorker(const ConfigProto& config, const string& worker_name,
+ const string& device_type, int num_devices) {
+ std::vector<Device*> devices;
+ for (int i = 0; i < num_devices; ++i) {
+ devices.push_back(NewDevice(
+ device_type,
+ strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+ }
+ DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ device_mgrs_.push_back(dev_mgr);
+ std::vector<string>* dv = &dev_by_task_[worker_name];
+ for (auto d : devices) {
+ dv->push_back(d->name());
+ }
+ DeviceResolverDistributed* dev_res =
+ new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+ dev_resolvers_[worker_name] = dev_res;
+ CollectiveParamResolverDistributed* cp_res =
+ new CollectiveParamResolverDistributed(config, dev_mgr, dev_res, &wc_,
+ worker_name);
+ cp_resolvers_[worker_name] = cp_res;
+ FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, cp_res);
+ workers_.push_back(fw);
+ wc_.AddWorker(worker_name, fw);
+ }
+
+ void DefineCollectiveParams(int num_workers, int num_devices) {
+ const int kGroupKey = 5;
+ const int kInstanceKey = 3;
+ for (int wi = 0; wi < num_workers; ++wi) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
+ for (int di = 0; di < num_devices; ++di) {
+ string device_name = strings::StrCat(task_name, "/device:CPU:", di);
+ cp_.push_back(CollectiveParams());
+ CollectiveParams& cp = cp_.back();
+ cp.group.group_key = kGroupKey;
+ cp.group.group_size = num_workers * num_devices;
+ cp.group.device_type = DEVICE_CPU;
+ cp.group.num_tasks = num_workers;
+ cp.instance.instance_key = kInstanceKey;
+ cp.instance.type = REDUCTION_COLLECTIVE;
+ cp.instance.data_type = DT_FLOAT;
+ cp.instance.shape = TensorShape({64});
+ cp.instance.impl_details.subdiv_offsets.push_back(0);
+ }
+ }
+ }
+
+ void IssueRequests(int num_workers, int num_devices) {
+ const int device_count = num_workers * num_devices;
+ {
+ mutex_lock l(mu_);
+ num_done_ = 0;
+ }
+ cp_.resize(device_count);
+ status_.resize(device_count);
+ int idx = 0;
+ for (int wi = 0; wi < num_workers; ++wi) {
+ for (int di = 0; di < num_devices; ++di) {
+ IssueRequest(num_workers, num_devices, idx);
+ ++idx;
+ }
+ }
+ }
+
+ void IssueRequest(int num_workers, int num_devices, int idx) {
+ int device_count = num_workers * num_devices;
+ int wi = idx / num_devices;
+ int di = idx % num_devices;
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
+ string device_name = strings::StrCat(task_name, "/device:CPU:", di);
+ while (idx >= cp_.size()) {
+ status_.resize(idx + 1);
+ cp_.resize(idx + 1);
+ }
+ CollectiveParams* cp = &cp_[idx];
+ CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name];
+ CHECK(cp_res);
+ cp_res->CompleteParamsAsync(device_name, cp, &cm_,
+ [this, idx, device_count](const Status& s) {
+ status_[idx] = s;
+ {
+ mutex_lock l(mu_);
+ ++num_done_;
+ if (num_done_ == device_count) {
+ done_.notify_all();
+ }
+ }
+ });
+ }
+
+ void ValidateCollectiveParams(int num_workers, int num_devices) {
+ int device_count = num_workers * num_devices;
+ {
+ mutex_lock l(mu_);
+ if (num_done_ < device_count) {
+ done_.wait(l);
+ }
+ }
+ // Verify that all cp_ values get the same set of task and device
+ // names, with unique default_rank in the expected order.
+ const int dev_count = num_workers * num_devices;
+ for (int wi = 0; wi < num_workers; ++wi) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
+ for (int di = 0; di < num_devices; ++di) {
+ string device_name = strings::StrCat(task_name, "/device:CPU:", di);
+ int idx = wi * num_devices + di;
+ TF_ASSERT_OK(status_[idx]);
+ EXPECT_EQ(cp_[idx].default_rank, idx);
+ EXPECT_EQ(cp_[idx].instance.device_names.size(), dev_count);
+ EXPECT_EQ(cp_[idx].instance.device_names[idx], device_name);
+ EXPECT_EQ(cp_[idx].instance.task_names[idx], task_name);
+ if (idx > 0) {
+ for (int i = 0; i < dev_count; ++i) {
+ EXPECT_EQ(cp_[0].instance.device_names[i],
+ cp_[idx].instance.device_names[i]);
+ EXPECT_EQ(cp_[0].instance.task_names[i],
+ cp_[idx].instance.task_names[i]);
+ }
+ }
+ }
+ }
+ }
+
+ FakeCache wc_;
+ CancellationManager cm_;
+ std::vector<DeviceMgr*> device_mgrs_;
+ std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
+ std::unordered_map<string, CollectiveParamResolverDistributed*> cp_resolvers_;
+ std::unordered_map<string, std::vector<string>> dev_by_task_;
+ std::vector<FakeWorker*> workers_;
+ std::vector<CollectiveParams> cp_;
+ std::vector<Status> status_;
+ mutex mu_;
+ int num_done_ GUARDED_BY(mu_);
+ condition_variable done_;
+};
+
+TEST_F(DeviceResDistTest, Workers1Devices1) {
+ const int num_workers = 1;
+ const int num_devices = 1;
+ DefineWorkers(num_workers, num_devices, "CPU");
+ DefineCollectiveParams(num_workers, num_devices);
+ IssueRequests(num_workers, num_devices);
+ ValidateCollectiveParams(num_workers, num_devices);
+}
+
+TEST_F(DeviceResDistTest, Workers2Devices2) {
+ const int num_workers = 2;
+ const int num_devices = 2;
+ DefineWorkers(num_workers, num_devices, "CPU");
+ DefineCollectiveParams(num_workers, num_devices);
+ IssueRequests(num_workers, num_devices);
+ ValidateCollectiveParams(num_workers, num_devices);
+}
+
+TEST_F(DeviceResDistTest, Workers4Devices3) {
+ const int num_workers = 4;
+ const int num_devices = 3;
+ DefineWorkers(num_workers, num_devices, "CPU");
+ DefineCollectiveParams(num_workers, num_devices);
+ IssueRequests(num_workers, num_devices);
+ ValidateCollectiveParams(num_workers, num_devices);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
new file mode 100644
index 0000000000..038974cb39
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+
+namespace tensorflow {
+DeviceResolverDistributed::DeviceResolverDistributed(
+ const DeviceMgr* dev_mgr, WorkerCacheInterface* worker_cache,
+ const string& task_name)
+ : dev_mgr_(dev_mgr), worker_cache_(worker_cache), task_name_(task_name) {}
+
+void DeviceResolverDistributed::GetLocalityAsync(const string& device,
+ const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) {
+ if (task.empty() || task == task_name_) {
+ // Device is local to this task.
+ Device* dev;
+ Status s = dev_mgr_->LookupDevice(device, &dev);
+ if (s.ok()) {
+ *locality = dev->attributes().locality();
+ }
+ done(s);
+ return;
+ } else {
+ // Lookup of a remote device: first try the local cache.
+ bool found = false;
+ {
+ mutex_lock l(mu_);
+ auto it = attr_table_.find(device);
+ if (it != attr_table_.end()) {
+ *locality = it->second.locality();
+ found = true;
+ }
+ }
+ if (found) {
+ done(Status::OK());
+ return;
+ }
+ }
+ // Device is remote and no cache entry was found. Refresh the cache
+ // then retry the lookup.
+ RefreshRemoteAttributes(
+ device, task, [this, device, task, locality, done](const Status& s) {
+ if (!s.ok()) {
+ done(s);
+ } else {
+ GetLocalityAsync(device, task, locality, done);
+ }
+ });
+}
+
+void DeviceResolverDistributed::GetDeviceLocalitiesAsync(
+ const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities, const StatusCallback& done) {
+ localities->clear();
+ GetDeviceLocalitiesRecursive(inst_params, localities, done);
+}
+
+void DeviceResolverDistributed::GetDeviceLocalitiesRecursive(
+ const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities, const StatusCallback& done) {
+ size_t i = localities->size();
+ if (i < inst_params.device_names.size()) {
+ localities->push_back(DeviceLocality());
+ GetLocalityAsync(inst_params.device_names[i], inst_params.task_names[i],
+ &localities->back(),
+ [this, &inst_params, localities, done](const Status& s) {
+ if (!s.ok()) {
+ done(s);
+ return;
+ } else {
+ GetDeviceLocalitiesRecursive(inst_params, localities,
+ done);
+ }
+ });
+ } else {
+ done(Status::OK());
+ }
+}
+
+void DeviceResolverDistributed::RefreshRemoteAttributes(
+ const string& device, const string& task, const StatusCallback& done) {
+ GetStatusRequest* req = new GetStatusRequest;
+ GetStatusResponse* resp = new GetStatusResponse;
+ WorkerInterface* worker = worker_cache_->CreateWorker(task);
+ CHECK(worker) << "Failed to get worker for " << task;
+ worker->GetStatusAsync(
+ req, resp, [this, device, task, req, resp, worker, done](Status s) {
+ if (s.ok()) {
+ mutex_lock l(mu_);
+ for (const DeviceAttributes& da : resp->device_attributes()) {
+ attr_table_[da.name()] = da;
+ }
+ }
+ done(s);
+ delete req;
+ delete resp;
+ worker_cache_->ReleaseWorker(task, worker);
+ });
+}
+
+void DeviceResolverDistributed::ClearTask(const string& task) {
+ mutex_lock l(mu_);
+ // First find all the keys belonging to the task.
+ std::unordered_set<string> task_keys;
+ for (const auto& it : attr_table_) {
+ const string& device_name = it.first;
+ if (DeviceNameUtils::IsSameAddressSpace(task, device_name)) {
+ task_keys.insert(device_name);
+ }
+ }
+ // Then delete them.
+ for (const string& key : task_keys) {
+ attr_table_.erase(key);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
new file mode 100644
index 0000000000..ac68ec6873
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class DeviceMgr;
+class WorkerCacheInterface;
+
+class DeviceResolverDistributed : public DeviceResolverInterface {
+ public:
+ DeviceResolverDistributed(const DeviceMgr* dev_mgr,
+ WorkerCacheInterface* worker_cache,
+ const string& task_name);
+
+ virtual ~DeviceResolverDistributed() {}
+
+ void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) override;
+
+ void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) override;
+
+ void ClearTask(const string& task) override;
+
+ protected:
+ // Loads attr_table_ with device attributes retrieved from remote task.
+ void RefreshRemoteAttributes(const string& device, const string& task,
+ const StatusCallback& done) LOCKS_EXCLUDED(mu_);
+
+ // Subroutine used by GetDeviceLocalitiesAsync. Recursively extends
+ // *localities with DeviceLocality of the corresponding device named
+ // by inst_params.instance.device_names.
+ void GetDeviceLocalitiesRecursive(const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done);
+
+ const DeviceMgr* dev_mgr_; // Not owned
+ WorkerCacheInterface* worker_cache_; // Not owned
+ const string task_name_;
+ mutex mu_;
+ gtl::FlatMap<string, DeviceAttributes> attr_table_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
new file mode 100644
index 0000000000..ae44b98bd5
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
@@ -0,0 +1,217 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace {
+
+// Subclass of DeviceResolverDistributed which behaves identically but
+// allows access to the attr_table_.
+class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
+ public:
+ TestableDeviceResolverDistributed(const DeviceMgr* dev_mgr,
+ WorkerCacheInterface* worker_cache,
+ const string& task)
+ : DeviceResolverDistributed(dev_mgr, worker_cache, task) {}
+
+ gtl::FlatMap<string, DeviceAttributes>& attr_table() { return attr_table_; }
+};
+
+// Create a fake 'Device' whose only interesting attribute is a non-default
+// DeviceLocality.
+static Device* NewDevice(const string& type, const string& name,
+ int numa_node) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ attr.mutable_locality()->set_numa_node(numa_node);
+ return new FakeDevice(attr);
+}
+
+// Create a fake WorkerInterface that responds to requests without RPCs,
+// in this case returning the DeviceAttributes of a fake remote worker.
+class FakeWorker : public TestWorkerInterface {
+ public:
+ FakeWorker(const string& name, DeviceMgr* dev_mgr,
+ DeviceResolverDistributed* dres)
+ : name_(name), device_mgr_(dev_mgr), device_resolver_(dres) {}
+
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ std::vector<DeviceAttributes> dev_attr;
+ device_mgr_->ListDeviceAttributes(&dev_attr);
+ for (const auto& da : dev_attr) {
+ *response->add_device_attributes() = da;
+ }
+ done(Status::OK());
+ }
+
+ private:
+ string name_;
+ DeviceMgr* device_mgr_;
+ DeviceResolverDistributed* device_resolver_;
+};
+
+// An implementation of WorkerCacheInterface that routes all requests
+// to local FakeWorkers, implementing only the methods needed for tests.
+class FakeCache : public TestWorkerCache {
+ public:
+ // Override the Locality methods to actually pass through to the
+ // worker.
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return false;
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ string task_name;
+ string dev_part;
+ if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
+ done(errors::Internal("failed to parse device name"));
+ return;
+ }
+ auto it = workers_.find(task_name);
+ if (it == workers_.end()) {
+ done(errors::Internal("failed to find worker ", task_name));
+ return;
+ }
+ WorkerInterface* wi = it->second;
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ Notification note;
+ Status status = wi->GetStatus(&req, &resp);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ for (const auto& it : resp.device_attributes()) {
+ if (it.name() == device) {
+ *locality = it.locality();
+ done(Status::OK());
+ return;
+ }
+ }
+ done(errors::Internal("device not found: ", device));
+ }
+};
+
+class DeviceResDistTest : public ::testing::Test {
+ protected:
+ DeviceResDistTest() {}
+
+ ~DeviceResDistTest() override {
+ for (DeviceMgr* dm : device_mgrs_) {
+ delete dm;
+ }
+ for (auto it : resolvers_) {
+ delete it.second;
+ }
+ for (FakeWorker* w : workers_) {
+ delete w;
+ }
+ }
+
+ void DefineWorkers(int num_workers, int num_devices,
+ const string& device_type) {
+ for (int w = 0; w < num_workers; ++w) {
+ string name = strings::StrCat("/job:worker/replica:0/task:", w);
+ DefineWorker(name, device_type, num_devices);
+ }
+ }
+
+ void DefineWorker(const string& worker_name, const string& device_type,
+ int num_devices) {
+ std::vector<Device*> devices;
+ for (int i = 0; i < num_devices; ++i) {
+ devices.push_back(NewDevice(
+ device_type,
+ strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
+ }
+ DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ TestableDeviceResolverDistributed* dev_res =
+ new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+ resolvers_[worker_name] = dev_res;
+ device_mgrs_.push_back(dev_mgr);
+ std::vector<string>* dv = &dev_by_task_[worker_name];
+ for (auto d : devices) {
+ dv->push_back(d->name());
+ }
+ FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+ workers_.push_back(fw);
+ wc_.AddWorker(worker_name, fw);
+ }
+
+ FakeCache wc_;
+ std::vector<DeviceMgr*> device_mgrs_;
+ std::unordered_map<string, TestableDeviceResolverDistributed*> resolvers_;
+ std::unordered_map<string, std::vector<string>> dev_by_task_;
+ std::vector<FakeWorker*> workers_;
+};
+
+TEST_F(DeviceResDistTest, Workers3Devices4) {
+ DefineWorkers(3, 4, "CPU");
+ // Check that every device is available from every task.
+ for (auto it : resolvers_) {
+ DeviceResolverDistributed* dres = it.second;
+ for (auto it2 : dev_by_task_) {
+ const string& task_name = it2.first;
+ for (const auto& dev_name : it2.second) {
+ DeviceNameUtils::ParsedName parsed;
+ ASSERT_TRUE(DeviceNameUtils::ParseFullName(dev_name, &parsed));
+ Notification note;
+ Status status;
+ DeviceLocality locality;
+ dres->GetLocalityAsync(dev_name, task_name, &locality,
+ [this, &note, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(parsed.id, locality.numa_node());
+ }
+ }
+ }
+ // Clear just task 0 from all.
+ const string w0_name = "/job:worker/replica:0/task:0";
+ for (auto it : resolvers_) {
+ if (it.first == w0_name) continue;
+ TestableDeviceResolverDistributed* dres = it.second;
+ EXPECT_EQ(8, it.second->attr_table().size());
+ dres->ClearTask("/job:worker/replica:0/task:0");
+ EXPECT_EQ(4, it.second->attr_table().size());
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 895bbd97b7..5b7b74ce63 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -56,6 +56,9 @@ class GrpcRemoteWorker : public WorkerInterface {
recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
logging_(Method(GrpcWorkerMethod::kLogging)),
tracing_(Method(GrpcWorkerMethod::kTracing)),
+ completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
+ instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
+ getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
logger_(logger) {}
~GrpcRemoteWorker() override {}
@@ -115,6 +118,27 @@ class GrpcRemoteWorker : public WorkerInterface {
IssueRequest(request, response, cleanupall_, std::move(done));
}
+ void CompleteGroupAsync(CallOptions* call_opts,
+ const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, completegroup_, std::move(done), call_opts);
+ }
+
+ void CompleteInstanceAsync(CallOptions* call_opts,
+ const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, instancesource_, std::move(done),
+ call_opts);
+ }
+
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, getstepsequence_, std::move(done));
+ }
+
void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
TensorResponse* response, StatusCallback done) override {
VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
@@ -217,6 +241,9 @@ class GrpcRemoteWorker : public WorkerInterface {
const ::grpc::string recvtensor_;
const ::grpc::string logging_;
const ::grpc::string tracing_;
+ const ::grpc::string completegroup_;
+ const ::grpc::string instancesource_;
+ const ::grpc::string getstepsequence_;
// Support for logging.
WorkerCacheLogger* logger_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index b20e744a97..bbf7391377 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -172,6 +172,12 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(Logging, false);
ENQUEUE_REQUEST(Tracing, false);
+ for (int i = 0; i < 10; ++i) {
+ ENQUEUE_REQUEST(CompleteGroup, false);
+ ENQUEUE_REQUEST(CompleteInstance, false);
+ ENQUEUE_REQUEST(GetStepSequence, false);
+ }
+
void* tag;
bool ok;
@@ -318,6 +324,47 @@ class GrpcWorkerService : public AsyncServiceInterface {
});
ENQUEUE_REQUEST(Tracing, false);
}
+
+ void CompleteGroupHandler(
+ WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
+ Schedule([this, call]() {
+ CallOptions* call_opts = new CallOptions;
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ worker_->CompleteGroupAsync(call_opts, &call->request, &call->response,
+ [call, call_opts](const Status& s) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ });
+ ENQUEUE_REQUEST(CompleteGroup, false);
+ }
+
+ void CompleteInstanceHandler(
+ WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
+ Schedule([this, call]() {
+ CallOptions* call_opts = new CallOptions;
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ worker_->CompleteInstanceAsync(call_opts, &call->request,
+ &call->response,
+ [call, call_opts](const Status& s) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ });
+ ENQUEUE_REQUEST(CompleteInstance, false);
+ }
+
+ void GetStepSequenceHandler(
+ WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
+ Schedule([this, call]() {
+ worker_->GetStepSequenceAsync(
+ &call->request, &call->response,
+ [call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
+ });
+ ENQUEUE_REQUEST(GetStepSequence, false);
+ }
#undef ENQUEUE_REQUEST
void EnqueueRecvTensorRequestRaw() {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
index 05a9db10d3..a91cc0692a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
@@ -50,6 +50,12 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
return "/tensorflow.WorkerService/Logging";
case GrpcWorkerMethod::kTracing:
return "/tensorflow.WorkerService/Tracing";
+ case GrpcWorkerMethod::kCompleteGroup:
+ return "/tensorflow.WorkerService/CompleteGroup";
+ case GrpcWorkerMethod::kCompleteInstance:
+ return "/tensorflow.WorkerService/CompleteInstance";
+ case GrpcWorkerMethod::kGetStepSequence:
+ return "/tensorflow.WorkerService/GetStepSequence";
}
// Shouldn't be reached.
LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index a54ea93796..c5104c6a50 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -83,9 +83,12 @@ enum class GrpcWorkerMethod {
kRecvTensor,
kLogging,
kTracing,
+ kCompleteGroup,
+ kCompleteInstance,
+ kGetStepSequence,
};
static const int kGrpcNumWorkerMethods =
- static_cast<int>(GrpcWorkerMethod::kTracing) + 1;
+ static_cast<int>(GrpcWorkerMethod::kGetStepSequence) + 1;
const char* GrpcWorkerMethodName(GrpcWorkerMethod id);
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
new file mode 100644
index 0000000000..0ed078241f
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -0,0 +1,173 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_
+
+#include <unordered_map>
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+
+namespace tensorflow {
+
+// Some utilities for testing distributed-mode components in a single process
+// without RPCs.
+
+// Implements the worker interface with methods that just respond with
+// "unimplemented" status. Override just the methods needed for
+// testing.
+class TestWorkerInterface : public WorkerInterface {
+ public:
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("GetStatusAsync"));
+ }
+
+ void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
+ CreateWorkerSessionResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("CreateWorkerSessionAsync"));
+ }
+
+ void DeleteWorkerSessionAsync(CallOptions* opts,
+ const DeleteWorkerSessionRequest* request,
+ DeleteWorkerSessionResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("DeleteWorkerSessionAsync"));
+ }
+
+ void RegisterGraphAsync(const RegisterGraphRequest* request,
+ RegisterGraphResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RegisterGraphAsync"));
+ }
+
+ void DeregisterGraphAsync(const DeregisterGraphRequest* request,
+ DeregisterGraphResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("DeregisterGraphAsync"));
+ }
+
+ void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
+ MutableRunGraphResponseWrapper* repsonse,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void CleanupGraphAsync(const CleanupGraphRequest* request,
+ CleanupGraphResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void CleanupAllAsync(const CleanupAllRequest* request,
+ CleanupAllResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
+ TensorResponse* response, StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void TracingAsync(const TracingRequest* request, TracingResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void CompleteGroupAsync(CallOptions* opts,
+ const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void CompleteInstanceAsync(CallOptions* ops,
+ const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ StatusCallback done) override {
+ done(errors::Unimplemented("RunGraphAsync"));
+ }
+};
+
+class TestWorkerCache : public WorkerCacheInterface {
+ public:
+ virtual ~TestWorkerCache() {}
+
+ void AddWorker(const string& target, WorkerInterface* wi) {
+ workers_[target] = wi;
+ }
+
+ void AddDevice(const string& device_name, const DeviceLocality& dev_loc) {
+ localities_[device_name] = dev_loc;
+ }
+
+ void ListWorkers(std::vector<string>* workers) const override {
+ workers->clear();
+ for (auto it : workers_) {
+ workers->push_back(it.first);
+ }
+ }
+
+ WorkerInterface* CreateWorker(const string& target) override {
+ auto it = workers_.find(target);
+ if (it != workers_.end()) {
+ return it->second;
+ }
+ return nullptr;
+ }
+
+ void ReleaseWorker(const string& target, WorkerInterface* worker) override {}
+
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ auto it = localities_.find(device);
+ if (it != localities_.end()) {
+ *locality = it->second;
+ return true;
+ }
+ return false;
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ auto it = localities_.find(device);
+ if (it != localities_.end()) {
+ *locality = it->second;
+ done(Status::OK());
+ return;
+ }
+ done(errors::Internal("Device not found: ", device));
+ }
+
+ protected:
+ std::unordered_map<string, WorkerInterface*> workers_;
+ std::unordered_map<string, DeviceLocality> localities_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index e9073ef9f6..d682ac8f34 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker.h"
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
@@ -25,8 +26,7 @@ limitations under the License.
namespace tensorflow {
-Worker::Worker(WorkerEnv* env)
- : env_(env), cancellation_manager_(new CancellationManager) {}
+Worker::Worker(WorkerEnv* env) : env_(env) {}
void Worker::GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response, StatusCallback done) {
@@ -185,19 +185,16 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
AbortStep(step_id);
});
CancellationToken token;
- {
- mutex_lock l(mu_);
- token = cancellation_manager_->get_cancellation_token();
- bool already_cancelled = !cancellation_manager_->RegisterCallback(
- token, [cm]() { cm->StartCancel(); });
- if (already_cancelled) {
- opts->ClearCancelCallback();
- delete cm;
- delete collector;
- delete out;
- done(errors::Aborted("Call was aborted"));
- return;
- }
+ token = cancellation_manager_.get_cancellation_token();
+ bool already_cancelled = !cancellation_manager_.RegisterCallback(
+ token, [cm]() { cm->StartCancel(); });
+ if (already_cancelled) {
+ opts->ClearCancelCallback();
+ delete cm;
+ delete collector;
+ delete out;
+ done(errors::Aborted("Call was aborted"));
+ return;
}
session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, session.get(), request->exec_opts(),
@@ -208,10 +205,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
s = session->graph_mgr->RecvOutputs(step_id, out);
}
opts->ClearCancelCallback();
- {
- mutex_lock l(mu_);
- cancellation_manager_->DeregisterCallback(token);
- }
+ cancellation_manager_.DeregisterCallback(token);
delete cm;
if (s.ok()) {
@@ -276,20 +270,14 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
// executors.
if (is_new_partial_run) {
CancellationToken token;
- {
- mutex_lock l(mu_);
- token = cancellation_manager_->get_cancellation_token();
- cancellation_manager_->RegisterCallback(token,
- [cm]() { cm->StartCancel(); });
- }
+ token = cancellation_manager_.get_cancellation_token();
+ cancellation_manager_.RegisterCallback(token,
+ [cm]() { cm->StartCancel(); });
session->graph_mgr->ExecuteAsync(
graph_handle, step_id, session.get(), request->exec_opts(),
nullptr /* collector */, nullptr /* response */, cm, in,
[this, token, step_id, session](Status s) {
- {
- mutex_lock l(mu_);
- cancellation_manager_->DeregisterCallback(token);
- }
+ cancellation_manager_.DeregisterCallback(token);
partial_run_mgr_.ExecutorDone(step_id, s);
});
} else {
@@ -324,6 +312,9 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
StatusCallback done) {
const int64 step_id = request->step_id();
env_->rendezvous_mgr->Cleanup(step_id);
+ if (env_->collective_executor_mgr) {
+ env_->collective_executor_mgr->Cleanup(step_id);
+ }
done(Status::OK());
}
@@ -346,6 +337,44 @@ void Worker::TracingAsync(const TracingRequest* request,
done(errors::Unimplemented("Tracing"));
}
+void Worker::CompleteGroupAsync(CallOptions* opts,
+ const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ StatusCallback done) {
+ if (env_->collective_executor_mgr) {
+ env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
+ request, response, &cancellation_manager_, done);
+ } else {
+ done(
+ errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
+ }
+}
+
+void Worker::CompleteInstanceAsync(CallOptions* opts,
+ const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ StatusCallback done) {
+ if (env_->collective_executor_mgr) {
+ env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
+ request, response, &cancellation_manager_, done);
+ } else {
+ done(
+ errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
+ }
+}
+
+void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ StatusCallback done) {
+ if (env_->collective_executor_mgr) {
+ env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
+ done);
+ } else {
+ done(
+ errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
+ }
+}
+
// Helper for RecvTensor. Validates "key" and returns the source
// device in "*src_dev".
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h
index 19aeeb752c..b5a9ada502 100644
--- a/tensorflow/core/distributed_runtime/worker.h
+++ b/tensorflow/core/distributed_runtime/worker.h
@@ -90,6 +90,20 @@ class Worker : public WorkerInterface {
void TracingAsync(const TracingRequest* request, TracingResponse* response,
StatusCallback done) override;
+ void CompleteGroupAsync(CallOptions* opts,
+ const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ StatusCallback done) override;
+
+ void CompleteInstanceAsync(CallOptions* opts,
+ const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ StatusCallback done) override;
+
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ StatusCallback done) override;
+
protected:
WorkerEnv* const env_; // Not owned.
@@ -101,8 +115,7 @@ class Worker : public WorkerInterface {
private:
PartialRunMgr partial_run_mgr_;
- mutex mu_;
- CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
+ CancellationManager cancellation_manager_;
Status PrepareRunGraph(RunGraphRequestWrapper* req,
GraphMgr::NamedTensors* in,
diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h
index 793d58c8a1..93d933bfa6 100644
--- a/tensorflow/core/distributed_runtime/worker_env.h
+++ b/tensorflow/core/distributed_runtime/worker_env.h
@@ -25,6 +25,7 @@ namespace thread {
class ThreadPool;
} // namespace thread
+class CollectiveExecutorMgrInterface;
class Device;
class DeviceMgr;
class Env;
@@ -57,6 +58,10 @@ struct WorkerEnv {
// A set of rendezvous keyed by step ids.
RendezvousMgrInterface* rendezvous_mgr = nullptr;
+ // Generates per-step CollectiveExecutors and has access to utilities
+ // supporting collective operations.
+ CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
+
// A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr;
};
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index a1597ee798..bad31d27b2 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -112,6 +112,20 @@ class WorkerInterface {
virtual void TracingAsync(const TracingRequest* request,
TracingResponse* response, StatusCallback done) = 0;
+ virtual void CompleteGroupAsync(CallOptions* opts,
+ const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void CompleteInstanceAsync(CallOptions* ops,
+ const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ StatusCallback done) = 0;
+
Status GetStatus(const GetStatusRequest* request,
GetStatusResponse* response) {
return CallAndWait(&ME::GetStatusAsync, request, response);
@@ -156,6 +170,11 @@ class WorkerInterface {
return CallAndWait(&ME::TracingAsync, request, response);
}
+ Status GetStepSequence(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response) {
+ return CallAndWait(&ME::GetStepSequenceAsync, request, response);
+ }
+
protected:
// Instances of WorkerInterface must be deleted by a call to
// WorkerCacheInterface::ReleaseWorker().
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
index f6c3c0b71b..661c28969e 100644
--- a/tensorflow/core/framework/tracking_allocator.h
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index fb8a6c39e6..eeb6c60f71 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -79,6 +79,7 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"Size", NC_METADATA},
{"Shape", NC_METADATA},
{"Rank", NC_METADATA},
+ {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
});
#undef REF_CLASS
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index f7ca7d0620..83a69e6b2d 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -34,8 +34,8 @@ limitations under the License.
// between output O of layer A and input I of layer B using
// "input index" and "output index" labels per edge.
-#ifndef TENSORFLOW_GRAPH_GRAPH_H_
-#define TENSORFLOW_GRAPH_GRAPH_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_H_
#include <functional>
#include <string>
@@ -162,6 +162,7 @@ class Node {
}
bool IsHostSend() const { return class_ == NC_HOST_SEND; }
bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
+ bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
bool IsMetadata() const { return class_ == NC_METADATA; }
@@ -233,6 +234,7 @@ class Node {
NC_GET_SESSION_TENSOR,
NC_DELETE_SESSION_TENSOR,
NC_METADATA,
+ NC_SCOPED_ALLOCATOR,
NC_OTHER // Not a special kind of node
};
@@ -696,6 +698,8 @@ inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
// (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
+inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
+
inline bool IsHostMemoryPreserving(const Node* node) {
return IsIdentity(node) || IsControlFlow(node);
}
@@ -827,4 +831,4 @@ inline const string& Node::assigned_device_name() const {
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 313f63149d..2c7b57971a 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -256,18 +256,14 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
return root;
}
-bool IsQueue(const NodeDef& node) {
- return str_util::EndsWith(node.op(), "QueueV2");
+bool IsEnqueue(const NodeDef& n) {
+ return (n.op().find("Enqueue") != std::string::npos &&
+ n.op().find("EnqueueMany") == std::string::npos);
}
-// Returns true if the node is an Enter op AND its input is a Queue.
-bool IsEnterWithQueue(const NodeDef& node, const GraphView& graph) {
- if (IsEnter(node)) {
- GraphView::InputPort input(&node, 0);
- GraphView::OutputPort fanin = graph.GetRegularFanin(input);
- return IsQueue(*fanin.node);
- }
- return false;
+bool IsDequeue(const NodeDef& n) {
+ return (n.op().find("Dequeue") != std::string::npos &&
+ n.op().find("DequeueMany") == std::string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
@@ -386,7 +382,7 @@ class TopoQueue {
std::set<const NodeDef*, CompareNodes> queue_;
};
-// Merge and relax symbolic shapes.
+// Processes symbolic shapes.
// Each symbolic shape or dimension is represented by a handle. Unlike the TF
// shape refiner which creates new handles every time it processes an unknown
// shape/dimension, the symbolic shape refiner assigns a specific handle to each
@@ -428,7 +424,8 @@ class SymbolicShapeRefiner {
}
return it->second.inference_context.get();
}
- Status UpdateNode(const NodeDef* node, bool relax, bool* refined) {
+
+ Status UpdateNode(const NodeDef* node, bool* refined) {
NodeContext* node_context = GetNodeContext(node);
if (node_context == nullptr) {
TF_RETURN_IF_ERROR(AddNode(node));
@@ -519,8 +516,12 @@ class SymbolicShapeRefiner {
}
}
+ // Make sure we schedule the fanout of resources (which have no input)
+ // whenever the resources are updated.
+ *refined |= inference_context->num_inputs() == 0;
+
if (!*refined) {
- // No input shape has changed, we're done
+ // No input shape has changed, we're done.
return Status::OK();
}
@@ -574,51 +575,6 @@ class SymbolicShapeRefiner {
};
// Compute the shape of the tensors outputed by node 'node' at output port
- // 'port_index' as the intersection of shape1 and shape2.
- ShapeHandle OutputAsIntersection(const NodeDef* node, int port_index,
- ShapeHandle shape1, ShapeHandle shape2) {
- if (shape1.SameHandle(shape2)) {
- return shape1;
- }
- InferenceContext* ctx = GetContext(node);
- ShapeHandle merged = shape1;
- if (!ctx->RankKnown(shape2) && !ctx->RankKnown(shape1)) {
- // Return either one since they're expected to represent the same value.
- return shape1;
- } else if (!ctx->RankKnown(shape2) && ctx->RankKnown(shape1)) {
- return shape1;
- } else if (ctx->RankKnown(shape2) && !ctx->RankKnown(shape1)) {
- return shape2;
- } else {
- const int rank = ctx->Rank(shape1);
- if (ctx->Rank(shape2) != rank) {
- // We detected an inconsistency, return an unknown shape. This can
- // happen in the fanout of a merge node since during the initial
- // propagation we optimistically assume that all the inputs to the merge
- // node have the same shape.
- return GetUnknownOutputShape(node, port_index);
- }
- for (int d = 0; d < rank; ++d) {
- if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
- if (ctx->Value(ctx->Dim(shape1, d)) !=
- ctx->Value(ctx->Dim(shape2, d))) {
- DimensionHandle new_dim;
- if (ctx->Value(ctx->Dim(shape1, d)) < 0) {
- new_dim = ctx->Dim(shape2, d);
- } else if (ctx->Value(ctx->Dim(shape2, d)) < 0) {
- new_dim = ctx->Dim(shape1, d);
- } else {
- new_dim = GetUnknownOutputDim(node, port_index, d);
- }
- TF_CHECK_OK(ctx->ReplaceDim(merged, d, new_dim, &merged));
- }
- }
- }
- }
- return merged;
- }
-
- // Compute the shape of the tensors outputed by node 'node' at output port
// 'port_index' as the union of shape1 and shape2.
ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
ShapeHandle shape1, ShapeHandle shape2) {
@@ -822,6 +778,7 @@ class SymbolicShapeRefiner {
status.Update(SetUnknownShape(&node, output_port));
}
}
+
return status;
}
@@ -884,29 +841,6 @@ class SymbolicShapeManager {
DisjointSet<shape_inference::DimensionHandle> dims_;
};
-Status GraphProperties::MergeEnqueueShapesAndTypes(
- SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
- const std::vector<ShapeAndType>& shapes_and_types,
- std::vector<ShapeAndType>* queue_shapes_and_types) {
- if (shapes_and_types.size() != queue_shapes_and_types->size()) {
- return errors::InvalidArgument(
- "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
- " vs ", queue_shapes_and_types->size());
- }
- for (size_t i = 0; i < shapes_and_types.size(); ++i) {
- const ShapeAndType& a = shapes_and_types[i];
- ShapeAndType& b = (*queue_shapes_and_types)[i];
- if (a.dtype != b.dtype) {
- return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
- i, ": ", DataTypeString(a.dtype), " vs ",
- DataTypeString(b.dtype));
- }
-
- b.shape = shape_refiner->OutputAsIntersection(qnode, i, a.shape, b.shape);
- }
- return Status::OK();
-}
-
Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<ShapeAndType>& shapes_and_types,
@@ -930,13 +864,10 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
return Status::OK();
}
-// If a Merge node has a NextIteration node as an input then that input will
-// try to forward an UnknownShape at graph construction time. However, the
-// Merge shape function will always propagate an UnknownShape if any of its
-// inputs are UnknownShapes. So we need to ignore the input from NextIteration
-// nodes to propagate any known shape from the Merge node.
+// Compute the output shape of the merge node as the union of the available
+// input shapes.
Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
- const NodeDef* node, bool relax,
+ const NodeDef* node,
bool* new_shapes) const {
InferenceContext* c = shape_refiner->GetContext(node);
if (!c) {
@@ -955,15 +886,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
bool out_initialized = false;
for (const GraphView::Edge fanin :
shape_refiner->graph().GetFaninEdges(*node, false)) {
- // Skip back edges during the initial propagation phase. This is equivalent
- // to assuming that all the inputs to the merge nodes are fed by the same
- // shape, and will be corrected as needed in the relaxation phase.
- if (!relax && IsNextIteration(*fanin.src.node)) {
- continue;
- }
-
InferenceContext* in = shape_refiner->GetContext(fanin.src.node);
- if (!relax && !in) {
+ if (!in) {
// Handling a loop for the first time, the back edge won't have any shape
// info.
continue;
@@ -976,11 +900,7 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
out = input;
continue;
}
- if (relax) {
- out = shape_refiner->OutputAsUnion(node, 0, input, out);
- } else {
- out = shape_refiner->OutputAsIntersection(node, 0, input, out);
- }
+ out = shape_refiner->OutputAsUnion(node, 0, input, out);
}
if (*new_shapes || !shape_refiner->EquivalentShapes(out, c->output(0))) {
@@ -991,14 +911,12 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
-// Manually propagate the input shape for Enter nodes and update any Merge node
-// outputs.
+// Manually propagate the input shape for Enter nodes.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
- const NodeDef* node, bool relax,
- bool* new_shapes) {
+ const NodeDef* node, bool* new_shapes) {
auto enter_ctx = shape_refiner->GetContext(node);
if (!enter_ctx) {
- TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, relax, new_shapes));
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
enter_ctx = shape_refiner->GetContext(node);
}
@@ -1012,53 +930,56 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
enter_ctx->set_output(0, input);
*new_shapes = true;
}
+ auto* outputs = in->output_handle_shapes_and_types(fanin.port_id);
+ if (outputs) {
+ enter_ctx->set_input_handle_shapes_and_types(0, *outputs);
+ enter_ctx->set_output_handle_shapes_and_types(0, *outputs);
+ *new_shapes = true;
+ }
return Status::OK();
}
-Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
- bool relax, const NodeDef* n,
- bool* new_shapes) const {
+Status GraphProperties::UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
+ const NodeDef* n, bool* new_shapes) const {
if (IsEnter(*n)) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
- TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, relax, new_shapes));
+ TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
} else if (IsMerge(*n)) {
// Properly handle merge nodes.
- TF_RETURN_IF_ERROR(UpdateMergeNode(shape_refiner, n, relax, new_shapes));
+ TF_RETURN_IF_ERROR(UpdateMergeNode(shape_refiner, n, new_shapes));
+ } else if (IsEnqueue(*n)) {
+ // Make sure the shapes of enqueued tensors are propagated to the queue
+ // itself.
+ TF_RETURN_IF_ERROR(
+ UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
} else {
// Rely on regular TF shape refinement for all the other nodes.
- bool updated = false;
- TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, relax, &updated));
- if (updated) {
- // We want to avoid propagating through loops on the merge pass because
- // the shapes are not guaranteed to converge.
- if (relax || !IsNextIteration(*n)) {
- *new_shapes = true;
- }
- }
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
}
return Status::OK();
}
// Propagates the shapes in the transitive fan-out of <new_shapes>.
Status GraphProperties::PropagateShapes(
- SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
- const std::unordered_map<const NodeDef*,
- std::unordered_set<const NodeDef*>>& resources,
+ SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
+ const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most
// num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
// The same applies to resources.
- VLOG(1) << "Propagating (relax=" << relax << ") " << new_shapes->size()
- << " new shapes through " << num_loops << " loops and "
- << resources.size() << " resources" << std::endl;
+ VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
+ << num_loops << " loops and " << resource_handles.size()
+ << " resources" << std::endl;
const int64 max_loop_length = item_.graph.node_size();
const int64 max_rank = 4;
const int64 max_loop_iterations =
max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
- const int64 num_queues = resources.size();
+ const int64 num_queues = resource_handles.size();
const int64 max_resource_iterations = num_queues * num_queues * max_rank;
int64 num_resource_iterations = 0;
@@ -1068,22 +989,22 @@ Status GraphProperties::PropagateShapes(
num_loop_iterations++ < max_loop_iterations) {
const NodeDef* n = new_shapes->pop();
bool updated = false;
- TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, n, &updated));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(shape_refiner, resource_handles, n, &updated));
if (updated) {
- for (const GraphView::InputPort fanout :
+ for (const GraphView::InputPort& fanout :
shape_refiner->graph().GetFanouts(*n, false)) {
new_shapes->push(fanout.node);
}
+ // Make sure the corresponding queue nodes are (re)processed.
+ if (IsEnqueue(*n)) {
+ auto it = resource_handles.find(n);
+ if (it != resource_handles.end()) {
+ new_shapes->push(it->second);
+ }
+ }
}
}
-
- for (const auto& resource : resources) {
- // Resources need special handling: since the enqueue nodes are in the
- // fanout of the queues, we need to manually propagate the shapes from
- // enqueue node to the corresponding queue.
- TF_RETURN_IF_ERROR(UpdateResource(resource.first, resource.second,
- shape_refiner, new_shapes));
- }
} while (!new_shapes->empty() &&
num_resource_iterations++ < max_resource_iterations);
@@ -1094,54 +1015,48 @@ Status GraphProperties::PropagateShapes(
return Status::OK();
}
-Status GraphProperties::UpdateResource(
- const NodeDef* qnode,
- const std::unordered_set<const NodeDef*>& queue_inputs,
- SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes) {
- // Proceed only if qnode is a queue or an Enter with queue input.
- if (!IsQueue(*qnode) && !IsEnterWithQueue(*qnode, shape_refiner->graph())) {
+Status GraphProperties::UpdateEnqueue(
+ const NodeDef* enqueue_node,
+ const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
+ SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
+ auto ctx = shape_refiner->GetNodeContext(enqueue_node);
+ if (!ctx) {
+ TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
+ ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
+ }
+
+ auto it = resource_handles.find(enqueue_node);
+ if (it == resource_handles.end()) {
+ // The corresponding queue was not found, there isn't much we can do.
return Status::OK();
}
+ const NodeDef* qnode = it->second;
auto qctx = shape_refiner->GetContext(qnode);
if (!qctx) {
return Status::OK();
}
auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
- // Merge all inputs into the enqueue node, regardless of which phase we
- // are in.
- std::vector<ShapeAndType> queue_shapes_and_types;
- for (const auto& node : queue_inputs) {
- auto ctx = shape_refiner->GetNodeContext(node);
- if (!ctx) {
- continue;
- }
- // TODO(bsteiner): handle EnqueueMany as well.
- if (node->op().find("Enqueue") != std::string::npos &&
- node->op().find("EnqueueMany") == std::string::npos) {
- std::vector<ShapeAndType> shapes_and_types;
- for (int i = 1; i < ctx->input_types.size(); ++i) {
- shapes_and_types.push_back(
- {ctx->inference_context->input(i), ctx->input_types[i]});
- }
- if (queue_shapes_and_types.empty()) {
- queue_shapes_and_types = shapes_and_types;
- } else {
- TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
- shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types));
- }
- }
+ // TODO(bsteiner): handle EnqueueMany as well.
+ std::vector<ShapeAndType> shapes_and_types;
+ for (int i = 1; i < ctx->input_types.size(); ++i) {
+ GraphView::InputPort inp(enqueue_node, i);
+ GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
+ InferenceContext* in = shape_refiner->GetContext(fanin.node);
+ ShapeHandle input = in->output(fanin.port_id);
+ ctx->inference_context->SetInput(i, input);
+ shapes_and_types.push_back({input, ctx->input_types[i]});
}
- if (queue_handle_data == nullptr ||
- !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
- queue_shapes_and_types)) {
- qctx->set_output_handle_shapes_and_types(0, queue_shapes_and_types);
-
- for (const GraphView::InputPort fanout :
- shape_refiner->graph().GetFanouts(*qnode, false)) {
- new_shapes->push(fanout.node);
- }
+ if (queue_handle_data == nullptr) {
+ qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
+ *new_shapes = true;
+ } else {
+ TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
+ shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
+ *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
+ shapes_and_types);
+ qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
}
return Status::OK();
@@ -1159,75 +1074,96 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
}
}
- std::unordered_map<const NodeDef*, int> topo_order;
- TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
-
- GraphView graph_view(&item_.graph);
+ GraphView graph_view(const_cast<GraphDef*>(&item_.graph));
// List the resources and the nodes using them. Also collect the Merge nodes,
// fed nodes, and primary inputs.
- std::unordered_map<const NodeDef*, std::unordered_set<const NodeDef*>>
+ std::unordered_map<const NodeDef*,
+ std::pair<std::unordered_set<const NodeDef*>,
+ std::unordered_set<const NodeDef*>>>
resources;
std::unordered_set<const NodeDef*> merge_nodes;
std::unordered_set<const NodeDef*> fed_nodes;
std::unordered_set<const NodeDef*> primary_inputs;
int num_loops = 0;
for (const NodeDef& node : item_.graph.node()) {
+ if (IsQueue(node)) {
+ for (const GraphView::InputPort& fanout :
+ graph_view.GetFanouts(node, false)) {
+ if (IsEnter(*fanout.node)) {
+ const NodeDef& enter = *fanout.node;
+ for (const GraphView::InputPort& fanout :
+ graph_view.GetFanouts(enter, false)) {
+ if (IsEnqueue(*fanout.node)) {
+ resources[&node].first.insert(fanout.node);
+ } else if (IsDequeue(*fanout.node)) {
+ resources[&node].second.insert(fanout.node);
+ }
+ }
+ } else {
+ if (IsEnqueue(*fanout.node)) {
+ resources[&node].first.insert(fanout.node);
+ } else if (IsDequeue(*fanout.node)) {
+ resources[&node].second.insert(fanout.node);
+ }
+ }
+ }
+ }
if (NumNonControlInputs(node) == 0) {
primary_inputs.insert(&node);
} else if (IsMerge(node)) {
merge_nodes.insert(&node);
} else if (IsNextIteration(node)) {
++num_loops;
- } else {
- const OpRegistrationData* op_data;
- TF_RETURN_IF_ERROR(function_library.LookUp(node.op(), &op_data));
- DataTypeVector input_types;
- DataTypeVector output_types;
- TF_RETURN_IF_ERROR(InOutTypesForNode(node, op_data->op_def, &input_types,
- &output_types));
- for (int i = 0; i < input_types.size(); ++i) {
- if (input_types[i] == DataType::DT_RESOURCE) {
- GraphView::InputPort input(&node, i);
- const GraphView::OutputPort resource =
- graph_view.GetRegularFanin(input);
- resources[resource.node].insert(&node);
- }
- }
}
if (fed_ports.find(node.name()) != fed_ports.end()) {
fed_nodes.insert(&node);
}
}
- SymbolicShapeRefiner refiner(graph_view, fed_ports);
-
- // We propagate shapes through the graph in two phases. In the first phase, we
- // exclusively merge shapes but we do not propagate shapes through the
- // backedge of loops (i.e. the NextIteration node). Then on the second phase,
- // we exclusively relax shapes and propagate shapes through loops until
- // reaching fixed point.
- for (int relax = 0; relax < 2; relax++) {
- TopoQueue new_shapes(topo_order);
- // Seed the propagation of shapes through merge nodes.
- if (relax) {
- for (const NodeDef* node : merge_nodes) {
- new_shapes.push(node);
+ std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
+ std::vector<std::pair<const NodeDef*, const NodeDef*>> extra_deps;
+ for (const auto& resource : resources) {
+ for (const NodeDef* src : resource.second.first) {
+ resource_handles[src] = resource.first;
+ for (const NodeDef* tgt : resource.second.second) {
+ // Add control edges from enqueue to dequeue nodes to ensure they are
+ // processed in their logical order.
+ extra_deps.emplace_back(src, tgt);
}
}
- // Also seed the propagation of shapes in the fanout of primary inputs.
- for (const NodeDef* node : primary_inputs) {
- new_shapes.push(node);
- }
- // Also seed the propagation of shapes in the fanout of fed nodes.
- for (const NodeDef* node : fed_nodes) {
- new_shapes.push(node);
+ }
+
+ std::unordered_map<const NodeDef*, int> topo_order;
+ Status s = ComputeTopologicalOrder(item_.graph, &topo_order, &extra_deps);
+ if (!s.ok()) {
+ if (extra_deps.empty()) {
+ return s;
+ } else {
+ // There is a loop between queues: we'll just use the graph topological
+ // order. This will make the shape inference less precise but since this
+ // isn't common it's not worth to figure out where to break the loop and
+ // do a proper relaxation.
+ TF_RETURN_IF_ERROR(
+ ComputeTopologicalOrder(item_.graph, &topo_order, nullptr));
}
- // Propagate shapes normally.
- TF_RETURN_IF_ERROR(
- PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
}
+ SymbolicShapeRefiner refiner(graph_view, fed_ports);
+
+ TopoQueue new_shapes(topo_order);
+ // Also seed the propagation of shapes in the fanout of primary inputs.
+ for (const NodeDef* node : primary_inputs) {
+ new_shapes.push(node);
+ }
+ // Also seed the propagation of shapes in the fanout of fed nodes.
+ for (const NodeDef* node : fed_nodes) {
+ new_shapes.push(node);
+ }
+ // Propagate shapes normally.
+ TF_RETURN_IF_ERROR(
+ PropagateShapes(&refiner, &new_shapes, resource_handles, num_loops));
+
// Track shapes globally across the graph.
SymbolicShapeManager shape_manager;
bool found_error = false;
@@ -1271,7 +1207,6 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Fill input properties.
{
- // CHECK_EQ(ctx->num_inputs(), node.num_inputs());
auto& input_properties = input_properties_[node.name()];
// Should always be empty, node names in graph are supposed to be unique.
@@ -1295,7 +1230,6 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Fill output properties.
{
- // CHECK_EQ(ctx->num_outputs(), node->num_outputs());
auto& output_properties = output_properties_[node.name()];
// Should always be empty, node names in graph are supposed to be unique.
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 7d685b5833..8703613a12 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -38,6 +38,7 @@ class TopoQueue;
// and data type properties.
class GraphProperties {
public:
+ // The item must outlive the properties
explicit GraphProperties(const GrapplerItem& item) : item_(item) {}
// Infer the shapes through abstract interpretation. Feed information can be
@@ -75,12 +76,6 @@ class GraphProperties {
void ClearOutputProperties(const string& node_name);
private:
- // Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
- // <*queue_shapes_and_types>.
- static Status MergeEnqueueShapesAndTypes(
- SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
- const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
- std::vector<shape_inference::ShapeAndType>* queue_shapes_and_types);
// Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
static Status RelaxEnqueueShapesAndMergeTypes(
@@ -88,35 +83,37 @@ class GraphProperties {
const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
std::vector<shape_inference::ShapeAndType>* queue_shapes_and_types);
- // Update the shapes for qnode. If output shapes of qnode have changed,
- // enqueue its fanout in 'new_shapes'.
- static Status UpdateResource(
- const NodeDef* qnode,
- const std::unordered_set<const NodeDef*>& queue_inputs,
- SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes);
+ // Update the shapes of the enqueue node, port them over to the corresponding
+ // queue, and schedule the reprocessing of the queue if needed.
+ static Status UpdateEnqueue(
+ const NodeDef* enqueue_node,
+ const std::unordered_map<const NodeDef*, const NodeDef*>&
+ resource_handles,
+ SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
// Update the output shapes of a Merge node, and enqueue its fanout in
// new_shapes if needed.
Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
- const NodeDef* node, bool relax,
- bool* new_shapes) const;
+ const NodeDef* node, bool* new_shapes) const;
// Process the Enter node, and enqueue its fanout in new_shapes if needed.
static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner,
- const NodeDef* node, bool relax, bool* new_shapes);
+ const NodeDef* node, bool* new_shapes);
// Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'.
- Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax,
+ Status UpdateShapes(SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<const NodeDef*, const NodeDef*>&
+ resource_handles,
const NodeDef* n, bool* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached.
Status PropagateShapes(
- SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
- const std::unordered_map<const NodeDef*,
- std::unordered_set<const NodeDef*>>& resources,
+ SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
+ const std::unordered_map<const NodeDef*, const NodeDef*>&
+ resource_handles,
int num_loops) const;
// Data members
- GrapplerItem item_;
+ const GrapplerItem& item_;
std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
const std::vector<OpInfo::TensorProperties> missing_properties_;
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index afe334dfa2..a53f6414c3 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -282,20 +282,11 @@ TEST_F(GraphPropertiesTest, Queues) {
auto dequeue2 =
ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
- // Create a queue that feeds itself.
- auto q3 =
- ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
- auto dequeue3 =
- ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
- auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2});
- auto enqueue3 =
- ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output});
-
auto q4 =
ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
auto enqueue4_2 =
- ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue3[0]});
+ ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue2[0]});
auto dequeue4 =
ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
@@ -327,10 +318,6 @@ TEST_F(GraphPropertiesTest, Queues) {
ASSERT_EQ(1, props2.size());
EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
- const auto props3 = properties.GetOutputProperties("Dequeue3");
- ASSERT_EQ(1, props3.size());
- EXPECT_EQ("float: [3,7]", PropToString(props3[0]));
-
// The dequeue3 op shape is unknown. The square2 op shape is known. Verify
// that we merge the 2 properly to determine the shape of the data coming out
// of the queue.
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 7a89c26374..bf6d4c0921 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -54,6 +54,10 @@ bool IsApproximateEqual(const NodeDef& node) {
bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
+bool IsAssign(const NodeDef& node) {
+ return node.op() == "Assign" || node.op() == "AssignVariableOp";
+}
+
bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
@@ -250,6 +254,10 @@ bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
+bool IsQueue(const NodeDef& node) {
+ return str_util::EndsWith(node.op(), "QueueV2");
+}
+
bool IsRandomShuffle(const NodeDef& node) {
return node.op() == "RandomShuffle";
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 976d23e527..3dddf3f1ea 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -21,7 +21,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-
bool IsAdd(const NodeDef& node);
bool IsAddN(const NodeDef& node);
bool IsAll(const NodeDef& node);
@@ -31,6 +30,7 @@ bool IsAnyDiv(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node);
+bool IsAssign(const NodeDef& node);
bool IsAtan2(const NodeDef& node);
bool IsBetainc(const NodeDef& node);
bool IsBiasAdd(const NodeDef& node);
@@ -98,6 +98,7 @@ bool IsPolygamma(const NodeDef& node);
bool IsPrint(const NodeDef& node);
bool IsProd(const NodeDef& node);
bool IsPow(const NodeDef& node);
+bool IsQueue(const NodeDef& node);
bool IsRandomShuffle(const NodeDef& node);
bool IsReal(const NodeDef& node);
bool IsRealDiv(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 18076eee96..bf59b25449 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -302,6 +302,11 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
}
}
+ bool IsInPreserveSet(const NodeDef& node) const {
+ return ctx().nodes_to_preserve->find(node.name()) !=
+ ctx().nodes_to_preserve->end();
+ }
+
private:
// Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
@@ -474,11 +479,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
return group.root_node->device() == node.device();
}
- bool IsInPreserveSet(const NodeDef& node) const {
- return ctx().nodes_to_preserve->find(node.name()) !=
- ctx().nodes_to_preserve->end();
- }
-
bool IsAlreadyOptimized(const NodeDef& node) const {
return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
}
@@ -1340,65 +1340,143 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
};
// This optimization hoists the common prefix of unary ops of the inputs to
-// concat out of the concat.
-// For example: Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))]) ->
-// Exp(Sin(Concat([x, y, z]))).
+// concat out of the concat, for example:
+// Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
+// becomes
+// Exp(Sin(Concat([x, y, z]))).
+// Similarly, it will hoist the common postfix of unary ops into Split or
+// SplitV nodes, for example:
+// [Exp(Sin(y)) for y in Split(x)]
+// becomes
+// [y for y in Split(Exp(Sin(x))]
+//
// TODO(rmlarsen): Support casting. We would have to change the type attribute
-// on the concat node.
-class HoistCWiseUnaryFromConcatStage : public ArithmeticOptimizerStage {
+// on the concat/split node.
+// TODO(rmlarsen): Handle Enter/Exit.
+class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
public:
- explicit HoistCWiseUnaryFromConcatStage(
- const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
+ explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("", ctx, ctx_ext) {}
- ~HoistCWiseUnaryFromConcatStage() override = default;
+ ~HoistCWiseUnaryChainsStage() override = default;
+
+ struct ChainLink {
+ ChainLink() = default;
+ ChainLink(NodeDef* _node, int _port_origin)
+ : node(_node), port_origin(_port_origin) {}
+ NodeDef* node; // Node in a chain.
+ int port_origin; // Port on concat/split node from which this chain
+ // originates.
+
+ bool operator<(const ChainLink& other) const {
+ if (port_origin < other.port_origin) {
+ return true;
+ } else if (port_origin > other.port_origin) {
+ return false;
+ } else {
+ return node->name() < other.node->name();
+ }
+ }
+ };
+
+ // We use an ordinary set sorted on port and node name, so the order, and
+ // hence the node name used for the hoisted chain, will be deterministic.
+ using ChainLinkSet = std::set<ChainLink>;
bool IsSupported(const NodeDef* node) const override {
- if (!IsConcat(*node)) return false;
- const int n = node->attr().at("N").i();
- return n > 1;
+ if (IsInPreserveSet(*node)) return false;
+ if (IsConcat(*node)) {
+ const int n = node->attr().at("N").i();
+ return n > 1;
+ } else if (IsSplit(*node) || IsSplitV(*node)) {
+ const int num_split = node->attr().at("num_split").i();
+ return num_split > 1 && !IsAlreadyOptimized(*node);
+ }
+ return false;
}
- Status TrySimplify(NodeDef* concat_node,
- string* simplified_node_name) override {
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ node_is_concat_ = IsConcat(*node);
int prefix_length;
std::set<string> ctrl_inputs;
+ ChainLinkSet tails;
TF_RETURN_IF_ERROR(
- FindCommonUnaryOpPrefix(*concat_node, &prefix_length, &ctrl_inputs));
- if (prefix_length > 0) {
+ FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
+ if (prefix_length > 0 && !tails.empty()) {
TF_RETURN_IF_ERROR(
- HoistUnaryOpPrefix(prefix_length, &ctrl_inputs, concat_node));
- AddToOptimizationQueue(concat_node);
+ HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
}
return Status::OK();
}
private:
- void RemoveControlInputs(std::set<string>* removed_ctrl_inputs,
- NodeDef* node) const {
- const int num_inputs = node->input_size();
- for (int idx = num_inputs - 1; idx >= 0; --idx) {
- const string& input = node->input(idx);
- if (IsControlInput(input)) {
- removed_ctrl_inputs->insert(input);
- ctx().node_map->RemoveOutput(NodeName(input), node->name());
- node->mutable_input()->RemoveLast();
- } else {
- break;
+ // Returns the length of the common unary chain of ops that can be
+ // hoisted to the other side of concat or split.
+ Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
+ ChainLinkSet* tails,
+ std::set<string>* ctrl_inputs) const {
+ *prefix_length = 0;
+ // Follow the chains starting at each concat input or split output as long
+ // as all the following conditions hold:
+ // 1. The ops in all chains are the same.
+ // 2. The ops are unary elemenwise op.
+ // 3. The op output has only a single consumer (concat only).
+ ChainLinkSet cur_tails;
+ TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
+ if (cur_tails.size() < 2) {
+ return Status::OK();
+ }
+ ctrl_inputs->clear();
+ bool stop = false;
+ while (!stop && !cur_tails.empty() &&
+ OpsAreSafeToHoist(root_node, cur_tails)) {
+ // We found one more link that can be hoisted.
+ ++(*prefix_length);
+ tails->swap(cur_tails);
+ GatherControlInputs(ctrl_inputs, *tails);
+
+ // Advance tail pointers to the next level.
+ TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
+ }
+ return Status::OK();
+ }
+
+ // Hoists the chains to the other side of concat or split and attaches the
+ // control inputs gathered from them to the concat or split node.
+ Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
+ std::set<string>* ctrl_inputs, NodeDef* root_node) {
+ if (tails.empty()) {
+ return Status::OK();
+ }
+ AddControlInputs(ctrl_inputs, root_node);
+ AddToOptimizationQueue(root_node);
+ optimized_nodes_.insert(root_node->name());
+ if (node_is_concat_) {
+ return HoistChainForConcat(prefix_length, tails, root_node);
+ } else {
+ return HoistChainForSplit(prefix_length, tails, root_node);
+ }
+ }
+
+ void GatherControlInputs(std::set<string>* ctrl_inputs,
+ const ChainLinkSet& ops) const {
+ for (const auto& link : ops) {
+ const NodeDef* node = link.node;
+ for (int i = node->input_size() - 1; i >= 0; --i) {
+ const string& input = node->input(i);
+ if (!IsControlInput(input)) break;
+ ctrl_inputs->insert(input);
}
}
}
void AddControlInputs(std::set<string>* new_ctrl_inputs,
NodeDef* node) const {
- for (int idx = node->input_size() - 1; idx >= 0; --idx) {
- const string& existing_input = node->input(idx);
- if (IsControlInput(existing_input)) {
- new_ctrl_inputs->erase(existing_input);
- } else {
- break;
- }
+ for (int i = node->input_size() - 1; i >= 0; --i) {
+ const string& existing_input = node->input(i);
+ if (!IsControlInput(existing_input)) break;
+ new_ctrl_inputs->erase(existing_input);
}
for (const string& new_input : *new_ctrl_inputs) {
ctx().node_map->AddOutput(NodeName(new_input), node->name());
@@ -1406,113 +1484,193 @@ class HoistCWiseUnaryFromConcatStage : public ArithmeticOptimizerStage {
}
}
- // Returns the length of the common unary prefix chain of ops that can be
- // hoisted out of concat.
- Status FindCommonUnaryOpPrefix(const NodeDef& concat_node, int* prefix_length,
- std::set<string>* ctrl_inputs) const {
- *prefix_length = 0;
- const int n = concat_node.attr().at("N").i();
- // Follow the chains backwards from each concat input as long as all the
- // following conditions hold:
- // 1. The ops in all chains are the same.
- // 2. The op is a unary elemenwise op.
- // 3. The op output has only a single consumer.
- std::vector<NodeDef*> tail(n, nullptr);
- const int start = concat_node.op() == "Concat" ? 1 : 0;
- const int end = start + n;
- // Set up tail pointers to point to the immediate inputs to Concat.
- for (int i = start; i < end; ++i) {
- if (IsControlInput(concat_node.input(i))) {
- return errors::FailedPrecondition("Got control input ",
- concat_node.input(i),
- " where normal input was expected.");
- }
- TF_RETURN_IF_ERROR(GetInputNode(concat_node.input(i), &tail[i - start]));
- }
-
- bool stop = false;
- ctrl_inputs->clear();
- while (!stop) {
- const NodeDef* tail0 = tail[0];
- if (!IsUnaryElementWise(*tail0)) break;
- for (int chain = 0; chain < n; ++chain) {
- // TODO(rmlarsen): Allow and hoist outgoing control edges.
- if (tail[chain]->op() != tail0->op() ||
- ctx().node_map->GetOutputs(tail[chain]->name()).size() > 1) {
- stop = true;
- break;
+ Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
+ if (node_is_concat_) {
+ // Handle concat nodes by looking backwards in the graph.
+ const int n = node.attr().at("N").i();
+ const int start = node.op() == "Concat" ? 1 : 0;
+ const int end = start + n;
+ // Set up tail pointers to point to the immediate inputs to Concat.
+ for (int input_port = start; input_port < end; ++input_port) {
+ if (IsControlInput(node.input(input_port))) {
+ return errors::FailedPrecondition(
+ "Got control input ", node.input(input_port),
+ " where normal input was expected.");
}
+ NodeDef* tail;
+ TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
+ tails->insert(ChainLink(tail, input_port));
}
- if (stop) break;
- // We found one more op that can be hoisted.
- ++(*prefix_length);
- for (int chain = 0; chain < n; ++chain) {
- RemoveControlInputs(ctrl_inputs, tail[chain]);
- }
- // Advance tail pointers to the next level.
- for (int chain = 0; chain < n; ++chain) {
- if (tail[chain]->input_size() == 0 ||
- IsControlInput(tail[chain]->input(0))) {
- stop = true;
- break;
+ return Status::OK();
+ } else {
+ // Handle split nodes by looking forwards in the graph.
+ const auto& outputs = ctx().node_map->GetOutputs(node.name());
+ for (NodeDef* output : outputs) {
+ if (IsControlInput(output->input(0))) continue;
+ int port;
+ const string node_name = ParseNodeName(output->input(0), &port);
+ if (node_name == node.name()) {
+ tails->insert(ChainLink(output, port));
} else {
- NodeDef* new_tail = nullptr;
- TF_RETURN_IF_ERROR(GetInputNode(tail[chain]->input(0), &new_tail));
- tail[chain] = new_tail;
+ // This output node has a non-control input other than the split node,
+ // abort.
+ tails->clear();
+ return Status::OK();
}
}
}
return Status::OK();
}
- Status HoistUnaryOpPrefix(const int prefix_length,
- std::set<string>* ctrl_inputs,
- NodeDef* concat_node) {
- const int n = concat_node->attr().at("N").i();
- const int start = concat_node->op() == "Concat" ? 1 : 0;
- const int end = start + n;
- const std::set<NodeDef*> consumers =
- ctx().node_map->GetOutputs(concat_node->name());
- AddControlInputs(ctrl_inputs, concat_node);
- for (int chain = 0; chain < (end - start); ++chain) {
- NodeDef* tail = nullptr;
- const string concat_input = concat_node->input(chain + start);
- for (int distance = 0; distance < prefix_length; ++distance) {
- if (distance == 0) {
- TF_RETURN_IF_ERROR(GetInputNode(concat_input, &tail));
- } else {
- TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &tail));
+ bool OpsAreSafeToHoist(const NodeDef& root_node,
+ const ChainLinkSet& ops) const {
+ if (ops.empty()) return true;
+ const NodeDef* op0 = ops.begin()->node;
+ if (!IsUnaryElementWise(*op0)) return false;
+ for (const auto& link : ops) {
+ const NodeDef* op = link.node;
+ if (op->device() != root_node.device() || op->op() != op0->op() ||
+ IsInPreserveSet(*op)) {
+ return false;
+ }
+ if (node_is_concat_ &&
+ ctx().node_map->GetOutputs(op->name()).size() > 1) {
+ // TODO(rmlarsen): Allow and hoist outgoing control edges.
+ return false;
+ }
+ }
+ return true;
+ }
+
+ Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
+ bool* stop) const {
+ *stop = true;
+ new_tails->clear();
+ for (const auto& link : tails) {
+ const NodeDef* tail = link.node;
+ if (node_is_concat_) {
+ if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
+ return Status::OK();
+ }
+ NodeDef* new_tail;
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
+ // Remember original port.
+ new_tails->insert(ChainLink(new_tail, link.port_origin));
+ } else {
+ for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
+ int port;
+ const string node_name = ParseNodeName(new_tail->input(0), &port);
+ if (node_name != tail->name()) {
+ return Status::OK();
+ }
+ // Skip control outputs.
+ if (port >= 0) {
+ // Remember original port.
+ new_tails->insert(ChainLink(new_tail, link.port_origin));
+ }
}
}
+ }
+ *stop = false;
+ return Status::OK();
+ }
+ Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
+ NodeDef* concat_node) {
+ const string& concat_name = concat_node->name();
+ const int first_input = concat_node->op() == "Concat" ? 1 : 0;
+ for (const auto& link : tails) {
+ NodeDef* tail = CHECK_NOTNULL(link.node);
+ const int concat_port = link.port_origin;
+ CHECK_GE(concat_port, 0);
+ CHECK_LT(concat_port, concat_node->input_size());
+ const string concat_input = concat_node->input(concat_port);
// Hook the node following tail directly into the concat node.
const string tail_input = tail->input(0);
- concat_node->set_input(chain + start, tail_input);
- ctx().node_map->UpdateInput(concat_node->name(), concat_input,
- tail_input);
-
- if (chain == 0) {
- // Reuse nodes in the first chain to process output of concat.
- tail->set_input(0, concat_node->name());
- ctx().node_map->UpdateInput(tail->name(), tail_input,
- concat_node->name());
+ concat_node->set_input(concat_port, tail_input);
+ ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
+ if (concat_port == first_input) {
// Update the consumers of concat to consume the end of the chain
// instead.
- for (NodeDef* consumer : consumers) {
- for (int idx = 0; idx < consumer->input_size(); ++idx) {
- if (consumer->input(idx) == concat_node->name()) {
- consumer->set_input(idx, concat_input);
- ctx().node_map->UpdateInput(consumer->name(), concat_node->name(),
- concat_input);
- }
- }
- AddToOptimizationQueue(consumer);
- }
+ UpdateConsumers(concat_node, concat_input);
+ // Reuse nodes in the first chain to process output of concat.
+ tail->set_input(0, concat_name);
+ ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
}
}
return Status::OK();
}
+
+ Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
+ NodeDef* split_node) {
+ // Create a new chain before the split node to process the input tensor.
+ const string& split_name = split_node->name();
+ auto root_scope_and_name = ParseNodeScopeAndName(split_name);
+
+ // We use the first tail node in the set as a template to get the list of
+ // ops to apply (starting from the end).
+ NodeDef* cur_tail = tails.begin()->node;
+ NodeDef* cur_copy = AddCopyNode(
+ OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
+ cur_copy->clear_input();
+
+ // Update the split to take its input from the tail of the new chain.
+ const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
+ const string orig_input = split_node->input(value_slot);
+ split_node->set_input(value_slot, cur_copy->name());
+ ctx().node_map->UpdateInput(split_node->name(), orig_input,
+ cur_copy->name());
+ TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
+
+ // Now walk backwards creating the rest of the chain.
+ while (cur_tail != split_node) {
+ NodeDef* new_copy = AddCopyNode(
+ OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
+ new_copy->clear_input();
+ cur_copy->add_input(new_copy->name());
+ ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
+ cur_copy = new_copy;
+ TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
+ }
+ // Connect the original input to the head of the new chain.
+ cur_copy->add_input(orig_input);
+ ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
+ cur_copy->name());
+
+ // Connect all consumers of the tail nodes directly to the
+ // output port of Split from which the chain started.
+ for (const auto& link : tails) {
+ UpdateConsumers(link.node,
+ link.port_origin == 0
+ ? split_name
+ : strings::StrCat(split_name, ":", link.port_origin));
+ }
+ return Status::OK();
+ }
+
+ // Update consumers of node to take new_input as input instead.
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+
+ bool IsAlreadyOptimized(const NodeDef& node) const {
+ return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
+ }
+
+ private:
+ bool node_is_concat_;
+ std::unordered_set<string> optimized_nodes_;
};
// Performs the conversion:
@@ -2200,8 +2358,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
- if (options_.hoist_unary_out_of_concat)
- pipeline.AddStage<HoistCWiseUnaryFromConcatStage>(ctx, ctx_ext);
+ if (options_.hoist_cwise_unary_chains)
+ pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
if (options_.convert_sqrt_div_to_rsqrt_mul)
pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
@@ -2304,5 +2462,5 @@ void ArithmeticOptimizer::Feedback(Cluster* /*cluster*/,
// Nothing to do for ArithmeticOptimizer.
}
-} // end namespace grappler
-} // end namespace tensorflow
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 24a2a50719..3b297ec0aa 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
bool remove_negation = true;
- bool hoist_unary_out_of_concat = false;
+ bool hoist_cwise_unary_chains = false;
bool convert_sqrt_div_to_rsqrt_mul = false;
// Choose which arithmetic optimizer stages will be enabled for a given
@@ -73,9 +73,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
ArithmeticOptimizerOptions options;
- if (opt_level == RewriterConfig::AGGRESSIVE) {
- options.hoist_unary_out_of_concat = true;
- }
return options;
}
};
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 7485d99c3b..f903f53a35 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -94,6 +94,16 @@ class ArithmeticOptimizerTest : public GrapplerTest {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
}
+ // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
+ void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+ GraphDef* output) {
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ item->graph.Swap(output);
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ item->graph.Swap(output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
+ }
+
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
@@ -149,9 +159,9 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.remove_negation = true;
}
- void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) {
+ void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
- optimizer->options_.hoist_unary_out_of_concat = true;
+ optimizer->options_.hoist_cwise_unary_chains = true;
}
void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
@@ -2136,14 +2146,18 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
- Output b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
- Output c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
+ Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
+ Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
Output axis = ops::Const(s.WithOpName("axis"), 0, {});
Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
// Test case with chains of length 1.
+ // Rewrites
+ // Concat({Exp(a), Exp(b), Exp(c)})
+ // into
+ // Exp(Concat({a, b, c})).
Output sin_a =
ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
Output exp_a =
@@ -2156,6 +2170,10 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
Output id = ops::Identity(s.WithOpName("id"), concat);
// Test case with chains of length 2.
+ // Rewrites
+ // Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
+ // into
+ // Cos(Exp(Concat({a, b, c}))).
Output exp_a2 =
ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
@@ -2173,11 +2191,13 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
item.fetch = {"id", "id2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
GraphDef output;
ArithmeticOptimizer optimizer;
- EnableOnlyHoistCWiseUnaryFromConcat(&optimizer);
+ EnableOnlyHoistCWiseUnaryChains(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
- OptimizeAndPrune(&optimizer, &item, &output);
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "concat") {
@@ -2191,8 +2211,9 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
found++;
}
if (node.name() == "exp_a") {
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("concat", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
found++;
}
if (node.name() == "id") {
@@ -2213,13 +2234,15 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
found++;
}
if (node.name() == "exp_a2") {
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("concat2", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
found++;
}
if (node.name() == "cos_exp_a2") {
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("exp_a2", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
found++;
}
if (node.name() == "id2") {
@@ -2229,6 +2252,142 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
}
}
EXPECT_EQ(7, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(tensors.size(), tensors_expected.size());
+ EXPECT_EQ(tensors.size(), item.fetch.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
+ }
+}
+
+TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
+ Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+ Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
+ Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
+ Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
+ // Test case with chains of length 1.
+ // Rewrites
+ // [Sin(y) for y in Split(x)]
+ // into
+ // [y for y in Split(Sin(x))].
+ ops::Split split1(s.WithOpName("split1"), axis, x, 2);
+ Output sin_a =
+ ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
+ Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
+ Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
+ Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
+ Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
+
+ // Test case with SplitV and chains of length 2.
+ // Rewrites
+ // [Cos(Exp(y)) for y in Split(x)]
+ // into
+ // [y for y in Split(Cos(Exp(x)))].
+ Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
+ ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
+ Output exp_a2 = ops::Exp(
+ s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
+ Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
+ Output cos_exp_a2 = ops::Cos(
+ s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
+ Output cos_exp_b2 = ops::Cos(
+ s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
+ Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
+ Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
+
+ GrapplerItem item;
+ item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCWiseUnaryChains(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ // The following 6 nodes should be pruned.
+ EXPECT_NE(node.name(), "sin_a");
+ EXPECT_NE(node.name(), "sin_b");
+ EXPECT_NE(node.name(), "exp_a2");
+ EXPECT_NE(node.name(), "exp_b2");
+ EXPECT_NE(node.name(), "cos_exp_a2");
+ EXPECT_NE(node.name(), "cos_exp_b2");
+
+ if (node.name() == "split1") {
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("axis", node.input(0));
+ EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
+ EXPECT_EQ("^ctrl1", node.input(2));
+ found++;
+ }
+ if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
+ EXPECT_EQ("Sin", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ found++;
+ }
+ if (node.name() == "id_a") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("split1", node.input(0));
+ found++;
+ }
+ if (node.name() == "exp_b") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("split1:1", node.input(0));
+ found++;
+ }
+ if (node.name() == "id_b") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("exp_b", node.input(0));
+ found++;
+ }
+ if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
+ EXPECT_EQ("Exp", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ found++;
+ }
+ if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
+ EXPECT_EQ("Cos", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("ArithmeticOptimizer/_exp_a2_split2", node.input(0));
+ found++;
+ }
+ if (node.name() == "split2") {
+ EXPECT_EQ(6, node.input_size());
+ EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
+ EXPECT_EQ("size_splits2", node.input(1));
+ EXPECT_EQ("axis", node.input(2));
+ EXPECT_EQ("^ctrl1", node.input(3));
+ EXPECT_EQ("^ctrl2", node.input(4));
+ EXPECT_EQ("^ctrl3", node.input(5));
+ found++;
+ }
+ if (node.name() == "id_a2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("split2", node.input(0));
+ found++;
+ }
+ if (node.name() == "id_b2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("split2:1", node.input(0));
+ found++;
+ }
+ }
+ EXPECT_EQ(10, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(tensors.size(), tensors_expected.size());
+ EXPECT_EQ(tensors.size(), item.fetch.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
+ }
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 3a6de9e3b2..1bec9086f7 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -79,6 +79,7 @@ class FunctionOptimizerContext {
explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level,
const GrapplerItem& item)
: function_library_(OpRegistry::Global(), item.graph.library()) {
+ InitializeTrulyConstNodes(item);
InitializeInlinedFunctions(opt_level, item);
}
@@ -86,20 +87,41 @@ class FunctionOptimizerContext {
return function_library_;
}
- FunctionLibraryDefinition& mutable_function_library() {
- return function_library_;
+ FunctionLibraryDefinition* mutable_function_library() {
+ return &function_library_;
}
bool IsInlinedFunction(const string& name) const {
return inlined_functions_.count(name) > 0;
}
+ bool IsTrulyConst(const string& name) const {
+ return TrulyConstNode(name) != nullptr;
+ }
+
+ const NodeDef* TrulyConstNode(const string& name) const {
+ return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
+ }
+
// Find inlining candidate by name. Return nullptr if not found.
const FunctionDef* FindInlinedFunction(const string& name) const {
return gtl::FindWithDefault(inlined_functions_, name, nullptr);
}
private:
+ void InitializeTrulyConstNodes(const GrapplerItem& item) {
+ std::unordered_set<string> feed_nodes;
+ for (const auto& feed : item.feed) {
+ feed_nodes.insert(NodeName(feed.first));
+ }
+
+ for (const NodeDef& node : item.graph.node()) {
+ if (IsConstant(node) && feed_nodes.count(node.name()) == 0) {
+ truly_const_nodes_[node.name()] = &node;
+ }
+ }
+ }
+
void InitializeInlinedFunctions(RewriterConfig::Toggle opt_level,
const GrapplerItem& item) {
bool aggressive = opt_level == RewriterConfig::AGGRESSIVE;
@@ -123,10 +145,20 @@ class FunctionOptimizerContext {
FunctionLibraryDefinition function_library_;
// Functions that can be inlined into optimized graph.
std::unordered_map<string, const FunctionDef*> inlined_functions_;
+ // Nodes that are Const and not in feed.
+ std::unordered_map<string, const NodeDef*> truly_const_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
};
+bool HasTrulyConstInputs(const NodeDef& node,
+ const FunctionOptimizerContext& ctx) {
+ const auto is_truly_const = [&ctx](const string& input) {
+ return ctx.IsTrulyConst(NodeName(input));
+ };
+ return std::any_of(node.input().begin(), node.input().end(), is_truly_const);
+}
+
// Return trimmed FunctionDefLibrary with functions that are reachable from
// the optimized graph.
FunctionDefLibrary TrimFunctionLibrary(const FunctionLibraryDefinition& flib,
@@ -208,6 +240,77 @@ FunctionDefLibrary TrimFunctionLibrary(const FunctionLibraryDefinition& flib,
return lib;
}
+// Push all constant inputs of an instantiating node into the function body.
+Status PushDownConstInputs(const NodeDef& func_node,
+ const FunctionOptimizerContext& ctx,
+ GrapplerFunctionItem* item,
+ std::unordered_set<string>* const_inputs,
+ std::unordered_set<string>* control_deps) {
+ // Record node control dependencies in the control_deps set.
+ const auto record_control_deps = [&](const NodeDef* const_input) {
+ for (int i = const_input->input_size() - 1; i >= 0; --i) {
+ const string& input = const_input->input(i);
+ if (IsControlInput(input))
+ control_deps->insert(input);
+ else
+ break;
+ }
+ };
+
+ for (int i = func_node.input_size() - 1; i >= 0; --i) {
+ const string& input = func_node.input(i);
+ if (IsControlInput(input)) continue;
+
+ const string node_name = NodeName(input);
+ if (ctx.IsTrulyConst(node_name)) {
+ VLOG(3) << "Push const into function body: input=" << input;
+ const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
+ const_inputs->insert(input);
+ record_control_deps(const_input);
+ TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
+ }
+ }
+
+ return Status::OK();
+}
+
+// Remove inputs that were pushed into the function body, and attach their
+// control dependencies to the function caller node.
+void RemovePushedDownConstInputs(const std::unordered_set<string>& const_inputs,
+ const std::unordered_set<string>& control_deps,
+ NodeDef* specialized_func_node) {
+ // Nothing to do if it was no const inputs to the function node.
+ if (const_inputs.empty()) return;
+
+ // Keep only non-const inputs.
+ std::vector<string> keep_inputs;
+ const auto& inputs = specialized_func_node->input();
+ std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(keep_inputs),
+ [&](const string& input) {
+ return const_inputs.find(input) == const_inputs.end();
+ });
+
+ specialized_func_node->clear_input();
+ for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
+
+ // Attach control dependencies of pushed down const input to the caller node.
+ if (!control_deps.empty()) {
+ std::unordered_set<string> existing_control_deps;
+
+ for (const string& input : keep_inputs) {
+ existing_control_deps.insert(AsControlDependency(NodeName(input)));
+ }
+
+ for (const string& ctrl : control_deps) {
+ if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
+ VLOG(3) << "Forward control dependency to function caller node: input="
+ << ctrl;
+ specialized_func_node->add_input(ctrl);
+ }
+ }
+ }
+}
+
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
@@ -219,11 +322,19 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
const auto& flib = ctx->function_library();
- // Make a GrapplerFunctionItem and immediately convert it back to FunctionDef.
+ // Make a GrapplerFunctionItem and convert it back to FunctionDef after
+ // pushing all constant inputs into the function body.
GrapplerFunctionItem item;
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
- // TODO(ezhulenev): Push down const inputs and known input shapes.
+ // Push const inputs into the function body, and keep track of their control
+ // dependencies.
+ std::unordered_set<string> const_inputs;
+ std::unordered_set<string> control_deps;
+ TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
+ &control_deps));
+
+ // TODO(ezhulenev): Push down known input shapes.
FunctionDef specialized_func;
TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
@@ -237,13 +348,16 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Add specialized function to the library.
TF_RETURN_IF_ERROR(
- ctx->mutable_function_library().AddFunctionDef(specialized_func));
+ ctx->mutable_function_library()->AddFunctionDef(specialized_func));
// Add a function call node for the specialized function.
NodeDef* specialized_func_node = optimized_graph->add_node();
*specialized_func_node = func_node;
specialized_func_node->set_op(specialized_func_name);
+ // Update specialized node to remove inputs for pushed down consts.
+ RemovePushedDownConstInputs(const_inputs, control_deps,
+ specialized_func_node);
return Status::OK();
}
@@ -582,11 +696,9 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Do not specialize if function has custom gradient.
const string grad_func = ctx.function_library().FindGradient(func_name);
- if (specialize_func && grad_func.empty() && IsParametrized(*func)) {
- // TODO(ezhulenev): Specialize function call if input is a Const or has
- // a known shape. Const input tensors can be pushed into the function
- // body and removed from function inputs.
-
+ if (specialize_func && grad_func.empty() &&
+ (IsParametrized(*func) || HasTrulyConstInputs(node, ctx))) {
+ // TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
TF_RETURN_IF_ERROR(
SpecializeFunction(node, *func, &ctx, optimized_graph));
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 6147e8a27c..147a264421 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -657,5 +657,66 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_XTimesTwo) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
+TEST_F(FunctionOptimizerTest, SpecializeFunction_PushDownConstInput) {
+ using test::function::NDef;
+
+ FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
+
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ // Mark MyMul as noinline.
+ (*mul_func.mutable_attr())["_noinline"].set_b(true);
+ std::vector<FunctionDef> function_library = {mul_func};
+
+ // Build a graph to compute y = MyMul(x, 2.0).
+ const Tensor kTwo = test::AsScalar<float>(2.0);
+
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("init", "NoOp", {}, {}, kDevice),
+ NDef("two", "Const", {"^init", "^x"},
+ {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice),
+ NDef("y", "MyMul", {"x", "two"}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, kDevice)},
+ function_library);
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Make sure that specialized function was added to the library and original
+ // function was removed.
+ ASSERT_EQ(1, output.library().function_size());
+
+ const FunctionDef& specialized = output.library().function(0);
+ EXPECT_EQ("MyMul_specialized_for_y", specialized.signature().name());
+ EXPECT_EQ(1, specialized.signature().input_arg_size());
+
+ // And 'y' node has control dependencies of a pushed down const node.
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "y" && count++) {
+ ASSERT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^init", node.input(1));
+ }
+ }
+ EXPECT_EQ(1, count);
+
+ // And that graph evaluation yields the same result.
+ Tensor pi = test::AsScalar<float>(3.14f);
+ item.fetch = {"z"};
+ item.feed.emplace_back("x", pi);
+
+ auto tensors_expected = EvaluateFetchNodes(item);
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index c1fee0e993..7c6468bfcb 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -1219,6 +1219,80 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
return updated_graph;
}
+// TODO(rmlarsen): Add distributed TF test.
+Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
+ std::unordered_set<string> devices;
+ std::vector<int> assign_nodes;
+ bool found_send = false;
+ for (int i = 0; i < optimized_graph->node_size(); ++i) {
+ const NodeDef& node = optimized_graph->node(i);
+ devices.insert(node.device());
+ if (IsAssign(node)) {
+ assign_nodes.push_back(i);
+ }
+ if (IsSend(node)) {
+ found_send = true;
+ break;
+ }
+ }
+ if (!found_send && devices.size() == 1) {
+ for (int assign_idx : assign_nodes) {
+ // Set an attribute telling AssignOp to ignore allocator constraints.
+ NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
+ (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
+ .set_b(true);
+ }
+ return Status::OK();
+ }
+
+ std::unordered_set<int> optimized_nodes;
+ SimpleGraphView graph_view;
+ TF_RETURN_IF_ERROR(graph_view.Initialize(*optimized_graph));
+ for (int i : assign_nodes) {
+ if (optimized_nodes.find(i) == optimized_nodes.end()) {
+ const NodeDef& node = optimized_graph->node(i);
+ optimized_nodes.insert(i);
+ std::vector<int> assign_nodes_in_fanout;
+ assign_nodes_in_fanout.push_back(i);
+ std::set<int> transitive_fanout;
+ graph_view.DepthFirstSearch(std::unordered_set<string>{}, i,
+ &transitive_fanout);
+ const string& assign_device = node.device();
+ bool relax_constraint = true;
+ // If all nodes in the transitive fanout are on the same device as the
+ // assign node, there is no need to allocate the output in pinned memory.
+ for (int fanout : transitive_fanout) {
+ const NodeDef& fanout_node = optimized_graph->node(fanout);
+ if (relax_constraint &&
+ (fanout_node.device() != assign_device || IsSend(fanout_node))) {
+ relax_constraint = false;
+ }
+ if (optimized_nodes.find(fanout) == optimized_nodes.end() &&
+ IsAssign(fanout_node)) {
+ assign_nodes_in_fanout.push_back(fanout);
+ }
+ }
+
+ for (int assign_idx : assign_nodes_in_fanout) {
+ if (relax_constraint) {
+ // If all devices match in fanout of node(i) then, by transitivity,
+ // they must also match in the fanout of other assign nodes
+ // node(assign_idx) in the fanout, so we can process them here,
+ // and save computing their transitive fanout later.
+ optimized_nodes.insert(assign_idx);
+
+ // Set an attribute telling AssignOp to ignore allocator constraints.
+ NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
+ (*assign_node
+ ->mutable_attr())["_grappler_relax_allocator_constraints"]
+ .set_b(true);
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
@@ -1251,6 +1325,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
+ TF_RETURN_IF_ERROR(RelaxAllocatorConstraints(&optimized_item.graph));
+
optimized_graph->Swap(&optimized_item.graph);
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index a1f80802dd..a3f0e07861 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -440,6 +440,140 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
}
}
+class RelaxAllocatorConstraintsTest : public GrapplerTest {};
+
+TEST_F(RelaxAllocatorConstraintsTest, SameDevice) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
+ -3.14f, {128, 128});
+ Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
+ {128, 128}, DT_FLOAT);
+ Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
+ variable, constant);
+ Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/cpu:0"), assign);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto node = output.node(2);
+ EXPECT_EQ("assign", node.name());
+ EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
+ EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
+
+ item.fetch = {"exp"};
+ item.init_ops = {"variable"};
+ auto tensors_expected = EvaluateFetchNodes(item);
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
+TEST_F(RelaxAllocatorConstraintsTest, DifferentDevice) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
+ -3.14f, {128, 128});
+ Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
+ {128, 128}, DT_FLOAT);
+ Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
+ variable, constant);
+ // exp runs on a different device, so we cannot relax the allocation
+ // constraints on assign.
+ Output exp = ops::Exp(s.WithOpName("exp").WithDevice("/gpu:0"), assign);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto node = output.node(2);
+ EXPECT_EQ("assign", node.name());
+ EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
+#if GOOGLE_CUDA
+ item.fetch = {"exp"};
+ item.init_ops = {"variable"};
+ auto tensors_expected = EvaluateFetchNodes(item);
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+#endif
+}
+
+TEST_F(RelaxAllocatorConstraintsTest, SendNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output constant = ops::Const(s.WithOpName("constant").WithDevice("/cpu:0"),
+ -3.14f, {128, 128});
+ Output variable = ops::Variable(s.WithOpName("variable").WithDevice("/cpu:0"),
+ {128, 128}, DT_FLOAT);
+ Output assign = ops::Assign(s.WithOpName("assign").WithDevice("/cpu:0"),
+ variable, constant);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ NodeDef* send = item.graph.add_node();
+ // Add a send node to the graph in the fanout of "assign".
+ send->set_name("send");
+ send->set_op("_Send");
+ send->add_input("assign");
+
+ MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto node = output.node(2);
+ EXPECT_EQ("assign", node.name());
+ EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
+}
+
+TEST_F(RelaxAllocatorConstraintsTest, AssignNodeInFanout) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output constant0 = ops::Const(s.WithOpName("constant0").WithDevice("/cpu:0"),
+ -42.0f, {128, 128});
+ Output variable0 = ops::Variable(
+ s.WithOpName("variable0").WithDevice("/cpu:0"), {128, 128}, DT_FLOAT);
+ Output assign0 = ops::Assign(s.WithOpName("assign0").WithDevice("/cpu:0"),
+ variable0, constant0);
+ // The rest of the graph is on a second device, so we can relax the
+ // constraint for assign1, but not for assign0.
+ Output exp1 = ops::Exp(s.WithOpName("exp1").WithDevice("/gpu:0"), assign0);
+ Output variable1 = ops::Variable(
+ s.WithOpName("variable1").WithDevice("/gpu:0"), {128, 128}, DT_FLOAT);
+ Output assign1 = ops::Assign(s.WithOpName("assign1").WithDevice("/gpu:0"),
+ variable1, exp1);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ MemoryOptimizer optimizer(RewriterConfig::MANUAL);
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto node = output.node(3);
+ EXPECT_EQ("assign0", node.name());
+ EXPECT_EQ(0, node.attr().count("_grappler_relax_allocator_constraints"));
+
+ node = output.node(5);
+ EXPECT_EQ("assign1", node.name());
+ EXPECT_EQ(1, node.attr().count("_grappler_relax_allocator_constraints"));
+ EXPECT_EQ(true, node.attr().at("_grappler_relax_allocator_constraints").b());
+
+#if GOOGLE_CUDA
+ item.fetch = {"assign0", "assign1"};
+ item.init_ops = {"exp1", "variable1"};
+ auto tensors_expected = EvaluateFetchNodes(item);
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+ for (int i = 0; i < tensors_expected.size(); ++i) {
+ test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
+ }
+#endif
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 7398d2c896..c8e63f95e1 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -361,8 +361,11 @@ inline void STLSortAndRemoveDuplicates(T* v) {
}
} // namespace
-Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs,
- bool dedup_outputs) {
+Status SimpleGraphView::Initialize(
+ const GraphDef& graph,
+ const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
+ extra_dependencies,
+ bool dedup_inputs, bool dedup_outputs) {
graph_ = &graph;
const int num_nodes = graph.node_size();
inputs_.clear();
@@ -381,6 +384,23 @@ Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs,
index_to_name_.push_back(node.name());
}
+ if (extra_dependencies) {
+ for (const auto& dep : *extra_dependencies) {
+ auto itr_src = name_to_index_.find(dep.first->name());
+ if (itr_src == name_to_index_.end()) {
+ return errors::InvalidArgument("Non-existent src ", dep.first->name());
+ }
+ auto itr_tgt = name_to_index_.find(dep.second->name());
+ if (itr_tgt == name_to_index_.end()) {
+ return errors::InvalidArgument("Non-existent tgt ", dep.second->name());
+ }
+ const int src_idx = itr_src->second;
+ const int tgt_idx = itr_tgt->second;
+ inputs_[tgt_idx].push_back(src_idx);
+ outputs_[src_idx].push_back(tgt_idx);
+ }
+ }
+
// Build forward and reverse adjacency lists.
for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
const NodeDef& node = graph.node(node_idx);
@@ -415,7 +435,8 @@ void SimpleGraphView::DepthFirstSearch(
std::set<int>* nodes_found) const {
nodes_found->clear();
const string& op_type = graph_->node(root_node).op();
- if (op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) {
+ if (!op_types_to_traverse.empty() &&
+ op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) {
return;
}
std::vector<int> stack;
@@ -426,7 +447,8 @@ void SimpleGraphView::DepthFirstSearch(
stack.pop_back();
nodes_found->insert(node_idx);
const string& op_type = graph_->node(node_idx).op();
- if (op_types_to_traverse.find(op_type) != op_types_to_traverse.end()) {
+ if (op_types_to_traverse.empty() ||
+ op_types_to_traverse.find(op_type) != op_types_to_traverse.end()) {
for (auto output_idx : this->outputs(node_idx)) {
if (nodes_found->find(output_idx) == nodes_found->end()) {
stack.push_back(output_idx);
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 54cb26bafa..9776e99f20 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -211,11 +211,24 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
class SimpleGraphView {
public:
+ // Build a graph view for the specified graphdef.
Status Initialize(const GraphDef& graph) {
- return Initialize(graph, true, true);
+ return Initialize(graph, nullptr, true, true);
}
- Status Initialize(const GraphDef& graph, bool dedup_inputs,
- bool dedup_outputs);
+ // Build a graph view for the specified graphdef augmented with the additional
+ // edges specified in 'extra_dependencies' if any. Note that
+ // extra_dependencies can be null.
+ Status Initialize(
+ const GraphDef& graph,
+ const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
+ extra_dependencies) {
+ return Initialize(graph, extra_dependencies, true, true);
+ }
+ Status Initialize(
+ const GraphDef& graph,
+ const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
+ extra_dependencies,
+ bool dedup_inputs, bool dedup_outputs);
const GraphDef* graph() const { return graph_; }
inline int num_nodes() const { return index_to_name_.size(); }
@@ -238,6 +251,7 @@ class SimpleGraphView {
// visited in nodes_found. If a node has an op in `op_types_to_traverse`, the
// walk continues to its children. It is assumed that *graph_ was not modified
// after the call to Initialize().
+ // If `op_types_to_traverse` is empty the DFS will traverse any node type.
void DepthFirstSearch(const std::unordered_set<string>& op_types_to_traverse,
int node_idx, std::set<int>* nodes_found) const;
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index a8e464d09d..ff89035902 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -26,10 +26,12 @@ namespace grappler {
// Kahn's algorithm is implemented.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
-Status ComputeTopologicalOrder(const GraphDef& graph,
- std::vector<int>* ready_nodes) {
+Status ComputeTopologicalOrder(
+ const GraphDef& graph, std::vector<int>* ready_nodes,
+ const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
+ extra_dependencies) {
SimpleGraphView graph_view;
- TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
+ TF_RETURN_IF_ERROR(graph_view.Initialize(graph, extra_dependencies));
ready_nodes->reserve(graph_view.num_nodes());
@@ -70,10 +72,12 @@ Status ComputeTopologicalOrder(const GraphDef& graph,
}
Status ComputeTopologicalOrder(
- const GraphDef& graph,
- std::unordered_map<const NodeDef*, int>* topo_order) {
+ const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order,
+ const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
+ extra_dependencies) {
std::vector<int> ready_nodes;
- TF_RETURN_IF_ERROR(ComputeTopologicalOrder(graph, &ready_nodes));
+ TF_RETURN_IF_ERROR(
+ ComputeTopologicalOrder(graph, &ready_nodes, extra_dependencies));
topo_order->reserve(graph.node_size());
for (int i = 0; i < ready_nodes.size(); ++i) {
(*topo_order)[&graph.node(ready_nodes[i])] = i;
@@ -83,7 +87,7 @@ Status ComputeTopologicalOrder(
Status TopologicalSort(GraphDef* graph) {
std::vector<int> ready_nodes;
- TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes));
+ TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
return Status::OK();
}
diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h
index 668c88dc75..bc0299a7b8 100644
--- a/tensorflow/core/grappler/utils/topological_sort.h
+++ b/tensorflow/core/grappler/utils/topological_sort.h
@@ -24,7 +24,9 @@ namespace grappler {
// Compute a topological ordering for the graph nodes.
Status ComputeTopologicalOrder(
- const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order);
+ const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order,
+ const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
+ extra_dependencies);
// Sort a graph in topological order.
Status TopologicalSort(GraphDef* graph);
diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc
index f5c95009d2..48b7eb50bd 100644
--- a/tensorflow/core/grappler/utils/topological_sort_test.cc
+++ b/tensorflow/core/grappler/utils/topological_sort_test.cc
@@ -53,7 +53,7 @@ TEST_F(TopologicalSortTest, NoLoop) {
*graph.add_node() = CreateNode("4", {});
std::unordered_map<const NodeDef*, int> topo_order;
- TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
+ TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr));
const std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
for (const auto& topo : topo_order) {
@@ -80,7 +80,7 @@ TEST_F(TopologicalSortTest, WithLoop) {
*graph.add_node() = CreateNode("1", {});
std::unordered_map<const NodeDef*, int> topo_order;
- TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
+ TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr));
const std::vector<string> order = {"1", "2", "3", "4", "5"};
for (const auto& topo : topo_order) {
@@ -143,6 +143,36 @@ TEST_F(TopologicalSortTest, Idempotent) {
}
}
+TEST_F(TopologicalSortTest, ExtraDependencies) {
+ GraphDef graph;
+ *graph.add_node() = CreateNode("2", {"5"});
+ *graph.add_node() = CreateNode("0", {"5", "4"});
+ *graph.add_node() = CreateNode("1", {"4", "3"});
+ *graph.add_node() = CreateNode("3", {"2"});
+ *graph.add_node() = CreateNode("5", {});
+ *graph.add_node() = CreateNode("4", {});
+
+ // Add an edge from 4 to 5.
+ std::vector<std::pair<const NodeDef*, const NodeDef*>> extra_dependencies;
+ extra_dependencies.emplace_back(&graph.node(5), &graph.node(4));
+
+ std::unordered_map<const NodeDef*, int> topo_order;
+ TF_EXPECT_OK(
+ ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies));
+
+ const std::vector<string> order = {"4", "5", "2", "0", "3", "1"};
+ for (const auto& topo : topo_order) {
+ const string& node_name = topo.first->name();
+ const int topo_order = topo.second;
+ EXPECT_EQ(node_name, order[topo_order]);
+ }
+
+ // Add an edge from 0 to 4. This will create a loop
+ extra_dependencies.emplace_back(&graph.node(1), &graph.node(5));
+ EXPECT_FALSE(
+ ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies).ok());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 6355f13654..3fb03cd5bd 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3299,7 +3299,10 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
- ] + if_cuda(["@cub_archive//:cub"]),
+ ] + if_cuda([
+ "@cub_archive//:cub",
+ "@local_config_cuda//cuda:cudnn",
+ ]),
)
tf_kernel_library(
@@ -3310,12 +3313,15 @@ tf_kernel_library(
prefix = "depthwise_conv_grad_op",
deps = [
":bounds_check",
+ ":conv_ops",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
- ],
+ ] + if_cuda([
+ "@local_config_cuda//cuda:cudnn",
+ ]),
)
cc_library(
diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h
index 2ed1628bf1..a450b1d1ee 100644
--- a/tensorflow/core/kernels/assign_op.h
+++ b/tensorflow/core/kernels/assign_op.h
@@ -36,6 +36,12 @@ class AssignOp : public OpKernel {
context->GetAttr("validate_shape", &validate_shape_));
OP_REQUIRES(context, IsRefType(context->input_type(0)),
errors::InvalidArgument("lhs input needs to be a ref type"));
+ if (!context
+ ->GetAttr("_grappler_relax_allocator_constraints",
+ &relax_constraints_)
+ .ok()) {
+ relax_constraints_ = false;
+ }
}
void Compute(OpKernelContext* context) override {
@@ -44,49 +50,37 @@ class AssignOp : public OpKernel {
// We always return the input ref.
context->forward_ref_input_to_ref_output(0, 0);
- // We can't always know how this value will be used downstream,
- // so make conservative assumptions in specifying constraints on
- // the memory allocation attributes.
- // TODO(rmlarsen): These conservative constraints make buffer
- // forwarding unlikely to happen very often. Try to use graph analysis
- // (possibly the InferAllocAttr pass in the executer) to improve the
- // situation.
+ // We can't always know how this value will be used downstream, so make
+ // conservative assumptions in specifying constraints on the memory
+ // allocation attributes, unless the Grappler graph analysis determined that
+ // it was safe not to.
AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
+ if (!relax_constraints_) {
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ }
{
mutex_lock l(*context->input_ref_mutex(0));
const Tensor& old_lhs = context->mutable_input(0, /* lock_held */ true);
const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape());
if (validate_shape_) {
- OP_REQUIRES(
- context, same_shape,
- errors::InvalidArgument(
- "Assign requires shapes of both tensors to match. lhs shape= ",
- old_lhs.shape().DebugString(),
- " rhs shape= ", rhs.shape().DebugString()));
+ OP_REQUIRES(context, same_shape,
+ errors::InvalidArgument(
+ "Assign requires shapes of both tensors to match. "
+ "lhs shape= ",
+ old_lhs.shape().DebugString(),
+ " rhs shape= ", rhs.shape().DebugString()));
}
// In the code below we try to minimize the amount of memory allocation
// and copying by trying the following two shortcuts:
- // 1. If we can reuse the rhs buffer we avoid both a memory allocation
- // and copying.
- // 2. If the lhs is initialized and has the same number of elements as the
- // rhs we can avoid a memory allocation.
-
- // 1. Try to reuse the rhs.
- std::unique_ptr<Tensor> input_alias = context->forward_input(
- 1, OpKernelContext::Params::kNoReservation /*output_index*/,
- old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr);
- if (input_alias != nullptr) {
- // Transfer ownership to the ref.
- context->replace_ref_input(0, *input_alias.release(),
- /* lock_held */ true);
- return;
- }
+ // 1. If the lhs is initialized and has the same number of elements as
+ // the rhs we can avoid a memory allocation.
+ // 2. If we can reuse the rhs buffer we avoid both a memory allocation
+ // and copying.
- // 2. Try to copy into an existing buffer.
+ // 1. Try to copy into an existing buffer.
if (old_lhs.IsInitialized() &&
old_lhs.shape().num_elements() == rhs.shape().num_elements()) {
// The existing lhs tensor has already been initialized and the right
@@ -96,15 +90,26 @@ class AssignOp : public OpKernel {
reshaped_old_lhs = old_lhs;
} else {
CHECK(reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape()));
- context->replace_ref_input(0, reshaped_old_lhs, /* lock_held */ true);
+ context->replace_ref_input(0, reshaped_old_lhs,
+ /* lock_held */ true);
}
if (use_exclusive_lock_) {
Copy(context, &reshaped_old_lhs, rhs);
return;
}
} else {
- // Create a new persistent tensor whose shape matches the right hand
- // side, hand off to lhs and copy the rhs into it.
+ // 2. Try to reuse the rhs.
+ std::unique_ptr<Tensor> input_alias = context->forward_input(
+ 1, OpKernelContext::Params::kNoReservation /*output_index*/,
+ rhs.dtype(), rhs.shape(), DEVICE_MEMORY, attr);
+ if (input_alias != nullptr) {
+ // Update the ref to point to the new buffer.
+ context->replace_ref_input(0, *input_alias, /* lock_held */ true);
+ return;
+ }
+
+ // Otherwise, create a new persistent tensor whose shape matches the
+ // right hand side, hand off to lhs and copy the rhs into it.
PersistentTensor copy;
Tensor* copyTensor = nullptr;
OP_REQUIRES_OK(
@@ -133,6 +138,7 @@ class AssignOp : public OpKernel {
bool use_exclusive_lock_;
bool validate_shape_;
+ bool relax_constraints_;
};
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/broadcast_to_op.h b/tensorflow/core/kernels/broadcast_to_op.h
index 608e9b6ac9..73fdd5d28e 100644
--- a/tensorflow/core/kernels/broadcast_to_op.h
+++ b/tensorflow/core/kernels/broadcast_to_op.h
@@ -34,14 +34,37 @@ struct BroadcastTo {
const TensorShape &input_shape) {
#define BROADCAST_SHAPE(broadcast, reshape, NDIMS, input_shape, output_shape) \
for (int i = 0; i < NDIMS; i++) { \
- OP_REQUIRES(ctx, (broadcast[i] % reshape[i] == 0), \
- errors::InvalidArgument("invalid shape to broadcast from ", \
- input_shape.DebugString(), " to ", \
- output_shape.DebugString())); \
- broadcast[i] = broadcast[i] / reshape[i]; \
+ if (reshape[i] != broadcast[i]) { \
+ OP_REQUIRES(ctx, \
+ ((reshape[i] != 0) && (broadcast[i] % reshape[i] == 0)), \
+ errors::InvalidArgument("invalid shape to broadcast from ", \
+ input_shape.DebugString(), " to ", \
+ output_shape.DebugString())); \
+ broadcast[i] = broadcast[i] / reshape[i]; \
+ } else { \
+ broadcast[i] = 1; \
+ } \
}
+ if (output_shape.num_elements() == 0) {
+ return;
+ }
+ if (output_shape == input_shape) {
+ output_tensor.flat<T>().device(d) = input_tensor.flat<T>();
+ return;
+ }
+
switch (output_shape.dims()) {
+ case 0: {
+ if (input_shape.dims() > 0) {
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ output_tensor.scalar<T>().device(d) = input_tensor.scalar<T>();
+ break;
+ }
case 1: {
auto reshape = AsEigenDSizesWithPrefix<1>(input_shape);
auto broadcast = output_shape.AsEigenDSizes<1>();
@@ -125,7 +148,6 @@ struct BroadcastTo {
auto broadcast = output_shape.AsEigenDSizes<4>();
BROADCAST_SHAPE(broadcast, reshape, 4, input_shape, output_shape);
-
auto output = output_tensor.tensor<T, 4>();
switch (input_shape.dims()) {
case 0: {
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index ef1e73e5ab..aca75176a5 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -96,7 +96,8 @@ template <typename T>
struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
- int row_stride, int col_stride, const Padding& padding,
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format) {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
@@ -275,7 +276,8 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
#endif
LaunchConv2DBackpropFilterOp<Device, T>()(
- context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
+ context, false, false, out_backprop, input,
+ /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
}
@@ -523,6 +525,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
+template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
+
// GPU definitions.
#if GOOGLE_CUDA
// The slow version (but compiles for GPU)
@@ -690,10 +697,15 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
return;
}
+ // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
+ // input depth, it's a depthwise convolution. More generally, if the filter
+ // in-depth divides but is smaller than the input depth, it is a grouped
+ // convolution.
+ bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
if (!cudnn_disable_conv_1x1_optimization_ &&
dims.spatial_dims[0].filter_size == 1 &&
- dims.spatial_dims[1].filter_size == 1 &&
+ dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
data_format == FORMAT_NHWC) {
const uint64 m = dims.in_depth;
@@ -734,9 +746,10 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
dims.spatial_dims[0].input_size &&
dims.spatial_dims[1].filter_size ==
dims.spatial_dims[1].input_size &&
- padding == VALID && data_format == FORMAT_NHWC) {
- // The input data and filter have the same height/width, so call cublas
- // directly.
+ !is_grouped_convolution && padding == VALID &&
+ data_format == FORMAT_NHWC) {
+ // The input data and filter have the same height/width, and we are not
+ // using grouped convolution, so call cublas directly.
const uint64 m = dims.spatial_dims[0].input_size *
dims.spatial_dims[1].input_size * dims.in_depth;
const uint64 k = dims.batch_size;
@@ -802,15 +815,16 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
- .set_input_feature_map_count(dims.in_depth)
- .set_output_feature_map_count(dims.out_depth);
+ .set_input_feature_map_count(filter_shape.dim_size(2))
+ .set_output_feature_map_count(filter_shape.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
+ .set_zero_padding_width(padding_cols / 2)
+ .set_group_count(dims.in_depth / filter_shape.dim_size(2));
// NOTE(zhengxq):
// cuDNN only supports the following layouts :
@@ -891,21 +905,22 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
ConvParameters conv_parameters = {
- dims.batch_size, // batch
- dims.in_depth, // in_depths
- {{input_desc.height(), // in_rows
- input_desc.width()}}, // in_cols
- dims.out_depth, // out_depths
- {{dims.spatial_dims[0].filter_size, // filter_rows
- dims.spatial_dims[1].filter_size}}, // filter_cols
- {{dims.spatial_dims[0].dilation, // dilation_rows
- dims.spatial_dims[1].dilation}}, // dilation_cols
- {{dims.spatial_dims[0].stride, // stride_rows
- dims.spatial_dims[1].stride}}, // stride_cols
- {{padding_rows, // padding_rows
- padding_cols}}, // padding_cols
- dtype, // tensor datatype
- device_id, // device_id
+ dims.batch_size, // batch
+ dims.in_depth, // in_depths
+ {{input_desc.height(), // in_rows
+ input_desc.width()}}, // in_cols
+ dims.out_depth, // out_depths
+ {{dims.spatial_dims[0].filter_size, // filter_rows
+ dims.spatial_dims[1].filter_size, // filter_cols
+ filter_shape.dim_size(2)}}, // filter_depth
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
+ {{dims.spatial_dims[0].stride, // stride_rows
+ dims.spatial_dims[1].stride}}, // stride_cols
+ {{padding_rows, // padding_rows
+ padding_cols}}, // padding_cols
+ dtype, // tensor datatype
+ device_id, // device_id
};
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
@@ -1019,9 +1034,9 @@ namespace functor {
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;
-DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
@@ -1040,6 +1055,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
.TypeConstraint<Eigen::half>("T")
.HostMemory("filter_sizes"),
Conv2DSlowBackpropFilterOp<GPUDevice, Eigen::half>);
+
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
+
#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 35f2676023..63a775afa8 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -101,8 +101,9 @@ template <typename T>
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
- int row_stride, int col_stride, const Padding& padding,
- Tensor* in_backprop, TensorFormat data_format) {
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding, Tensor* in_backprop,
+ TensorFormat data_format) {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
@@ -280,8 +281,8 @@ class Conv2DFastBackpropInputOp : public OpKernel {
LaunchConv2DBackpropInputOp<Device, T>()(
context, false, false, out_backprop, filter,
- dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
- in_backprop, data_format_);
+ /*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
+ dims.spatial_dims[1].stride, padding_, in_backprop, data_format_);
}
private:
@@ -595,6 +596,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
+template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
+
// GPU definitions.
#if GOOGLE_CUDA
// The slow version (but compiles for GPU)
@@ -761,8 +767,13 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
return;
}
+ // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
+ // input depth, it's a depthwise convolution. More generally, if the filter
+ // in-depth divides but is smaller than the input depth, it is a grouped
+ // convolution.
+ bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
if (dims.spatial_dims[0].filter_size == 1 &&
- dims.spatial_dims[1].filter_size == 1 &&
+ dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
@@ -795,9 +806,10 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
dims.spatial_dims[0].input_size &&
dims.spatial_dims[1].filter_size ==
dims.spatial_dims[1].input_size &&
- padding == VALID && data_format == FORMAT_NHWC) {
- // The input data and filter have the same height/width, so call cublas
- // directly.
+ !is_grouped_convolution && padding == VALID &&
+ data_format == FORMAT_NHWC) {
+ // The input data and filter have the same height/width, and we are not
+ // using grouped convolution, so call cublas directly.
const uint64 m = dims.batch_size;
const uint64 k = dims.out_depth;
const uint64 n = dims.spatial_dims[0].input_size *
@@ -856,15 +868,16 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
- .set_input_feature_map_count(dims.in_depth)
- .set_output_feature_map_count(dims.out_depth);
+ .set_input_feature_map_count(filter_shape.dim_size(2))
+ .set_output_feature_map_count(filter_shape.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
+ .set_zero_padding_width(padding_cols / 2)
+ .set_group_count(dims.in_depth / filter_shape.dim_size(2));
// NOTE(keveman):
// cuDNN only supports the following layouts :
@@ -940,21 +953,22 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
int device_id = stream->parent()->device_ordinal();
DataType dtype = out_backprop.dtype();
ConvParameters conv_parameters = {
- dims.batch_size, // batch
- dims.in_depth, // in_depths
- {{input_desc.height(), // in_rows
- input_desc.width()}}, // in_cols
- dims.out_depth, // out_depths
- {{dims.spatial_dims[0].filter_size, // filter_rows
- dims.spatial_dims[1].filter_size}}, // filter_cols
- {{dims.spatial_dims[0].dilation, // dilation_rows
- dims.spatial_dims[1].dilation}}, // dilation_cols
- {{dims.spatial_dims[0].stride, // stride_rows
- dims.spatial_dims[1].stride}}, // stride_cols
- {{padding_rows, // padding_rows
- padding_cols}}, // padding_cols
- dtype, // tensor data type
- device_id, // device_id
+ dims.batch_size, // batch
+ dims.in_depth, // in_depths
+ {{input_desc.height(), // in_rows
+ input_desc.width()}}, // in_cols
+ dims.out_depth, // out_depths
+ {{dims.spatial_dims[0].filter_size, // filter_rows
+ dims.spatial_dims[1].filter_size, // filter_cols
+ filter_shape.dim_size(2)}}, // filter_depths
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
+ {{dims.spatial_dims[0].stride, // stride_rows
+ dims.spatial_dims[1].stride}}, // stride_cols
+ {{padding_rows, // padding_rows
+ padding_cols}}, // padding_cols
+ dtype, // tensor data type
+ device_id, // device_id
};
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
@@ -1092,9 +1106,9 @@ namespace functor {
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;
-DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
@@ -1113,6 +1127,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
.TypeConstraint<Eigen::half>("T")
.HostMemory("input_sizes"),
Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
+
+// To be used inside depthwise_conv_grad_op.cc.
+template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
+template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
+
#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 170ce31d17..5bf709af08 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -127,16 +127,17 @@ Status ConvBackpropComputeDimensionsV2(
dims->in_depth = input_shape.dim_size(feature_dim);
// The input and output feature dimensions are the second last and last
// dimensions of the filter Tensor.
- if (dims->in_depth != filter_shape.dim_size(num_dims - 2)) {
+ VLOG(2) << "input vs filter_in depth " << dims->in_depth << " "
+ << filter_shape.dim_size(num_dims - 2);
+ if (dims->in_depth % filter_shape.dim_size(num_dims - 2)) {
return errors::InvalidArgument(
- label, ": input and filter must have the same depth");
+ label, ": input depth must be evenly divisible by filter depth");
}
dims->out_depth = filter_shape.dim_size(num_dims - 1);
if (dims->out_depth != out_backprop_shape.dim_size(feature_dim)) {
return errors::InvalidArgument(
label, ": filter and out_backprop must have the same out_depth");
}
-
dims->spatial_dims.resize(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
int image_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index c6d36b40fe..3b9886eece 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -18,10 +18,16 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
#include "tensorflow/core/kernels/conv_ops.h"
+
#include <string.h>
#include <map>
#include <vector>
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -32,9 +38,6 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/deep_conv2d.h"
#include "tensorflow/core/kernels/ops_util.h"
-#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
-#include "tensorflow/core/kernels/xsmm_conv2d.h"
-#endif
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -45,6 +48,10 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
+#include "tensorflow/core/kernels/xsmm_conv2d.h"
+#endif
+
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
@@ -123,6 +130,10 @@ struct LaunchConv2DOp<CPUDevice, T> {
"NHWC tensor format for now."));
return;
}
+ const int64 in_depth = GetTensorDim(input, data_format, 'C');
+ OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
+ errors::Unimplemented("Generic conv implementation does not "
+ "support grouped convolutions for now."));
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
row_dilation, col_dilation, padding, output,
data_format);
@@ -324,12 +335,13 @@ class Conv2DOp : public BinaryOp<T> {
}
// The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth.
+ // filter's in_depth or be evenly divisible by filter's in_depth.
const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- OP_REQUIRES(context, in_depth == filter.dim_size(2),
+ const int64 patch_depth = filter.dim_size(2);
+ OP_REQUIRES(context, in_depth % patch_depth == 0,
errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", filter.dim_size(2)));
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
// The last dimension for filter is out_depth.
const int out_depth = static_cast<int>(filter.dim_size(3));
@@ -386,6 +398,7 @@ class Conv2DOp : public BinaryOp<T> {
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
VLOG(2) << "Conv2D: in_depth = " << in_depth
+ << ", patch_depth = " << patch_depth
<< ", input_cols = " << input_cols
<< ", filter_cols = " << filter_cols
<< ", input_rows = " << input_rows
@@ -450,7 +463,9 @@ TF_CALL_double(REGISTER_CPU);
#endif // USE_GEMM_FOR_CONV
// To be used inside depthwise_conv_op.cc.
+template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DOp<CPUDevice, float>;
+template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
@@ -498,13 +513,24 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
}
Tensor input = input_param;
-
- if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
- col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
- data_format == FORMAT_NHWC) {
+ const int64 in_batch = GetTensorDim(input, data_format, 'N');
+ int64 in_rows = GetTensorDim(input, data_format, 'H');
+ int64 in_cols = GetTensorDim(input, data_format, 'W');
+ const int64 in_depths = GetTensorDim(input, data_format, 'C');
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
+ const int64 patch_depths = filter.dim_size(2);
+
+ // If the filter in-depth (patch_depths) is 1 and smaller than the input
+ // depth, it's a depthwise convolution. More generally, if the filter in-depth
+ // divides but is smaller than the input depth, it is a grouped convolution.
+ bool is_grouped_convolution = patch_depths != in_depths;
+ if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
+ row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
+ col_stride == 1 && data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
- const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
- const uint64 k = filter.dim_size(2);
+ const uint64 m = in_batch * in_rows * in_cols;
+ const uint64 k = patch_depths;
const uint64 n = filter.dim_size(3);
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
@@ -525,15 +551,14 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
", n=", n, ", k=", k));
}
return;
- } else if (filter.dim_size(0) == input.dim_size(1) &&
- filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
+ } else if (patch_rows == in_rows && patch_cols == in_cols &&
+ !is_grouped_convolution && row_dilation == 1 &&
col_dilation == 1 && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
- const uint64 m = input.dim_size(0);
- const uint64 k =
- filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
+ const uint64 m = in_batch;
+ const uint64 k = patch_rows * patch_cols * patch_depths;
const uint64 n = filter.dim_size(3);
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
@@ -558,16 +583,10 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
int padding_rows = 0;
int padding_cols = 0;
- const int64 in_batch = GetTensorDim(input, data_format, 'N');
- int64 in_rows = GetTensorDim(input, data_format, 'H');
- int64 in_cols = GetTensorDim(input, data_format, 'W');
- const int64 in_depths = GetTensorDim(input, data_format, 'C');
const int64 out_batch = GetTensorDim(*output, data_format, 'N');
const int64 out_rows = GetTensorDim(*output, data_format, 'H');
const int64 out_cols = GetTensorDim(*output, data_format, 'W');
const int64 out_depths = GetTensorDim(*output, data_format, 'C');
- const int64 patch_rows = filter.dim_size(0);
- const int64 patch_cols = filter.dim_size(1);
if (padding == SAME) {
// Total padding on rows and cols is
// Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
@@ -642,9 +661,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_feature_map_count(out_depths)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(filter.dim_size(0))
- .set_input_filter_width(filter.dim_size(1))
- .set_input_feature_map_count(filter.dim_size(2))
+ filter_desc.set_input_filter_height(patch_rows)
+ .set_input_filter_width(patch_cols)
+ .set_input_feature_map_count(patch_depths)
.set_output_feature_map_count(filter.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(row_dilation)
@@ -652,7 +671,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
+ .set_zero_padding_width(padding_cols / 2)
+ .set_group_count(in_depths / patch_depths);
Tensor transformed_filter;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
@@ -695,7 +715,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
in_cols}}, // in_cols
out_depths, // out_depths
{{patch_rows, // filter_rows
- patch_cols}}, // filter_cols
+ patch_cols, // filter_cols
+ patch_depths}}, // filter_depths
{{row_dilation, // dilation_rows
col_dilation}}, // dilation_cols
{{row_stride, // stride_rows
@@ -812,9 +833,9 @@ namespace functor {
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>
-DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
@@ -830,7 +851,9 @@ REGISTER_KERNEL_BUILDER(
Conv2DOp<GPUDevice, double>);
// To be used inside depthwise_conv_op.cc.
-template class LaunchConv2DOp<GPUDevice, float>;
+template struct LaunchConv2DOp<GPUDevice, float>;
+template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
+template struct LaunchConv2DOp<GPUDevice, double>;
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index c78e0aff83..9ded2667eb 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -124,6 +124,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "group_by_reducer_dataset_op",
+ srcs = ["group_by_reducer_dataset_op.cc"],
+ deps = [
+ ":captured_function",
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "group_by_window_dataset_op",
srcs = ["group_by_window_dataset_op.cc"],
deps = [
@@ -550,6 +564,7 @@ tf_kernel_library(
":filter_dataset_op",
":flat_map_dataset_op",
":generator_dataset_op",
+ ":group_by_reducer_dataset_op",
":group_by_window_dataset_op",
":interleave_dataset_op",
":iterator_ops",
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index dd61b7daee..ee58341cfd 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -32,6 +32,20 @@ Status CapturedFunction::Create(
return Status::OK();
}
+/* static */
+Status CapturedFunction::Create(
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ std::unique_ptr<CapturedFunction>* out_function) {
+ OpInputList argument_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
+ std::vector<Tensor> arguments_t;
+ arguments_t.reserve(argument_inputs.size());
+ for (const Tensor& t : argument_inputs) {
+ arguments_t.push_back(t);
+ }
+ return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+}
+
CapturedFunction::~CapturedFunction() {
if (lib_ != nullptr && f_handle_ != kInvalidHandle) {
lib_->ReleaseHandle(f_handle_).IgnoreError();
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 490f5cd1e3..e9ad3e381d 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -40,12 +40,20 @@ class ResourceMgr;
// context.
class CapturedFunction {
public:
+ // Creates a new instance from a list of named attributes and captured inputs.
+ //
// NOTE(mrry): The `captured_inputs` are passed by value. For
// efficiency, you are recommended to move this argument into the call.
static Status Create(const NameAttrList& func,
std::vector<Tensor> captured_inputs,
std::unique_ptr<CapturedFunction>* out_function);
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
+ std::unique_ptr<CapturedFunction>* out_function);
+
~CapturedFunction();
// Runs the "Captured function" using the given FLR and caches the lib and
@@ -87,6 +95,9 @@ class CapturedFunction {
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done);
+ // Returns the named list of function arguments.
+ const NameAttrList& func() { return func_; }
+
// Returns that additional captured inputs that will be passed to the function
// when `Run*()` is called.
const std::vector<Tensor>& captured_inputs() { return captured_inputs_; }
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
new file mode 100644
index 0000000000..c8aeaab9cb
--- /dev/null
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::unique_ptr<CapturedFunction> captured_key_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
+ std::unique_ptr<CapturedFunction> captured_init_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx,
+ "init_func_other_arguments",
+ &captured_init_func));
+ std::unique_ptr<CapturedFunction> captured_reduce_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
+ std::unique_ptr<CapturedFunction> captured_finalize_func;
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_arguments",
+ &captured_finalize_func));
+
+ *output = new Dataset(
+ ctx, input, std::move(captured_key_func), std::move(captured_init_func),
+ std::move(captured_reduce_func), std::move(captured_finalize_func),
+ output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ std::unique_ptr<CapturedFunction> captured_key_func,
+ std::unique_ptr<CapturedFunction> captured_init_func,
+ std::unique_ptr<CapturedFunction> captured_reduce_func,
+ std::unique_ptr<CapturedFunction> captured_finalize_func,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ captured_key_func_(std::move(captured_key_func)),
+ captured_init_func_(std::move(captured_init_func)),
+ captured_reduce_func_(std::move(captured_reduce_func)),
+ captured_finalize_func_(std::move(captured_finalize_func)),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() override { return "GroupByReducerDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+
+ std::vector<Node*> key_func_other_arguments_node;
+ DataTypeVector key_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_key_func_, &key_func_other_arguments_node,
+ &key_func_other_arguments_types));
+
+ std::vector<Node*> init_func_other_arguments_node;
+ DataTypeVector init_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_init_func_, &init_func_other_arguments_node,
+ &init_func_other_arguments_types));
+
+ std::vector<Node*> reduce_func_other_arguments_node;
+ DataTypeVector reduce_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_reduce_func_, &reduce_func_other_arguments_node,
+ &reduce_func_other_arguments_types));
+
+ std::vector<Node*> finalize_func_other_arguments_node;
+ DataTypeVector finalize_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_finalize_func_, &finalize_func_other_arguments_node,
+ &finalize_func_other_arguments_types));
+
+ AttrValue key_func;
+ b->BuildAttrValue(this->key_func(), &key_func);
+ AttrValue init_func;
+ b->BuildAttrValue(this->init_func(), &init_func);
+ AttrValue reduce_func;
+ b->BuildAttrValue(this->reduce_func(), &reduce_func);
+ AttrValue finalize_func;
+ b->BuildAttrValue(this->finalize_func(), &finalize_func);
+
+ AttrValue key_func_other_arguments_types_attr;
+ b->BuildAttrValue(key_func_other_arguments_types,
+ &key_func_other_arguments_types_attr);
+ AttrValue init_func_other_arguments_types_attr;
+ b->BuildAttrValue(init_func_other_arguments_types,
+ &init_func_other_arguments_types_attr);
+ AttrValue reduce_func_other_arguments_types_attr;
+ b->BuildAttrValue(reduce_func_other_arguments_types,
+ &reduce_func_other_arguments_types_attr);
+ AttrValue finalize_func_other_arguments_types_attr;
+ b->BuildAttrValue(finalize_func_other_arguments_types,
+ &finalize_func_other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {{0, input_graph_node}},
+ {{1, key_func_other_arguments_node},
+ {2, init_func_other_arguments_node},
+ {3, reduce_func_other_arguments_node},
+ {4, finalize_func_other_arguments_node}},
+ {{"key_func", key_func},
+ {"init_func", init_func},
+ {"reduce_func", reduce_func},
+ {"finalize_func", finalize_func},
+ {"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
+ {"Tinit_func_other_arguments", init_func_other_arguments_types_attr},
+ {"Treduce_func_other_arguments",
+ reduce_func_other_arguments_types_attr},
+ {"Tfinalize_func_other_arguments",
+ finalize_func_other_arguments_types_attr}},
+ output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+
+ // Iterate through the input dataset, keying input elements to reducers.
+ while (!end_of_input_) {
+ std::vector<Tensor> next_input_element;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
+
+ if (!end_of_input_) {
+ // Run the key function on the input element.
+ std::vector<Tensor> key_func_output;
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_key_func_->RunWithBorrowedArgs(
+ ctx, next_input_element, &key_func_output));
+
+ if (key_func_output.size() != 1 ||
+ key_func_output[0].dtype() != DT_INT64 ||
+ key_func_output[0].NumElements() != 1) {
+ // TODO(b/78665031): Support non-int64 keys.
+ return errors::InvalidArgument(
+ "`key_func` must return a scalar int64.");
+ }
+ const int64 key = key_func_output[0].scalar<int64>()();
+
+ if (states_.find(key) == states_.end()) {
+ // Run the init function to create the initial state.
+ std::vector<Tensor> init_func_output;
+ TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Run(
+ ctx, std::move(key_func_output), &init_func_output));
+ states_[key] = init_func_output;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(states_[key].size() + next_input_element.size());
+ std::copy(states_[key].begin(), states_[key].end(),
+ std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run(
+ ctx, std::move(args), &reduce_func_output));
+ states_[key] = reduce_func_output;
+ } else {
+ keys_.resize(states_.size());
+ int idx = 0;
+ for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
+ keys_[idx] = it->first;
+ }
+ }
+ }
+
+ if (keys_index_ == keys_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_finalize_func_->RunWithBorrowedArgs(
+ ctx, states_[keys_[keys_index_++]], out_tensors));
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+
+ // Saving states_.
+ if (!states_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("states_size"), states_.size()));
+ int idx = 0;
+ for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
+ int64 key = it->first;
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("states[", idx, "]->key")), key));
+ if (!it->second.empty()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("states[", idx, "]->state_size")),
+ it->second.size()));
+ for (int j = 0; j < it->second.size(); ++j) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("states[", idx, "]->state[", j, "]")),
+ it->second[j]));
+ }
+ }
+ }
+ }
+
+ // Saving keys_index_ and keys_.
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("keys_index"), keys_index_));
+ if (!keys_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("keys_size"), keys_.size()));
+ for (int idx = 0; idx < keys_.size(); ++idx) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("keys[", idx, "]")), keys_[idx]));
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+
+ // Restoring states_.
+ if (reader->Contains(full_name("states_size"))) {
+ int64 size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("states_size"), &size));
+ for (int idx = 0; idx < size; ++idx) {
+ int64 key;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("states[", idx, "]->key")), &key));
+ std::vector<Tensor> state;
+ if (reader->Contains(full_name(
+ strings::StrCat("states[", idx, "]->state_size")))) {
+ int64 state_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("states[", idx, "]->state_size")),
+ &state_size));
+ state.resize(state_size);
+ for (int j = 0; j < state_size; ++j) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(
+ strings::StrCat("states[", idx, "]->state[", j, "]")),
+ &state[j]));
+ }
+ }
+ states_[key] = state;
+ }
+ }
+
+ // Restoring keys_index_ and keys_.
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("keys_index"), &keys_index_));
+ if (reader->Contains(full_name("keys_size"))) {
+ int64 size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("keys_size"), &size));
+ keys_.resize(size);
+ for (int idx = 0; idx < size; ++idx) {
+ int64 key;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("keys[", idx, "]")), &key));
+ keys_[idx] = key;
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+ std::map<int64, std::vector<Tensor>> states_ GUARDED_BY(mu_);
+ std::vector<int64> keys_ GUARDED_BY(mu_);
+ int64 keys_index_ GUARDED_BY(mu_) = 0;
+ };
+
+ const NameAttrList& key_func() const { return captured_key_func_->func(); }
+
+ const NameAttrList& init_func() const {
+ return captured_init_func_->func();
+ }
+
+ const NameAttrList& reduce_func() const {
+ return captured_reduce_func_->func();
+ }
+
+ const NameAttrList& finalize_func() const {
+ return captured_finalize_func_->func();
+ }
+
+ Status OtherArgumentsNodeAndType(
+ DatasetGraphDefBuilder* b,
+ const std::unique_ptr<CapturedFunction>& captured_func,
+ std::vector<Node*>* other_arguments_node,
+ DataTypeVector* other_arguments_types) const {
+ other_arguments_node->reserve(captured_func->captured_inputs().size());
+ other_arguments_types->reserve(captured_func->captured_inputs().size());
+ for (const Tensor& t : captured_func->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments_node->emplace_back(node);
+ other_arguments_types->emplace_back(t.dtype());
+ }
+ return Status::OK();
+ }
+
+ const DatasetBase* const input_;
+ const std::unique_ptr<CapturedFunction> captured_key_func_;
+ const std::unique_ptr<CapturedFunction> captured_init_func_;
+ const std::unique_ptr<CapturedFunction> captured_reduce_func_;
+ const std::unique_ptr<CapturedFunction> captured_finalize_func_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList key_func_;
+ NameAttrList init_func_;
+ NameAttrList reduce_func_;
+ NameAttrList finalize_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
+ GroupByReducerDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 46f43dd1b1..03f847ce9c 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -241,7 +241,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
key_func_output[0].NumElements() != 1) {
- // TODO(mrry): Support non-int64 keys.
+ // TODO(b/78665031): Support non-int64 keys.
return errors::InvalidArgument(
"`key_func` must return a scalar int64.");
}
diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
index 91a9587174..7afa21acb9 100644
--- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
@@ -33,9 +34,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
+#include "cuda/include/cudnn.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -509,8 +512,19 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
}
}
+// Extern template instantiated in conv_grad_input_ops.cc.
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
+extern template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
+
#if GOOGLE_CUDA
+// Extern template instantiated in conv_grad_input_ops.cc.
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
+extern template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
+
+// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
@@ -548,6 +562,12 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+ // For in_depth == 1 and grouped convolutions.
+ use_cudnn_ = CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
+ use_cudnn_grouped_conv_ = false;
+ dtype_ = DataTypeToEnum<T>::value;
}
void Compute(OpKernelContext* context) override {
@@ -560,6 +580,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
input_sizes.dims()));
TensorShape input_shape;
const int32* in_sizes_data = input_sizes.template flat<int32>().data();
+
for (int i = 0; i < input_sizes.NumElements(); ++i) {
OP_REQUIRES(context, in_sizes_data[i] >= 0,
errors::InvalidArgument("Dimension ", i,
@@ -568,27 +589,77 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
}
const TensorShape& filter_shape = filter.shape();
EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
+
Tensor* in_backprop = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, input_shape, &in_backprop));
- auto out_backprop_ptr = out_backprop.template flat<T>().data();
- auto filter_ptr = filter.template flat<T>().data();
- auto in_backprop_ptr = in_backprop->template flat<T>().data();
+
// If there is nothing to compute, return.
if (input_shape.num_elements() == 0) {
return;
}
+
+ // If in_depth==1, this operation is just a standard convolution.
+ // Depthwise convolution is a special case of cuDNN's grouped convolution.
+ bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+ VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+ << ", " << out_depth << "], stride = " << stride_
+ << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+ << ", Use cuDNN: " << use_cudnn;
+
+ if (use_cudnn) {
+ // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+ //
+ // | TensorFlow | cuDNN
+ // --------------------------------------------------------------------
+ // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+ // filter_in_depth | in_depth | in_depth / group_count
+ //
+ // For depthwise convolution, we have group_count == in_depth.
+ int32 filter_in_depth = 1;
+ TensorShape shape =
+ TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+ Tensor reshaped_filter(/*type=*/dtype_);
+ OP_REQUIRES(
+ context, reshaped_filter.CopyFrom(filter, shape),
+ errors::Internal(
+ "Failed to reshape filter tensor for grouped convolution."));
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop,
+ reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
+ stride_, stride_, padding_, in_backprop, data_format_);
+ return;
+ }
+
+ auto out_backprop_ptr = out_backprop.template flat<T>().data();
+ auto filter_ptr = filter.template flat<T>().data();
+ auto in_backprop_ptr = in_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropInputOp<Device, T>()(
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
data_format_);
}
+ protected:
+ bool use_cudnn_grouped_conv_;
+
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
int64 stride_;
+ // For in_depth == 1 and grouped convolutions.
+ LaunchConv2DBackpropInputOp<Device, T> launcher_;
+ bool use_cudnn_;
+ bool cudnn_use_autotune_;
+ DataType dtype_;
+
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp);
};
@@ -597,23 +668,52 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
DepthwiseConv2dNativeBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
+#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
+#endif
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("input_sizes"),
- DepthwiseConv2dNativeBackpropInputOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNativeBackpropInput")
- .Device(DEVICE_GPU)
- .TypeConstraint<double>("T")
- .HostMemory("input_sizes"),
- DepthwiseConv2dNativeBackpropInputOp<GPUDevice, double>);
+
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("input_sizes"), \
+ DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvBackpropInputOp
+ : public DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T> {
+ public:
+ DepthwiseConv2dGroupedConvBackpropInputOp(OpKernelConstruction* context)
+ : DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>(context) {
+ this->use_cudnn_grouped_conv_ = true;
+ }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("input_sizes") \
+ .Label("cudnn_grouped_convolution"), \
+ DepthwiseConv2dGroupedConvBackpropInputOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#undef REGISTER_GROUPED_CONV_KERNEL
+#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
// Kernels to compute the gradients of the filters for depthwise convolution.
@@ -885,8 +985,19 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
}
}
+// Extern template instantiated in conv_grad_filter_ops.cc.
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
+extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
+
#if GOOGLE_CUDA
+// Extern template instantiated in conv_grad_filter_ops.cc.
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
+extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
+
+// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
@@ -924,6 +1035,21 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+ // For in_depth == 1 and grouped convolutions.
+ use_cudnn_ = CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
+ use_cudnn_grouped_conv_ = false;
+
+ if (std::is_same<T, Eigen::half>::value) {
+ dtype_ = DT_HALF;
+ } else if (std::is_same<T, float>::value) {
+ dtype_ = DT_FLOAT;
+ } else if (std::is_same<T, double>::value) {
+ dtype_ = DT_DOUBLE;
+ } else {
+ LOG(ERROR) << "Only half, float, and double are supported.";
+ }
}
void Compute(OpKernelContext* context) override {
@@ -949,24 +1075,73 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{1}, 0, filter_shape, &filter_backprop));
- auto out_backprop_ptr = out_backprop.template flat<T>().data();
- auto input_ptr = input.template flat<T>().data();
- auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
// If there is nothing to compute, return.
if (filter_shape.num_elements() == 0) {
return;
}
+
+ // If in_depth==1, this operation is just a standard convolution.
+ // Depthwise convolution is a special case of cuDNN's grouped convolution.
+ bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+ VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+ << ", " << out_depth << "], stride = " << stride_
+ << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+ << ", Use cuDNN: " << use_cudnn;
+
+ if (use_cudnn) {
+ // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+ //
+ // | TensorFlow | cuDNN
+ // --------------------------------------------------------------------
+ // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+ // filter_in_depth | in_depth | in_depth / group_count
+ //
+ // For depthwise convolution, we have group_count == in_depth.
+ int32 filter_in_depth = 1;
+ TensorShape shape =
+ TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+ Tensor reshaped_filter(/*type=*/dtype_);
+ OP_REQUIRES(
+ context, reshaped_filter.CopyFrom(*filter_backprop, shape),
+ errors::Internal(
+ "Failed to reshape filter tensor for grouped convolution."));
+
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
+ /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
+ padding_, &reshaped_filter, data_format_);
+ return;
+ }
+
+ auto out_backprop_ptr = out_backprop.template flat<T>().data();
+ auto input_ptr = input.template flat<T>().data();
+ auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
data_format_);
}
+ protected:
+ bool use_cudnn_grouped_conv_;
+
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
int64 stride_;
+ // For in_depth == 1 and grouped convolutions.
+ LaunchConv2DBackpropFilterOp<Device, T> launcher_;
+ bool use_cudnn_;
+ bool cudnn_use_autotune_;
+ DataType dtype_;
+
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp);
};
@@ -976,24 +1151,50 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
DepthwiseConv2dNativeBackpropFilterOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
+#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
+#endif
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNativeBackpropFilter")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("filter_sizes"),
- DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNativeBackpropFilter")
- .Device(DEVICE_GPU)
- .TypeConstraint<double>("T")
- .HostMemory("filter_sizes"),
- DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, double>);
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("filter_sizes"), \
+ DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvBackpropFilterOp
+ : public DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T> {
+ public:
+ DepthwiseConv2dGroupedConvBackpropFilterOp(OpKernelConstruction* context)
+ : DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>(context) {
+ this->use_cudnn_grouped_conv_ = true;
+ }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("filter_sizes") \
+ .Label("cudnn_grouped_convolution"), \
+ DepthwiseConv2dGroupedConvBackpropFilterOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#undef REGISTER_GROUPED_CONV_KERNEL
+#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index 6dedb1a61e..d5f4a68120 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
+#include "cuda/include/cudnn.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -241,18 +242,22 @@ struct LaunchDepthwiseConvOp<CPUDevice, T> {
};
// Extern template instantiated in conv_ops.cc.
+extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
extern template struct LaunchConv2DOp<CPUDevice, float>;
+extern template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
+// Extern template instantiated in conv_ops.cc.
+extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
+extern template struct LaunchConv2DOp<GPUDevice, float>;
+extern template struct LaunchConv2DOp<GPUDevice, double>;
+
// Extern template instantiated in depthwise_conv_op_gpu.cc.
extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
-// Extern template instantiated in conv_ops.cc.
-extern template struct LaunchConv2DOp<GPUDevice, float>;
-
#endif
template <typename Device, typename T>
@@ -284,9 +289,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- // For special case when in_depth == 1.
+ // For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
+ use_cudnn_grouped_conv_ = false;
+ dtype_ = DataTypeToEnum<T>::value;
}
void Compute(OpKernelContext* context) override {
@@ -357,27 +364,47 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "DepthwiseConv2dNative: "
- << " Input: [" << batch << ", " << input_rows << ", " << input_cols
- << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
- << filter_cols << ", " << in_depth << ", " << depth_multiplier
- << "]; stride = " << stride_ << ", pad_rows = " << pad_rows
- << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
- << out_rows << ", " << out_cols << ", " << out_depth << "]";
-
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
return;
}
- // If in_depth==1, this operation is just a standard convolution, so
- // invoke that op.
- if (std::is_same<T, float>::value && in_depth == 1) {
+ // TODO(csigg): Have autotune decide if native is faster than cuDNN.
+ // If in_depth==1, this operation is just a standard convolution.
+ // Depthwise convolution is a special case of cuDNN's grouped convolution.
+ bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+
+ VLOG(2) << "DepthwiseConv2dNative: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
+ << ", " << out_depth << "], stride = " << stride_
+ << ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
+ << ", Use cuDNN: " << use_cudnn;
+
+ if (use_cudnn) {
+ // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
+ //
+ // | TensorFlow | cuDNN
+ // --------------------------------------------------------------------
+ // filter_out_depth | depth_multiplier | depth_multiplier * group_count
+ // filter_in_depth | in_depth | in_depth / group_count
+ //
+ // For depthwise convolution, we have group_count == in_depth.
+ int32 filter_in_depth = 1;
+ TensorShape shape =
+ TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
+ Tensor reshaped_filter(/*type=*/dtype_);
+ OP_REQUIRES(
+ context, reshaped_filter.CopyFrom(filter, shape),
+ errors::Internal(
+ "Failed to reshape filter tensor for grouped convolution."));
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
// conv is supported.
- launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
- padding_, output, data_format_);
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, input,
+ reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
+ stride_, stride_, padding_, output, data_format_);
return;
}
@@ -403,6 +430,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
output_ptr, data_format_);
}
+ protected:
+ bool use_cudnn_grouped_conv_;
+
private:
std::vector<int32> strides_;
Padding padding_;
@@ -410,10 +440,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
int64 stride_; // in height/width dimension.
- // For the case in_depth == 1.
+ // For in_depth == 1 and grouped convolutions.
LaunchConv2DOp<Device, T> launcher_;
bool use_cudnn_;
bool cudnn_use_autotune_;
+ DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
};
@@ -421,7 +452,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- DepthwiseConv2dNativeOp<CPUDevice, T>);
+ DepthwiseConv2dNativeOp<CPUDevice, T>)
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
@@ -430,19 +461,38 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
#endif
#if GOOGLE_CUDA
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
- .Device(DEVICE_GPU)
- .TypeConstraint<Eigen::half>("T"),
- DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>);
-
-REGISTER_KERNEL_BUILDER(
- Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- DepthwiseConv2dNativeOp<GPUDevice, float>);
-
-REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
- .Device(DEVICE_GPU)
- .TypeConstraint<double>("T"),
- DepthwiseConv2dNativeOp<GPUDevice, double>);
-#endif
+
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ DepthwiseConv2dNativeOp<GPUDevice, T>)
+
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+TF_CALL_double(REGISTER_GPU_KERNEL);
+
+#if CUDNN_VERSION >= 7000
+template <typename T>
+class DepthwiseConv2dGroupedConvOp
+ : public DepthwiseConv2dNativeOp<GPUDevice, T> {
+ public:
+ DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
+ : DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
+ this->use_cudnn_grouped_conv_ = true;
+ }
+};
+
+#define REGISTER_GROUPED_CONV_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .Label("cudnn_grouped_convolution"), \
+ DepthwiseConv2dGroupedConvOp<T>)
+
+TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
+TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
+#endif // CUDNN_VERSION
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 916869fb56..a8bcc7f7dc 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -211,6 +211,11 @@ class AssignVariableOp : public OpKernel {
public:
explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+ if (!c->GetAttr("_grappler_relax_allocator_constraints",
+ &relax_constraints_)
+ .ok()) {
+ relax_constraints_ = false;
+ }
}
void Compute(OpKernelContext* context) override {
@@ -228,8 +233,10 @@ class AssignVariableOp : public OpKernel {
PersistentTensor unused;
Tensor* tmp;
AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
+ if (!relax_constraints_) {
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ }
TF_RETURN_IF_ERROR(context->allocate_persistent(
dtype_, context->input(1).shape(), &unused, &tmp, attr));
*(*ptr)->tensor() = *tmp;
@@ -245,8 +252,10 @@ class AssignVariableOp : public OpKernel {
const Tensor& value = context->input(1);
AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
+ if (!relax_constraints_) {
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ }
// Copying is unnecessary if we are the last user of the value
// tensor, we can just adopt the input tensor's buffer instead.
@@ -277,6 +286,7 @@ class AssignVariableOp : public OpKernel {
private:
DataType dtype_;
+ bool relax_constraints_;
};
template <typename Device>
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 2ad9fa265e..d0703d7576 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -24,35 +24,6 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
-
-// This file requires the following include because it uses CudaAtomicMax:
-// #include "tensorflow/core/util/cuda_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and CudaAtomicMax is used in template context.
-
-// This file requires the following include because it uses CudaAtomicMax:
-// #include "tensorflow/core/util/cuda_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and CudaAtomicMax is used in template context.
-
-// This file requires the following include because it uses CudaAtomicMax:
-// #include "tensorflow/core/util/cuda_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and CudaAtomicMax is used in template context.
-
-// This file requires the following include because it uses CudaAtomicMax:
-// #include "tensorflow/core/util/cuda_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and CudaAtomicMax is used in template context.
-
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index ca05e6346e..3f85303c0f 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
+#include <functional>
#include <string>
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -49,12 +50,28 @@ inline uint64 Hash64Combine(uint64 a, uint64 b) {
// In particular, tensorflow::hash is not the identity function for pointers.
// This is important for power-of-two sized hashtables like FlatMap and FlatSet,
// because otherwise they waste the majority of their hash buckets.
-template <typename T>
+//
+// The second type argument is only used for SFNIAE below.
+template <typename T, typename = void>
struct hash {
size_t operator()(const T& t) const { return std::hash<T>()(t); }
};
template <typename T>
+struct hash<T, typename std::enable_if<std::is_enum<T>::value>::type> {
+ size_t operator()(T value) const {
+ // This works around a defect in the std::hash C++ spec that isn't fixed in
+ // (at least) gcc 4.8.4:
+ // http://www.open-std.org/jtc1/sc22/wg21/docs/lwg-defects.html#2148
+ //
+ // We should be able to remove this and use the default
+ // tensorflow::hash<EnumTy>() once we stop building with GCC versions old
+ // enough to not have this defect fixed.
+ return std::hash<uint64>()(static_cast<uint64>(value));
+ }
+};
+
+template <typename T>
struct hash<T*> {
size_t operator()(const T* t) const {
// Hash pointers as integers, but bring more entropy to the lower bits.
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 88fc03826a..fce0b93cd7 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -466,7 +466,7 @@ REGISTER_OP("BroadcastTo")
// so no check needed.
if (i >= in_offset) {
DimensionHandle in_dim = c->Dim(in, i - in_offset);
- if (c->ValueKnown(in_dim)) {
+ if (c->ValueKnown(in_dim) && c->Value(in_dim) != 0) {
if (c->Value(dim) % c->Value(in_dim) != 0) {
return errors::InvalidArgument(
"Cannot broadcast a tensor with shape ", c->DebugString(in),
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 71ba5f016a..cb466ef817 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -24167,6 +24167,82 @@ op {
}
}
op {
+ name: "GroupByReducerDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "init_func_other_arguments"
+ type_list_attr: "Tinit_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "finalize_func_other_arguments"
+ type_list_attr: "Tfinalize_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ }
+ attr {
+ name: "init_func"
+ type: "func"
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "finalize_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tinit_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tfinalize_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "GroupByWindowDataset"
input_arg {
name: "input_dataset"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 4ba3f15ef0..73174c184c 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -270,6 +270,26 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("GroupByReducerDataset")
+ .Input("input_dataset: variant")
+ .Input("key_func_other_arguments: Tkey_func_other_arguments")
+ .Input("init_func_other_arguments: Tinit_func_other_arguments")
+ .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+ .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
+ .Output("handle: variant")
+ .Attr("key_func: func")
+ .Attr("init_func: func")
+ .Attr("reduce_func: func")
+ .Attr("finalize_func: func")
+ .Attr("Tkey_func_other_arguments: list(type) >= 0")
+ .Attr("Tinit_func_other_arguments: list(type) >= 0")
+ .Attr("Treduce_func_other_arguments: list(type) >= 0")
+ .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByWindowDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -458,11 +478,11 @@ REGISTER_OP("TextLineDataset")
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
- return shape_inference::ScalarShape(c);
// `compression_type` could only be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
// `buffer_size` could only be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
});
REGISTER_OP("SqlDataset")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 90368fe614..207dd1c3d7 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -11537,6 +11537,82 @@ op {
}
}
op {
+ name: "GroupByReducerDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "init_func_other_arguments"
+ type_list_attr: "Tinit_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "finalize_func_other_arguments"
+ type_list_attr: "Tfinalize_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ }
+ attr {
+ name: "init_func"
+ type: "func"
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "finalize_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tinit_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tfinalize_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "GroupByWindowDataset"
input_arg {
name: "input_dataset"
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 1819a35248..602f6a1ef1 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -27,6 +27,8 @@ import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
import "tensorflow/core/lib/core/error_codes.proto";
import "tensorflow/core/protobuf/config.proto";
import "tensorflow/core/protobuf/debug.proto";
@@ -413,3 +415,71 @@ message TracingRequest {
message TracingResponse {
}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Collective Op dynamic group resolution messages.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+// Supplies one or more device names as members of the group identified by
+// group_key. Service will respond when all group_size devices become known.
+// All devices in group must have same type.
+message CompleteGroupRequest {
+ int32 group_key = 1;
+ int32 group_size = 2;
+ string device_type = 3;
+ repeated string device_name = 4;
+}
+
+// Gives the complete membership of the group identified by group_key.
+message CompleteGroupResponse {
+ int32 group_key = 1;
+ int32 group_size = 2;
+ string device_type = 3;
+ int32 num_tasks = 4; // number of distinct tasks hosting the devices
+ repeated string device_name = 5;
+ repeated string task_name = 6; // task name prefixes of device_names
+}
+
+// Supplies data about one collective op belonging to the instance identified
+// by instance_key. Service will respond when all group_size ops have
+// become known. Most of the data being sent is for correctness checking,
+// to ensure that all ops in the instance share common attributes.
+message CompleteInstanceRequest {
+ string name = 1;
+ int32 type = 2;
+ DataType data_type = 3;
+ TensorShapeProto shape = 4;
+ int32 group_key = 5;
+ int32 group_size = 6;
+ int32 instance_key = 7;
+ string device_type = 8;
+ repeated int32 subdiv_offset = 9;
+ string device = 10;
+ bool is_source = 11;
+}
+
+// Confirms that every op in the instance has consistently declared itself.
+// Also gives the source_rank in case of broadcast.
+message CompleteInstanceResponse {
+ int32 instance_key = 1;
+ int32 source_rank = 2;
+}
+
+// Request for next agreed-upon step_id for the specified graph_keys.
+// This is used to enable multiple graphs containing nodes from
+// a common collective instance to coordinate using the same step_ids.
+message GetStepSequenceRequest {
+ repeated int64 graph_key = 1;
+}
+
+message StepSequence {
+ int64 graph_key = 1;
+ int64 next_step_id = 2;
+}
+
+// Next valid step_ids for one or more graph_keys.
+message GetStepSequenceResponse {
+ repeated StepSequence step_sequence = 1;
+}
diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto
index e1bfb04d7c..01c76c01a9 100644
--- a/tensorflow/core/protobuf/worker_service.proto
+++ b/tensorflow/core/protobuf/worker_service.proto
@@ -72,4 +72,14 @@ service WorkerService {
// See worker.proto for details.
rpc Tracing(TracingRequest) returns (TracingResponse);
+
+ // See worker.proto for details.
+ rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);
+
+ // See worker.proto for details.
+ rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);
+
+ // See worker.proto for details.
+ rpc CompleteInstance(CompleteInstanceRequest)
+ returns (CompleteInstanceResponse);
}
diff --git a/tensorflow/docs_src/community/benchmarks.md b/tensorflow/docs_src/community/benchmarks.md
index 67856ce869..153ef4a015 100644
--- a/tensorflow/docs_src/community/benchmarks.md
+++ b/tensorflow/docs_src/community/benchmarks.md
@@ -1,14 +1,14 @@
# Defining and Running Benchmarks
-This guide contains instructions for defining and running a TensorFlow benchmark. These benchmarks store output in [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) format. If these benchmarks are added to TensorFlow github repo, then we will run them daily with our continuous build and display a graph on our dashboard: https://benchmarks-dot-tensorflow-testing.appspot.com/.
+This guide contains instructions for defining and running a TensorFlow benchmark. These benchmarks store output in [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) format. If these benchmarks are added to the TensorFlow github repo, we will run them daily with our continuous build and display a graph on our dashboard: https://benchmarks-dot-tensorflow-testing.appspot.com/.
[TOC]
## Defining a Benchmark
-Defining a TensorFlow benchmark requires extending from `tf.test.Benchmark`
-class and calling `self.report_benchmark` method. For example, take a look at the sample benchmark code below:
+Defining a TensorFlow benchmark requires extending the `tf.test.Benchmark`
+class and calling the `self.report_benchmark` method. Below, you'll find an example of benchmark code:
```python
import time
@@ -54,20 +54,20 @@ Key points to note in the example above:
## Running with Python
-Use the `--benchmarks` flag to run the benchmark with python. A [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto will be printed.
+Use the `--benchmarks` flag to run the benchmark with Python. A [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto will be printed.
```
python sample_benchmark.py --benchmarks=SampleBenchmark
```
-Setting the flag as `--benchmarks=.` or `--benchmarks=all` would work as well.
+Setting the flag as `--benchmarks=.` or `--benchmarks=all` works as well.
-(Please ensure that Tensorflow is installed to successfully import the package in the line `import tensorflow as tf`. For installation instructions, see [Installing TensorFlow](https://www.tensorflow.org/install/). This step is not necessary when running with bazel.)
+(Please ensure that Tensorflow is installed to successfully import the package in the line `import tensorflow as tf`. For installation instructions, see [Installing TensorFlow](https://www.tensorflow.org/install/). This step is not necessary when running with Bazel.)
## Adding a `bazel` Target
-We have a special target called `tf_py_logged_benchmark` for benchmarks defined under TensorFlow github repo. `tf_py_logged_benchmark` should wrap around a regular `py_test` target. Running a `tf_py_logged_benchmark` would print a [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) proto. Defining a `tf_py_logged_benchmark` also lets us run it with TensorFlow continuous build.
+We have a special target called `tf_py_logged_benchmark` for benchmarks defined under the TensorFlow github repo. `tf_py_logged_benchmark` should wrap around a regular `py_test` target. Running a `tf_py_logged_benchmark` would print a [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) proto. Defining a `tf_py_logged_benchmark` also lets us run it with TensorFlow continuous build.
First, define a regular `py_test` target. See example below:
@@ -82,7 +82,7 @@ py_test(
)
```
-You can run benchmarks in a `py_test` target by passing `--benchmarks` flag. The benchmark should just print out a [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto.
+You can run benchmarks in a `py_test` target by passing the `--benchmarks` flag. The benchmark should just print out a [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto.
```shell
bazel test :sample_benchmark --test_arg=--benchmarks=all
@@ -90,7 +90,7 @@ bazel test :sample_benchmark --test_arg=--benchmarks=all
Now, add the `tf_py_logged_benchmark` target (if available). This target would
-pass in `--benchmarks=all` to the wrapped `py_test` target and provide a way to store output for our TensorFlow continuous build. `tf_py_logged_benchmark` target should be available in TensorFlow repository.
+pass in `--benchmarks=all` to the wrapped `py_test` target and provide a way to store output for our TensorFlow continuous build. The target `tf_py_logged_benchmark` should be available in TensorFlow repository.
```build
load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
diff --git a/tensorflow/docs_src/community/swift.md b/tensorflow/docs_src/community/swift.md
index a7da189a5c..e5a0f02a8c 100644
--- a/tensorflow/docs_src/community/swift.md
+++ b/tensorflow/docs_src/community/swift.md
@@ -8,7 +8,7 @@ Welcome to the Swift for TensorFlow development community!
Swift for TensorFlow is a new way to develop machine learning models. It
gives you the power of
-[TensorFlow](https://www.tensorflow.org/programmers_guide/eager) directly
+[TensorFlow](programmers_guide/eager) directly
integrated into the [Swift programming language](https://swift.org/about).
With Swift, you can write the following imperative code, and Swift
automatically turns it into **a single TensorFlow Graph** and runs it
@@ -28,15 +28,15 @@ print(x)
```
Swift combines the flexibility of
-[Eager Execution](https://www.tensorflow.org/programmers_guide/eager) with the
-high performance of [Graphs and Sessions](https://www.tensorflow.org/programmers_guide/graphs).
+[Eager Execution](programmers_guide/eager) with the
+high performance of [Graphs and Sessions](programmers_guide/graphs).
Behind the scenes, Swift analyzes your Tensor code and automatically builds
graphs for you. Swift also catches type errors and shape mismatches before
running your code, and has [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
built right in. We believe that machine learning tools are so important that
they deserve **a first-class language and a compiler**.
-**Note:** Swift for TensorFlow is an early stage research project. It has been
+Note: Swift for TensorFlow is an early stage research project. It has been
released to enable open source development and is not yet ready for general use
by machine learning developers.
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index c66d50c3cb..761555ca9a 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -496,6 +496,8 @@ If you are new to machine learning, we recommend the following:
* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course)
* @{$get_started/eager}
+If you are experienced with machine learning but new to TensorFlow, see
+@{$get_started/eager}.
<a name="NVIDIARequirements"></a>
## TensorFlow GPU support
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index ff6c2f5e44..90d9ea0288 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -409,7 +409,7 @@ If you are new to machine learning, we recommend the following:
* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
If you are experienced with machine learning but new to TensorFlow, see
-@{$get_started/premade_estimators$Getting Started with TensorFlow}.
+@{$get_started/eager}.
## Common installation problems
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 5c5c9e057b..a4fec382f4 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -388,7 +388,7 @@ TensorFlow programs:
<pre>Hello, TensorFlow!</pre>
-If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}.
+If you are new to TensorFlow, see @{$get_started/eager}.
If the system outputs an error message instead of a greeting, see [Common
installation problems](#common_installation_problems).
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index 86add74da1..a139a49661 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -163,7 +163,7 @@ If you are new to machine learning, we recommend the following:
* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
If you are experienced with machine learning but new to TensorFlow, see
-@{$get_started/premade_estimators$Getting Started with TensorFlow}.
+@{$get_started/eager}.
## Common installation problems
diff --git a/tensorflow/docs_src/performance/xla/index.md b/tensorflow/docs_src/performance/xla/index.md
index a884783074..8f5de83ea6 100644
--- a/tensorflow/docs_src/performance/xla/index.md
+++ b/tensorflow/docs_src/performance/xla/index.md
@@ -1,5 +1,9 @@
# XLA Overview
+<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:50%" src="/images/xlalogo.png">
+</div>
+
> Note: XLA is experimental and considered alpha. Most use cases will not
> see improvements in performance (speed or decreased memory usage). We have
> released XLA early so the Open Source Community can contribute to its
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 9e87995441..2f1be51ada 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -6655,6 +6655,101 @@ func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Outp
return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
}
+// Reverses specific dimensions of a tensor.
+//
+// NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
+// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0.
+//
+// Given a `tensor`, and a `int32` tensor `axis` representing the set of
+// dimensions of `tensor` to reverse. This operation reverses each dimension
+// `i` for which there exists `j` s.t. `axis[j] == i`.
+//
+// `tensor` can have up to 8 dimensions. The number of dimensions specified
+// in `axis` may be 0 or more entries. If an index is specified more than
+// once, a InvalidArgument error is raised.
+//
+// For example:
+//
+// ```
+// # tensor 't' is [[[[ 0, 1, 2, 3],
+// # [ 4, 5, 6, 7],
+// # [ 8, 9, 10, 11]],
+// # [[12, 13, 14, 15],
+// # [16, 17, 18, 19],
+// # [20, 21, 22, 23]]]]
+// # tensor 't' shape is [1, 2, 3, 4]
+//
+// # 'dims' is [3] or 'dims' is [-1]
+// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
+// [ 7, 6, 5, 4],
+// [ 11, 10, 9, 8]],
+// [[15, 14, 13, 12],
+// [19, 18, 17, 16],
+// [23, 22, 21, 20]]]]
+//
+// # 'dims' is '[1]' (or 'dims' is '[-3]')
+// reverse(t, dims) ==> [[[[12, 13, 14, 15],
+// [16, 17, 18, 19],
+// [20, 21, 22, 23]
+// [[ 0, 1, 2, 3],
+// [ 4, 5, 6, 7],
+// [ 8, 9, 10, 11]]]]
+//
+// # 'dims' is '[2]' (or 'dims' is '[-2]')
+// reverse(t, dims) ==> [[[[8, 9, 10, 11],
+// [4, 5, 6, 7],
+// [0, 1, 2, 3]]
+// [[20, 21, 22, 23],
+// [16, 17, 18, 19],
+// [12, 13, 14, 15]]]]
+// ```
+//
+// Arguments:
+// tensor: Up to 8-D.
+// axis: 1-D. The indices of the dimensions to reverse. Must be in the range
+// `[-rank(tensor), rank(tensor))`.
+//
+// Returns The same shape as `tensor`.
+func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReverseV2",
+ Input: []tf.Input{
+ tensor, axis,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Adds `bias` to `value`.
+//
+// This is a deprecated version of BiasAdd and will be soon removed.
+//
+// This is a special case of `tf.add` where `bias` is restricted to be 1-D.
+// Broadcasting is supported, so `value` may have any number of dimensions.
+//
+// Arguments:
+// value: Any number of dimensions.
+// bias: 1-D with size the last dimension of `value`.
+//
+// Returns Broadcasted sum of `value` and `bias`.
+func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BiasAddV1",
+ Input: []tf.Input{
+ value, bias,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Transforms a Tensor into a serialized TensorProto proto.
//
// Arguments:
@@ -13816,101 +13911,6 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms
return scope.AddOperation(opspec)
}
-// Adds `bias` to `value`.
-//
-// This is a deprecated version of BiasAdd and will be soon removed.
-//
-// This is a special case of `tf.add` where `bias` is restricted to be 1-D.
-// Broadcasting is supported, so `value` may have any number of dimensions.
-//
-// Arguments:
-// value: Any number of dimensions.
-// bias: 1-D with size the last dimension of `value`.
-//
-// Returns Broadcasted sum of `value` and `bias`.
-func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BiasAddV1",
- Input: []tf.Input{
- value, bias,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Reverses specific dimensions of a tensor.
-//
-// NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
-// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0.
-//
-// Given a `tensor`, and a `int32` tensor `axis` representing the set of
-// dimensions of `tensor` to reverse. This operation reverses each dimension
-// `i` for which there exists `j` s.t. `axis[j] == i`.
-//
-// `tensor` can have up to 8 dimensions. The number of dimensions specified
-// in `axis` may be 0 or more entries. If an index is specified more than
-// once, a InvalidArgument error is raised.
-//
-// For example:
-//
-// ```
-// # tensor 't' is [[[[ 0, 1, 2, 3],
-// # [ 4, 5, 6, 7],
-// # [ 8, 9, 10, 11]],
-// # [[12, 13, 14, 15],
-// # [16, 17, 18, 19],
-// # [20, 21, 22, 23]]]]
-// # tensor 't' shape is [1, 2, 3, 4]
-//
-// # 'dims' is [3] or 'dims' is [-1]
-// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
-// [ 7, 6, 5, 4],
-// [ 11, 10, 9, 8]],
-// [[15, 14, 13, 12],
-// [19, 18, 17, 16],
-// [23, 22, 21, 20]]]]
-//
-// # 'dims' is '[1]' (or 'dims' is '[-3]')
-// reverse(t, dims) ==> [[[[12, 13, 14, 15],
-// [16, 17, 18, 19],
-// [20, 21, 22, 23]
-// [[ 0, 1, 2, 3],
-// [ 4, 5, 6, 7],
-// [ 8, 9, 10, 11]]]]
-//
-// # 'dims' is '[2]' (or 'dims' is '[-2]')
-// reverse(t, dims) ==> [[[[8, 9, 10, 11],
-// [4, 5, 6, 7],
-// [0, 1, 2, 3]]
-// [[20, 21, 22, 23],
-// [16, 17, 18, 19],
-// [12, 13, 14, 15]]]]
-// ```
-//
-// Arguments:
-// tensor: Up to 8-D.
-// axis: 1-D. The indices of the dimensions to reverse. Must be in the range
-// `[-rank(tensor), rank(tensor))`.
-//
-// Returns The same shape as `tensor`.
-func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReverseV2",
- Input: []tf.Input{
- tensor, axis,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// RealAttr is an optional argument to Real.
type RealAttr func(optionalAttr)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 44d9147bb6..087b89b125 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4032,6 +4032,7 @@ cuda_py_tests(
"training/basic_loops_test.py",
"training/coordinator_test.py",
"training/device_setter_test.py",
+ "training/device_util_test.py",
"training/ftrl_test.py",
"training/gradient_descent_test.py",
"training/learning_rate_decay_test.py",
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index fe033f5546..a73a8b5cdc 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -197,6 +197,11 @@ class TFRecordDataset(dataset_ops.Dataset):
filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
filenames = dataset_ops.Dataset.from_tensor_slices(filenames)
+ self._filenames = filenames
+ self._compression_type = compression_type
+ self._buffer_size = buffer_size
+ self._num_parallel_reads = num_parallel_reads
+
def read_one_file(filename):
return _TFRecordDataset(filename, compression_type, buffer_size)
@@ -208,6 +213,16 @@ class TFRecordDataset(dataset_ops.Dataset):
block_length=1, sloppy=False, buffer_output_elements=None,
prefetch_input_elements=None)
+ def _clone(self,
+ filenames=None,
+ compression_type=None,
+ buffer_size=None,
+ num_parallel_reads=None):
+ return TFRecordDataset(filenames or self._filenames,
+ compression_type or self._compression_type,
+ buffer_size or self._buffer_size,
+ num_parallel_reads or self._num_parallel_reads)
+
def _as_variant_tensor(self):
return self._impl._as_variant_tensor() # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 92774d4d50..d04b004451 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -681,8 +681,8 @@ class GradientTape(object):
with tfe.GradientTape() as gg:
gg.watch(x)
y = x * x
- dy_dx = gg.gradient(y, [x])[0] # Will compute to 6.0
- d2y_dx2 = g.gradient(dy_dx, [x])[0] # Will compute to 2.0
+ dy_dx = gg.gradient(y, x) # Will compute to 6.0
+ d2y_dx2 = g.gradient(dy_dx, x) # Will compute to 2.0
```
By default, the resources held by a GradientTape are released as soon as
@@ -697,8 +697,8 @@ class GradientTape(object):
g.watch(x)
y = x * x
z = y * y
- dy_dx = g.gradient(z, [x])[0] # 6.0
- dz_dx = g.gradient(y, [x])[0] # 108.0 (4*x^3 at x = 3)
+ dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3)
+ dy_dx = g.gradient(y, x) # 6.0
del g # Drop the reference to the tape
"""
@@ -740,7 +740,7 @@ class GradientTape(object):
"""Computes the gradient using operations recorded in context of this tape.
Args:
- target: Tensor to be differentiated.
+ target: Tensor (or list of tensors) to be differentiated.
sources: a list or nested structure of Tensors or Variables. `target`
will be differentiated against elements in `sources`.
output_gradients: a list of gradients, one for each element of
@@ -762,8 +762,12 @@ class GradientTape(object):
flat_sources = nest.flatten(sources)
flat_sources = [_handle_or_self(x) for x in flat_sources]
+ if output_gradients is not None:
+ output_gradients = [None if x is None else ops.convert_to_tensor(x)
+ for x in nest.flatten(output_gradients)]
+
flat_grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, [target], flat_sources,
+ _default_vspace, self._tape, nest.flatten(target), flat_sources,
output_gradients=output_gradients)
if not self._persistent:
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 991b4dbe7a..8d9959fe20 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -96,6 +96,26 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(grads_and_vars[0][0], 1.0)
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
+ def testTwoTargets(self):
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(2.0)
+ t.watch([x, y])
+ xx = 2 * x
+ yy = 3 * y
+ dx, dy = t.gradient([xx, yy], [x, y])
+ self.assertAllEqual(dx, 2.0)
+ self.assertAllEqual(dy, 3.0)
+
+ def testOutputGradUsedInComputation(self):
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(2.0)
+ t.watch([x, y])
+ loss = x * y
+ dx, = t.gradient([loss, x], [x], output_gradients=[1.0, 2.0])
+ self.assertAllEqual(dx, 4.0)
+
def testDy(self):
def f(x):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 426ee4c215..741bd2ac9c 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -716,8 +716,7 @@ def defun(func):
objects. Non-Tensor python objects are treated as constants, and new function
definitions are created internally based on their values.
- func must return a tf.Tensor (NOT a Tensor) or a list of tf.Tensor (NOT a
- Tensor).
+ func must return zero or more `tf.Tensor`.
Control flow constructs (e.g., `if`, `while`) are not yet compatible with
`defun`.
@@ -748,7 +747,7 @@ def defun(func):
Returns:
A callable that will execute the compiled function (and return zero
- or more Tensor objects).
+ or more `tf.Tensor` objects).
"""
# TODO(apassos): deal with captured global state. Deal with control flow.
try:
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 085dace1b3..d281fd90ea 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -49,35 +49,10 @@ _TreeHParams = collections.namedtuple('TreeHParams', [
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
_HOLD_FOR_MULTI_DIM_SUPPORT = object()
+_DUMMY_NUM_BUCKETS = -1
-def _get_max_buckets(feature_columns):
- """Gets the maximum number of buckets from feature_columns.
-
- Args:
- feature_columns: a list/set of tf.feature_column.
-
- Returns:
- max_buckets: the maximum number of buckets among bucketized_columns.
-
- Raises:
- ValueError: when unsupported feature_columns are given.
- """
- if not feature_columns:
- raise ValueError('feature_columns must be a non-empty list/set of '
- 'tf.feature_column.')
- max_buckets = 1
- for fc in feature_columns:
- if isinstance(fc, feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
- # N boundaries creates (N+1) buckets.
- max_buckets = max(max_buckets, len(fc.boundaries) + 1)
- else:
- raise ValueError('For now, only bucketized_column is supported but '
- 'got: {}'.format(fc))
- return max_buckets
-
-
-def _get_transformed_features(features, feature_columns):
+def _get_transformed_features(features, sorted_feature_columns):
"""Gets the transformed features from features/feature_columns pair.
Args:
@@ -91,22 +66,33 @@ def _get_transformed_features(features, feature_columns):
ValueError: when unsupported features/columns are tried.
"""
# pylint:disable=protected-access
- for fc in feature_columns:
- if not isinstance(fc, feature_column_lib._BucketizedColumn):
- raise ValueError('For now, only bucketized_column is supported but '
- 'got: {}'.format(fc))
transformed_features = feature_column_lib._transform_features(
- features, feature_columns)
- # pylint:enable=protected-access
+ features, sorted_feature_columns)
result_features = []
- for column in sorted(transformed_features, key=lambda tc: tc.name):
- source_name = column.source_column.name
- squeezed_tensor = array_ops.squeeze(transformed_features[column], axis=1)
- if len(squeezed_tensor.shape) > 1:
- raise ValueError('For now, only supports features equivalent to rank 1 '
- 'but column `{}` got: {}'.format(
- source_name, features[source_name].shape))
- result_features.append(squeezed_tensor)
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._BucketizedColumn):
+ source_name = column.source_column.name
+ squeezed_tensor = array_ops.squeeze(transformed_features[column], axis=1)
+ if len(squeezed_tensor.shape) > 1:
+ raise ValueError('For now, only supports features equivalent to rank 1 '
+ 'but column `{}` got: {}'.format(
+ source_name, features[source_name].shape))
+ result_features.append(squeezed_tensor)
+ elif isinstance(column, feature_column_lib._IndicatorColumn):
+ source_name = column.categorical_column.name
+ tensor = math_ops.to_int32(transformed_features[column])
+ if len(tensor.shape) > 2:
+ raise ValueError('Rank of indicator column must be no more than 2, '
+ 'but column `{}` got: {}'.format(
+ source_name, features[source_name].shape))
+ unstacked = array_ops.unstack(tensor, axis=1)
+ result_features.extend(unstacked)
+ else:
+ raise ValueError(
+ 'For now, only bucketized_column and indicator_column is supported '
+ 'but got: {}'.format(column))
+ # pylint:enable=protected-access
+
return result_features
@@ -120,9 +106,87 @@ def _local_variable(tensor, name=None):
name=name)
-def _cache_transformed_features(features, feature_columns, batch_size):
+def _group_features_by_num_buckets(sorted_feature_columns):
+ """Groups feature ids by the number of buckets.
+
+ Derives the feature ids based on iterating through ordered feature columns
+ and groups them by the number of buckets each feature require. Returns a
+ sorted list of buckets and a list of lists of feature ids for each of those
+ buckets.
+
+ Args:
+ sorted_feature_columns: a list/set of tf.feature_column sorted by name.
+
+ Returns:
+ bucket_size_list: a list of required bucket sizes.
+ feature_ids_list: a list of lists of feature ids for each bucket size.
+
+ Raises:
+ ValueError: when unsupported features columns are provided.
+ """
+ bucket_size_to_feature_ids_dict = collections.OrderedDict()
+
+ # TODO(nponomareva) for now we preserve the previous functionality and bucket
+ # all numeric into the same num of buckets. Can be easily changed to using
+ # each numeric's real buckets num, but we need to test that it does not cause
+ # a performance hit.
+
+ # We will replace this dummy key with the real max after we calculate it.
+ bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS] = []
+
+ max_buckets_for_bucketized = 2
+ max_buckets_for_indicator = 2
+
+ feature_idx = 0
+ # pylint:disable=protected-access
+
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._IndicatorColumn):
+ num_categorical_features = column.categorical_column._num_buckets
+ if max_buckets_for_indicator not in bucket_size_to_feature_ids_dict:
+ bucket_size_to_feature_ids_dict[max_buckets_for_indicator] = []
+
+ for _ in range(num_categorical_features):
+ # We use bucket size of 2 for categorical.
+ bucket_size_to_feature_ids_dict[max_buckets_for_indicator].append(
+ feature_idx)
+ feature_idx += 1
+ elif isinstance(column, feature_column_lib._BucketizedColumn):
+ max_buckets_for_bucketized = max(max_buckets_for_bucketized,
+ len(column.boundaries) + 1)
+ bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS].append(feature_idx)
+ feature_idx += 1
+ elif not isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
+ raise ValueError(
+ 'For now, only bucketized_column and indicator column are supported '
+ 'but got: {}'.format(column))
+
+ # pylint:enable=protected-access
+ # Replace the dummy key with the real max num of buckets for all bucketized
+ # columns.
+ bucket_size_to_feature_ids_dict[
+ max_buckets_for_bucketized] = bucket_size_to_feature_ids_dict[
+ _DUMMY_NUM_BUCKETS]
+ del bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS]
+
+ feature_ids_list = list(bucket_size_to_feature_ids_dict.values())
+ bucket_size_list = list(bucket_size_to_feature_ids_dict.keys())
+ return bucket_size_list, feature_ids_list
+
+
+def _calculate_num_features(sorted_feature_columns):
+ num_features = 0
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
+ num_features += column.categorical_column._num_buckets # pylint:disable=protected-access
+ else:
+ num_features += 1
+ return num_features
+
+
+def _cache_transformed_features(features, sorted_feature_columns, batch_size):
"""Transform features and cache, then returns (cached_features, cache_op)."""
- num_features = len(feature_columns)
+ num_features = _calculate_num_features(sorted_feature_columns)
cached_features = [
_local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
@@ -132,7 +196,7 @@ def _cache_transformed_features(features, feature_columns, batch_size):
are_features_cached = _local_variable(False, name='are_features_cached')
def cache_features_and_return():
- """Caches transoformed features.
+ """Caches transformed features.
The intention is to hide get_transformed_features() from the graph by
caching the result except the first step, since bucketize operation
@@ -144,7 +208,8 @@ def _cache_transformed_features(features, feature_columns, batch_size):
the graph.
"""
- transformed_features = _get_transformed_features(features, feature_columns)
+ transformed_features = _get_transformed_features(features,
+ sorted_feature_columns)
cached = [
state_ops.assign(cached_features[i], transformed_features[i])
for i in range(num_features)
@@ -349,6 +414,8 @@ def _bt_model_fn(
ValueError: mode or params are invalid, or features has the wrong type.
"""
is_single_machine = (config.num_worker_replicas <= 1)
+
+ sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
@@ -364,24 +431,26 @@ def _bt_model_fn(
# the dimension max_splits_per_layer, instead of max_splits (for the entire
# tree).
max_splits = (1 << tree_hparams.max_depth) - 1
- max_buckets = _get_max_buckets(feature_columns)
train_op = []
with ops.name_scope(name) as name:
# Prepare.
global_step = training_util.get_or_create_global_step()
- num_features = len(feature_columns)
+ bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
+ sorted_feature_columns)
# Extract input features and set up cache for training.
training_state_cache = None
if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
# cache transformed features as well for in-memory training.
batch_size = array_ops.shape(labels)[0]
- input_feature_list, input_cache_op = _cache_transformed_features(
- features, feature_columns, batch_size)
+ input_feature_list, input_cache_op = (
+ _cache_transformed_features(features, sorted_feature_columns,
+ batch_size))
train_op.append(input_cache_op)
training_state_cache = _CacheTrainingStatesUsingVariables(
batch_size, head.logits_dimension)
else:
- input_feature_list = _get_transformed_features(features, feature_columns)
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
@@ -446,34 +515,61 @@ def _bt_model_fn(
gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
hessians = gradients_impl.gradients(
gradients, logits, name='Hessians')[0]
- stats_summary_list = [
- array_ops.squeeze(
- boosted_trees_ops.make_stats_summary(
- node_ids=node_ids,
- gradients=gradients,
- hessians=hessians,
- bucketized_features_list=[input_feature_list[f]],
- max_splits=max_splits,
- num_buckets=max_buckets),
- axis=0) for f in range(num_features)
- ]
-
- def grow_tree_from_stats_summaries(stats_summary_list):
+
+ stats_summaries_list = []
+ for i, feature_ids in enumerate(feature_ids_list):
+ num_buckets = bucket_size_list[i]
+ summaries = [
+ array_ops.squeeze(
+ boosted_trees_ops.make_stats_summary(
+ node_ids=node_ids,
+ gradients=gradients,
+ hessians=hessians,
+ bucketized_features_list=[input_feature_list[f]],
+ max_splits=max_splits,
+ num_buckets=num_buckets),
+ axis=0) for f in feature_ids
+ ]
+ stats_summaries_list.append(summaries)
+
+ accumulators = []
+
+ def grow_tree_from_stats_summaries(stats_summaries_list,
+ feature_ids_list):
"""Updates ensemble based on the best gains from stats summaries."""
- (node_ids_per_feature, gains_list, thresholds_list,
- left_node_contribs_list, right_node_contribs_list) = (
- boosted_trees_ops.calculate_best_gains_per_feature(
- node_id_range=last_layer_nodes_range,
- stats_summary_list=stats_summary_list,
- l1=tree_hparams.l1,
- l2=tree_hparams.l2,
- tree_complexity=tree_hparams.tree_complexity,
- min_node_weight=tree_hparams.min_node_weight,
- max_splits=max_splits))
+ node_ids_per_feature = []
+ gains_list = []
+ thresholds_list = []
+ left_node_contribs_list = []
+ right_node_contribs_list = []
+ all_feature_ids = []
+
+ assert len(stats_summaries_list) == len(feature_ids_list)
+
+ for i, feature_ids in enumerate(feature_ids_list):
+ (numeric_node_ids_per_feature, numeric_gains_list,
+ numeric_thresholds_list, numeric_left_node_contribs_list,
+ numeric_right_node_contribs_list) = (
+ boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range=last_layer_nodes_range,
+ stats_summary_list=stats_summaries_list[i],
+ l1=tree_hparams.l1,
+ l2=tree_hparams.l2,
+ tree_complexity=tree_hparams.tree_complexity,
+ min_node_weight=tree_hparams.min_node_weight,
+ max_splits=max_splits))
+
+ all_feature_ids += feature_ids
+ node_ids_per_feature += numeric_node_ids_per_feature
+ gains_list += numeric_gains_list
+ thresholds_list += numeric_thresholds_list
+ left_node_contribs_list += numeric_left_node_contribs_list
+ right_node_contribs_list += numeric_right_node_contribs_list
+
grow_op = boosted_trees_ops.update_ensemble(
# Confirm if local_tree_ensemble or tree_ensemble should be used.
tree_ensemble.resource_handle,
- feature_ids=math_ops.range(0, num_features, dtype=dtypes.int32),
+ feature_ids=all_feature_ids,
node_ids=node_ids_per_feature,
gains=gains_list,
thresholds=thresholds_list,
@@ -486,32 +582,50 @@ def _bt_model_fn(
if train_in_memory and is_single_machine:
train_op.append(distribute_lib.increment_var(global_step))
- train_op.append(grow_tree_from_stats_summaries(stats_summary_list))
+ train_op.append(
+ grow_tree_from_stats_summaries(stats_summaries_list,
+ feature_ids_list))
else:
- summary_accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of gradients and hessians (the last dimension).
- shape=[num_features, max_splits, max_buckets, 2],
- shared_name='stats_summary_accumulator')
- apply_grad = summary_accumulator.apply_grad(
- array_ops.stack(stats_summary_list, axis=0), stamp_token)
+ dependencies = []
+
+ for i, feature_ids in enumerate(feature_ids_list):
+ stats_summaries = stats_summaries_list[i]
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ accumulators.append(accumulator)
+
+ apply_grad = accumulator.apply_grad(
+ array_ops.stack(stats_summaries, axis=0), stamp_token)
+ dependencies.append(apply_grad)
def grow_tree_from_accumulated_summaries_fn():
"""Updates the tree with the best layer from accumulated summaries."""
# Take out the accumulated summaries from the accumulator and grow.
- stats_summary_list = array_ops.unstack(
- summary_accumulator.take_grad(1), axis=0)
- grow_op = grow_tree_from_stats_summaries(stats_summary_list)
+ stats_summaries_list = []
+
+ stats_summaries_list = [
+ array_ops.unstack(accumulator.take_grad(1), axis=0)
+ for accumulator in accumulators
+ ]
+
+ grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
+ feature_ids_list)
return grow_op
- with ops.control_dependencies([apply_grad]):
+ with ops.control_dependencies(dependencies):
train_op.append(distribute_lib.increment_var(global_step))
if config.is_chief:
+ min_accumulated = math_ops.reduce_min(
+ array_ops.stack(
+ [acc.num_accumulated() for acc in accumulators]))
+
train_op.append(
control_flow_ops.cond(
- math_ops.greater_equal(
- summary_accumulator.num_accumulated(),
- n_batches_per_layer),
+ math_ops.greater_equal(min_accumulated,
+ n_batches_per_layer),
grow_tree_from_accumulated_summaries_fn,
control_flow_ops.no_op,
name='wait_until_n_batches_accumulated'))
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index c8c52d3bc6..95bb9b5a3b 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -46,6 +46,7 @@ INPUT_FEATURES = np.array(
[3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3]
],
dtype=np.float32)
+
CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]]
REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]]
FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)}
@@ -101,17 +102,25 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
def _assert_checkpoint(self, model_dir, global_step, finalized_trees,
attempted_layers):
+ self._assert_checkpoint_and_return_model(model_dir, global_step,
+ finalized_trees, attempted_layers)
+
+ def _assert_checkpoint_and_return_model(self, model_dir, global_step,
+ finalized_trees, attempted_layers):
reader = checkpoint_utils.load_checkpoint(model_dir)
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
serialized = reader.get_tensor('boosted_trees:0_serialized')
ensemble_proto = boosted_trees_pb2.TreeEnsemble()
ensemble_proto.ParseFromString(serialized)
+
self.assertEqual(
finalized_trees,
sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
self.assertEqual(attempted_layers,
ensemble_proto.growing_metadata.num_layers_attempted)
+ return ensemble_proto
+
def testTrainAndEvaluateBinaryClassifier(self):
input_fn = _make_train_input_fn(is_classification=True)
@@ -325,6 +334,55 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
[[0.353850], [0.254100], [0.106850], [0.712100], [1.012100]],
[pred['predictions'] for pred in predictions])
+ def testTrainEvaluateAndPredictWithIndicatorColumn(self):
+ categorical = feature_column.categorical_column_with_vocabulary_list(
+ key='categorical', vocabulary_list=('bad', 'good', 'ok'))
+ feature_indicator = feature_column.indicator_column(categorical)
+ bucketized_col = feature_column.bucketized_column(
+ feature_column.numeric_column(
+ 'an_uninformative_feature', dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+
+ labels = np.array([[0.], [5.7], [5.7], [0.], [0.]], dtype=np.float32)
+ # Our categorical feature defines the labels perfectly
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'an_uninformative_feature': np.array([1, 1, 1, 1, 1]),
+ 'categorical': np.array(['bad', 'good', 'good', 'ok', 'bad']),
+ },
+ y=labels,
+ batch_size=5,
+ shuffle=False)
+
+ # Train depth 1 tree.
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[bucketized_col, feature_indicator],
+ n_batches_per_layer=1,
+ n_trees=1,
+ learning_rate=1.0,
+ max_depth=1)
+
+ num_steps = 1
+ est.train(input_fn, steps=num_steps)
+ ensemble = self._assert_checkpoint_and_return_model(
+ est.model_dir, global_step=1, finalized_trees=1, attempted_layers=1)
+
+ # We learnt perfectly.
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['loss'], 0)
+
+ predictions = list(est.predict(input_fn))
+ self.assertAllClose(
+ labels,
+ [pred['predictions'] for pred in predictions])
+
+ self.assertEqual(3, len(ensemble.trees[0].nodes))
+
+ # Check that the split happened on 'good' value, which will be encoded as
+ # feature with index 2 (0-numeric, 1 - 'bad')
+ self.assertEqual(2, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
+ self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index e750e243be..3691c99dda 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -155,12 +155,12 @@ class Estimator(object):
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
- warm_start_from: Optional string filepath to a checkpoint to warm-start
- from, or a `tf.estimator.WarmStartSettings` object to
- fully configure warm-starting. If the string filepath is
- provided instead of a `WarmStartSettings`, then all
- variables are warm-started, and it is assumed that
- vocabularies and Tensor names are unchanged.
+ warm_start_from: Optional string filepath to a checkpoint or SavedModel to
+ warm-start from, or a `tf.estimator.WarmStartSettings`
+ object to fully configure warm-starting. If the string
+ filepath is provided instead of a `WarmStartSettings`,
+ then all variables are warm-started, and it is assumed
+ that vocabularies and Tensor names are unchanged.
Raises:
ValueError: parameters of `model_fn` don't match `params`.
@@ -400,7 +400,9 @@ class Estimator(object):
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the evaluation call.
checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
- latest checkpoint in `model_dir` is used.
+ latest checkpoint in `model_dir` is used. If there are no checkpoints
+ in `model_dir`, evaluation is run with newly initialized `Variables`
+ instead of restored from checkpoint.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data. Metrics for
different evaluations are saved in separate folders, and appear
@@ -464,7 +466,9 @@ class Estimator(object):
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the prediction call.
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
- latest checkpoint in `model_dir` is used.
+ latest checkpoint in `model_dir` is used. If there are no checkpoints
+ in `model_dir`, prediction is run with newly initialized `Variables`
+ instead of restored from checkpoint.
yield_single_examples: If False, yield the whole batch as returned by the
`model_fn` instead of decomposing the batch into individual elements.
This is useful if `model_fn` returns some tensors whose first dimension
@@ -487,9 +491,8 @@ class Estimator(object):
if not checkpoint_path:
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
- raise ValueError(
- 'Could not find trained model in model_dir: {}.'.format(
- self._model_dir))
+ logging.info('Could not find trained model in model_dir: {}, running '
+ 'initialization to predict.'.format(self._model_dir))
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -1066,8 +1069,8 @@ class Estimator(object):
if not checkpoint_path:
latest_path = saver.latest_checkpoint(self._model_dir)
if not latest_path:
- raise ValueError('Could not find trained model in model_dir: {}.'.
- format(self._model_dir))
+ logging.info('Could not find trained model in model_dir: {}, running '
+ 'initialization to evaluate.'.format(self._model_dir))
checkpoint_path = latest_path
# Setup output directory.
@@ -1499,7 +1502,7 @@ def _get_default_warm_start_settings(warm_start_from):
Args:
warm_start_from: Either a string representing the filepath of a checkpoint
- to initialize from, or an instance of WarmStartSettings.
+ or SavedModel to initialize from, or an instance of WarmStartSettings.
Returns:
Either None or an instance of WarmStartSettings.
@@ -1510,9 +1513,19 @@ def _get_default_warm_start_settings(warm_start_from):
"""
if warm_start_from is None:
return None
- if isinstance(warm_start_from, six.string_types):
+ if isinstance(warm_start_from, (six.string_types, six.binary_type)):
+ # Infer that this is a SavedModel if export_path +
+ # 'variables/variables.index' exists, and if so, construct the
+ # WarmStartSettings pointing to export_path + 'variables/variables'.
+ if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
+ compat.as_bytes('variables/variables.index'))):
+ logging.info('Warm-starting from a SavedModel')
+ return WarmStartSettings(ckpt_to_initialize_from=os.path.join(
+ compat.as_bytes(warm_start_from),
+ compat.as_bytes('variables/variables')))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from
else:
- raise ValueError('warm_start_from must be a string or a WarmStartSettings')
+ raise ValueError('warm_start_from must be a string or a WarmStartSettings, '
+ 'instead got {}'.format(type(warm_start_from)))
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 0fea86124c..4d958f8b43 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -658,6 +658,41 @@ class EstimatorTrainTest(test.TestCase):
5, estimator._load_global_step_from_checkpoint_dir(
warm_started_est.model_dir))
+ def test_warm_starts_from_savedmodel(self):
+ def _make_model_fn(x):
+ def _variable_creating_and_export_model_fn(features, labels, mode):
+ _, _ = features, labels
+ variable_scope.get_variable('x', initializer=x)
+ global_step = training.get_global_step()
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions={'y': constant_op.constant(1.0)},
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(global_step, 1),
+ export_outputs={'test': export_output.ClassificationOutput(
+ constant_op.constant([4.2]), constant_op.constant(['label']))})
+ return _variable_creating_and_export_model_fn
+
+ est = estimator.Estimator(model_fn=_make_model_fn(42.))
+ est.train(dummy_input_fn, steps=10)
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ tmpdir = tempfile.mkdtemp()
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(
+ export_dir_base, serving_input_receiver_fn)
+
+ warm_started_est = estimator.Estimator(
+ model_fn=_make_model_fn(36.),
+ warm_start_from=export_dir)
+ warm_started_est.train(dummy_input_fn, steps=5)
+ # warm_start is called after the model_fn, so x should have the value
+ # from the SavedModel.
+ self.assertEqual(42., warm_started_est.get_variable_value('x'))
+
def test_max_step(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
est.train(dummy_input_fn, max_steps=5)
@@ -1067,11 +1102,19 @@ class EstimatorEvaluateTest(test.TestCase):
ValueError, 'model_fn should return an EstimatorSpec'):
est.evaluate(dummy_input_fn, steps=1)
- def test_no_trained_model(self):
- est = estimator.Estimator(model_fn=_model_fn_with_eval_metric_ops)
- with self.assertRaisesRegexp(
- ValueError, 'Could not find trained model in model_dir'):
- est.evaluate(dummy_input_fn, steps=1)
+ def test_no_checkpoint_uses_init(self):
+ def _model_fn(features, labels, mode, params):
+ del features, labels, params
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ eval_metric_ops={'metric': metrics_lib.mean(
+ variables.Variable(2.) + 1)})
+ est = estimator.Estimator(model_fn=_model_fn)
+ metrics = est.evaluate(dummy_input_fn, steps=1)
+ # Metric value here is set to 1 + the value of the Variable that is newly
+ # initialized (since there is no checkpoint).
+ self.assertEqual(3., metrics['metric'])
def test_scores(self):
est = estimator.Estimator(
@@ -1331,11 +1374,15 @@ class EstimatorPredictTest(test.TestCase):
next(est.predict(_input_fn))
self.assertEqual(1, input_fn_call_count[0])
- def test_no_trained_model_in_model_dir(self):
- est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
- with self.assertRaisesRegexp(ValueError,
- 'Could not find trained model in model_dir'):
- next(est.predict(dummy_input_fn))
+ def test_no_checkpoint_uses_init(self):
+ def _model_fn(features, labels, mode, params, config):
+ del features, labels, params, config
+ x = variables.Variable([[3.]], name='x')
+ return model_fn_lib.EstimatorSpec(mode, predictions=math_ops.add(x, 1.))
+ est = estimator.Estimator(model_fn=_model_fn)
+ # Expected prediction value is 1 + the value of the Variable that is newly
+ # initialized (since there is no checkpoint).
+ self.assertEqual(4., next(est.predict(dummy_input_fn)))
def test_no_trained_model_invalid_checkpoint_path(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py
index 1cd51df4d9..654013b23c 100644
--- a/tensorflow/python/grappler/graph_placer.py
+++ b/tensorflow/python/grappler/graph_placer.py
@@ -55,11 +55,6 @@ def PlaceGraph(metagraph,
# Optimize the metagraph to speedup the placement
rewriter_config = rewriter_config_pb2.RewriterConfig()
- rewriter_config.optimizers.append("pruning")
- rewriter_config.optimizers.append("constfold")
- rewriter_config.optimizers.append("arithmetic")
- rewriter_config.optimizers.append("dependency")
- rewriter_config.optimizers.append("pruning")
optimized_graph = tf_optimizer.OptimizeGraph(
rewriter_config, metagraph, verbose=verbose, cluster=cluster)
optimized_metagraph = meta_graph_pb2.MetaGraphDef()
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
index 2b353ac007..f7398845d4 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -153,7 +153,8 @@ class Embedding(Layer):
return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
def call(self, inputs):
- if K.dtype(inputs) != 'int32':
+ dtype = K.dtype(inputs)
+ if dtype != 'int32' and dtype != 'int64':
inputs = math_ops.cast(inputs, 'int32')
out = embedding_ops.embedding_lookup(self.embeddings, inputs)
return out
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index b4ff094cdf..c892b6ee9a 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -112,6 +112,22 @@ cuda_py_test(
tags = ["no_windows"],
)
+cuda_py_test(
+ name = "reduce_benchmark_test",
+ srcs = ["reduce_benchmark_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_benchmark",
+ ],
+)
+
tf_py_test(
name = "bincount_op_test",
size = "small",
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index f0bb84e69a..5cceb98cff 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -224,7 +224,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
sess.run(right_node_contribs_list))
- def testCalculateBestGainsWithMinNodeWEight(self):
+ def testCalculateBestGainsWithMinNodeWeight(self):
"""Testing Gain calculation without any regularization."""
with self.test_session() as sess:
max_splits = 7
@@ -271,6 +271,59 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose([[[-0.75]], [[-0.014925]]],
sess.run(right_node_contribs_list))
+ def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
+ """Testing Gain calculation without any regularization."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 3] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .0036], [.06, .007], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .068], [0., 0.], [.3, .04]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1
+ [[.1, .1], [.2, .03], [-.4, .05], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ (node_ids_list, _, _, _,
+ _) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=1,
+ max_splits=max_splits)
+
+ # We can't split either of the nodes on the first feature
+ self.assertEqual(2, len(sess.run(node_ids_list)))
+ self.assertAllEqual([], sess.run(node_ids_list)[0])
+ self.assertAllEqual([1], sess.run(node_ids_list)[1])
+
+ # Now check when we can't split on any feature
+ (node_ids_list, _, _, _,
+ _) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=10,
+ max_splits=max_splits)
+ self.assertAllEqual([[], []], sess.run(node_ids_list))
+
def testMakeStatsSummarySimple(self):
"""Simple test for MakeStatsSummary."""
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index f7ae1a0f37..659dc0419a 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -22,12 +22,15 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
def ConfigsToTest():
@@ -98,6 +101,7 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
use_gpu,
+ grouped_conv=False,
data_format="NHWC"):
"""Verifies the output values of the convolution function.
@@ -110,25 +114,26 @@ class DepthwiseConv2DTest(test.TestCase):
padding: Padding type.
data_type: The data type to use.
use_gpu: Whether to use GPU.
+ grouped_conv: Whether to use cuDNN 7's grouped convolution.
data_format: The data_format of the input. "NHWC" or "NCHW".
"""
- total_size_1 = 1
- total_size_2 = 1
+ input_size = 1
+ filter_size = 1
for s in tensor_in_sizes:
- total_size_1 *= s
+ input_size *= s
for s in filter_in_sizes:
- total_size_2 *= s
+ filter_size *= s
# Initializes the input and filter tensor with numbers incrementing from 1.
- x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
- x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
- with self.test_session(use_gpu=use_gpu) as sess:
- if data_type == dtypes.float16:
- tolerance = 1e-5
- elif data_type == dtypes.float32:
- tolerance = 1e-5
- else:
- self.assertEqual(data_type, dtypes.float64)
- tolerance = 1e-8
+ x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
+ x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
+ ops.reset_default_graph()
+ graph = ops.get_default_graph()
+ with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ tolerance = {
+ dtypes.float16: 4e-2,
+ dtypes.float32: 1e-8,
+ dtypes.float64: 1e-13,
+ }[data_type]
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
t1.set_shape(tensor_in_sizes)
@@ -142,25 +147,39 @@ class DepthwiseConv2DTest(test.TestCase):
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
- conv_native = nn_ops.depthwise_conv2d_native(
- native_t1,
- t2,
- strides=strides,
- data_format=data_format,
- padding=padding)
+ with sess.graph._kernel_label_map({
+ "DepthwiseConv2dNative": "cudnn_grouped_convolution"
+ } if grouped_conv else {}):
+ conv_native = nn_ops.depthwise_conv2d_native(
+ native_t1,
+ t2,
+ strides=strides,
+ data_format=data_format,
+ padding=padding)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
+ try:
+ native_result = sess.run(conv_native)
+ except errors.InvalidArgumentError as e:
+ # Grouped convolution kernel is only registered for cuDNN 7. Silently
+ # return when we are running on an earlier version or without GPU.
+ if e.message.startswith(
+ "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
+ tf_logging.warn("Skipping grouped convolution test")
+ return
+ raise e
+
conv_interface = nn_impl.depthwise_conv2d(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
-
- native_result = sess.run(conv_native)
interface_result = sess.run(conv_interface)
- print("data_type:", data_type, "use_gpu:", use_gpu, "max diff = ",
- np.amax(np.absolute(native_result - interface_result)))
+ tf_logging.info(
+ "data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f",
+ data_type, use_gpu, grouped_conv,
+ np.amax(np.absolute(native_result - interface_result)))
self.assertArrayNear(
np.ravel(native_result), np.ravel(interface_result), tolerance)
self.assertShapeEqual(native_result, conv_native)
@@ -169,11 +188,22 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2D(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
- filter_size, "stride:", stride, "padding:", padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
+ "%s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
+ tf_logging.info("Testing with grouped_conv")
+ self._VerifyValues(
+ input_size,
+ filter_size,
+ stride,
+ padding,
+ data_type,
+ use_gpu=True,
+ grouped_conv=True)
def testDepthwiseConv2DFormat(self):
if not test.is_gpu_available():
@@ -181,8 +211,9 @@ class DepthwiseConv2DTest(test.TestCase):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
- "*", filter_size, "stride:", stride, "padding:", padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
+ "padding: %s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._VerifyValues(
input_size,
@@ -226,7 +257,7 @@ class DepthwiseConv2DTest(test.TestCase):
conv = nn_ops.depthwise_conv2d_native(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
value = sess.run(conv)
- print("value = ", value)
+ tf_logging.info("value = %r", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
@@ -296,7 +327,7 @@ class DepthwiseConv2DTest(test.TestCase):
expected=expected_output,
use_gpu=True)
- # Gradient checkers.This tests depthwise gradient computations for both
+ # Gradient checkers. This tests depthwise gradient computations for both
# BackpropFilter and BackpropInput by comparing gradients computed by the
# depthwise gradient ops with the gradients computed numerically (details can
# be found in the compute_gradient_error().
@@ -310,6 +341,7 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input,
use_gpu,
+ grouped_conv=False,
data_format="NHWC"):
input_size = 1
for x in input_shape:
@@ -319,14 +351,14 @@ class DepthwiseConv2DTest(test.TestCase):
filter_size *= x
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
- with self.test_session(use_gpu=use_gpu):
- if data_type == dtypes.float16:
- tolerance = 0.002
- elif data_type == dtypes.float32:
- tolerance = 0.002
- else:
- self.assertEqual(data_type, dtypes.float64)
- tolerance = 1e-8
+ ops.reset_default_graph()
+ graph = ops.get_default_graph()
+ with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ tolerance = {
+ dtypes.float16: 2e-0,
+ dtypes.float32: 5e-4,
+ dtypes.float64: 1e-12,
+ }[data_type]
input_tensor = constant_op.constant(
input_data, shape=input_shape, dtype=data_type, name="input")
@@ -347,35 +379,49 @@ class DepthwiseConv2DTest(test.TestCase):
]
strides = [1, 1, stride, stride]
- depthwise_conv2d = nn_ops.depthwise_conv2d_native(
- native_input,
- filter_tensor,
- strides,
- padding,
- data_format=data_format,
- name="depthwise_conv2d")
+ with sess.graph._kernel_label_map({
+ "DepthwiseConv2dNative": "cudnn_grouped_convolution",
+ "DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
+ "DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
+ } if grouped_conv else {}):
+ depthwise_conv2d = nn_ops.depthwise_conv2d_native(
+ native_input,
+ filter_tensor,
+ strides,
+ padding,
+ data_format=data_format,
+ name="depthwise_conv2d")
self.assertEqual(output_shape, depthwise_conv2d.get_shape())
- if test_input:
- err = gradient_checker.compute_gradient_error(
- native_input, input_shape, depthwise_conv2d, output_shape)
- else:
- err = gradient_checker.compute_gradient_error(filter_tensor,
- filter_shape,
- depthwise_conv2d,
- output_shape)
- print("data_type:", data_type, "use_gpu:", use_gpu, ", error = ", err)
+
+ try:
+ if test_input:
+ err = gradient_checker.compute_gradient_error(
+ native_input, input_shape, depthwise_conv2d, output_shape)
+ else:
+ err = gradient_checker.compute_gradient_error(
+ filter_tensor, filter_shape, depthwise_conv2d, output_shape)
+ except errors.InvalidArgumentError as e:
+ # Grouped convolution kernel is only registered for cuDNN 7. Silently
+ # return when we are running on an earlier version or without GPU.
+ if grouped_conv and e.message.startswith(
+ "No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
+ tf_logging.warn("Skipping grouped convolution test")
+ return
+ raise e
+
+ tf_logging.info(
+ "data_type: %r, use_gpu: %r, grouped_conv: %r, error = %f", data_type,
+ use_gpu, grouped_conv, err)
self.assertLess(err, tolerance)
def testDepthwiseConv2DInputGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DInputGrad,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DInputGrad is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
+ "padding: %s", index, input_size, filter_size, stride, padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -385,6 +431,16 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=True,
use_gpu=True)
+ self._ConstructAndTestGradient(
+ input_size,
+ filter_size,
+ output_size,
+ stride,
+ padding,
+ data_type,
+ test_input=True,
+ use_gpu=True,
+ grouped_conv=True)
def testDepthwiseConv2DInputGradFormat(self):
if not test.is_gpu_available():
@@ -392,12 +448,11 @@ class DepthwiseConv2DTest(test.TestCase):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DInputGradFormat,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DInputGradFormat is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -412,12 +467,10 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DFilterGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DFilterGrad,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DFilterGrad is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
+ "%d, padding: %s", index, input_size, filter_size, stride, padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -434,12 +487,11 @@ class DepthwiseConv2DTest(test.TestCase):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
- print("Testing DepthwiseConv2DFilterGradFormat,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
- # Note: float16 test for DepthwiseConv2DFilterGradFormat is not enabled,
- # calculations are not very precise.
- for data_type in [dtypes.float32, dtypes.float64]:
+ tf_logging.info(
+ "Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
+ for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -494,9 +546,10 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DInputGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
self._CompareBackpropInputFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropInputDouble(input_size, filter_size, output_size,
@@ -545,9 +598,10 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DFilterGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
- print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
- input_size, "*", filter_size, "stride:", stride, "padding:",
- padding)
+ tf_logging.info(
+ "Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
+ "stride: %d, padding: %s", index, input_size, filter_size, stride,
+ padding)
self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropFilterDouble(input_size, filter_size, output_size,
diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py
index 18582241e2..33db014279 100644
--- a/tensorflow/python/kernel_tests/distributions/bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py
@@ -24,6 +24,7 @@ import numpy as np
import six
from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.platform import test
@@ -275,6 +276,17 @@ class BijectorReduceEventDimsTest(test.TestCase):
8.,
self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2)))
+ def testHandlesNonStaticEventNdims(self):
+ x_ = [[[1., 2.], [3., 4.]]]
+ x = array_ops.placeholder_with_default(x_, shape=None)
+ event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
+ bij = ExpOnlyJacobian(forward_min_event_ndims=1)
+ bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
+ with self.test_session() as sess:
+ ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
+ feed_dict={event_ndims: 1})
+ self.assertAllClose(-np.log(x_), ildj)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index faeccc8fba..6573cb9a1a 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -25,6 +25,28 @@ cuda_py_test(
)
cuda_py_test(
+ name = "linear_operator_block_diag_test",
+ size = "medium",
+ srcs = ["linear_operator_block_diag_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+ shard_count = 6,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
+)
+
+cuda_py_test(
name = "linear_operator_composition_test",
size = "medium",
srcs = ["linear_operator_composition_test.py"],
@@ -115,6 +137,28 @@ cuda_py_test(
)
cuda_py_test(
+ name = "linear_operator_kronecker_test",
+ size = "medium",
+ srcs = ["linear_operator_kronecker_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+ shard_count = 8,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
+)
+
+cuda_py_test(
name = "linear_operator_lower_triangular_test",
size = "medium",
srcs = ["linear_operator_lower_triangular_test.py"],
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
index e7407ede11..2b80f01b73 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
@@ -19,11 +19,11 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.linalg.python.ops import linear_operator_block_diag as block_diag
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_block_diag as block_diag
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index 6574da22a1..cce1ecd45e 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -19,12 +19,12 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.linalg.python.ops import linear_operator_kronecker as kronecker
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_kronecker as kronecker
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.platform import test
diff --git a/tensorflow/python/kernel_tests/reduce_benchmark_test.py b/tensorflow/python/kernel_tests/reduce_benchmark_test.py
new file mode 100644
index 0000000000..3a2fb81157
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reduce_benchmark_test.py
@@ -0,0 +1,107 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Simple benchmarks for reductions and their gradients."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+from six.moves import range # pylint: disable=redefined-builtin
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceBenchmarks(test.Benchmark):
+ """Benchmarks for reductions."""
+
+ def _run(self, func, num_iters):
+ # call func to maybe warm up the GPU
+ func()
+ start = time.time()
+ for _ in range(num_iters):
+ func()
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmark_reduce_sum_grad_eager(self):
+ with context.eager_mode():
+ tensor = array_ops.zeros([100, 1000])
+
+ def fn():
+ backprop.gradients_function(math_ops.reduce_sum, [0])(tensor)
+
+ self._run(fn, 10000)
+
+ def benchmark_reduce_sum_grad_eager_cpu(self):
+ with context.eager_mode(), ops.device("/cpu:0"):
+ tensor = array_ops.zeros([100, 1000])
+
+ def fn():
+ backprop.gradients_function(math_ops.reduce_sum, [0])(tensor)
+
+ self._run(fn, 10000)
+
+ def benchmark_reduce_sum_grad_graph(self):
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0)))
+ with ops.Graph().as_default(), session.Session(config=config) as sess:
+
+ tensor = constant_op.constant(np.zeros([100, 1000], dtype=np.float32))
+ reduction = math_ops.reduce_sum(tensor)
+ grad, = gradients_impl.gradients(reduction, tensor)
+
+ def fn():
+ sess.run(grad.op)
+
+ self._run(fn, 10000)
+
+ def benchmark_reduce_sum_grad_graph_cpu(self):
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0)))
+ with ops.Graph().as_default(), session.Session(config=config) as sess:
+
+ with ops.device("/cpu:0"):
+ tensor = constant_op.constant(np.zeros([100, 1000], dtype=np.float32))
+ reduction = math_ops.reduce_sum(tensor)
+ grad, = gradients_impl.gradients(reduction, tensor)
+
+ def fn():
+ sess.run(grad.op)
+
+ self._run(fn, 10000)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 4ebc600d03..36eee5ce78 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -23,6 +23,7 @@ import collections
import contextlib
import re
+import numpy as np
import six
from tensorflow.python.framework import dtypes
@@ -146,15 +147,21 @@ class Bijector(object):
for transforming a `Distribution` generated `Tensor`. A `Bijector` is
characterized by three operations:
- 1. Forward\
+ 1. Forward
+
Useful for turning one random outcome into another random outcome from a
different distribution.
- 2. Inverse\
+
+ 2. Inverse
+
Useful for "reversing" a transformation to compute one probability in
terms of another.
- 3. `log_det_jacobian(x)`\
+
+ 3. `log_det_jacobian(x)`
+
"The log of the determinant of the matrix of all first-order partial
- derivatives of the inverse function."\
+ derivatives of the inverse function."
+
Useful for inverting a transformation to compute one probability in terms
of another. Geometrically, the Jacobian determinant is the volume of the
transformation and is used to scale the probability.
@@ -520,6 +527,8 @@ class Bijector(object):
ValueError: If a member of `graph_parents` is not a `Tensor`.
"""
self._graph_parents = graph_parents or []
+ forward_min_event_ndims = get_static_value(forward_min_event_ndims)
+ inverse_min_event_ndims = get_static_value(inverse_min_event_ndims)
if forward_min_event_ndims is None and inverse_min_event_ndims is None:
raise ValueError("Must specify at least one of `forward_min_event_ndims` "
@@ -795,33 +804,37 @@ class Bijector(object):
return self._constant_ildj_map[event_ndims]
y = ops.convert_to_tensor(y, name="y")
self._maybe_assert_dtype(y)
- if not self._is_injective: # No caching for non-injective
- ildjs = self._inverse_log_det_jacobian(y, **kwargs)
- return tuple(self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- for ildj in ildjs)
- mapping = self._lookup(y=y, kwargs=kwargs)
- if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
- return mapping.ildj_map[event_ndims]
- try:
- x = None # Not needed; leave cache as is.
- ildj = self._inverse_log_det_jacobian(y, **kwargs)
- ildj = self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- except NotImplementedError as original_exception:
+ with ops.control_dependencies(self._check_valid_event_ndims(
+ min_event_ndims=self.inverse_min_event_ndims,
+ event_ndims=event_ndims)):
+ if not self._is_injective: # No caching for non-injective
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ mapping = self._lookup(y=y, kwargs=kwargs)
+ if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
+ return mapping.ildj_map[event_ndims]
try:
- x = mapping.x if mapping.x is not None else self._inverse(y, **kwargs)
- ildj = -self._forward_log_det_jacobian(x, **kwargs)
+ x = None # Not needed; leave cache as is.
+ ildj = self._inverse_log_det_jacobian(y, **kwargs)
ildj = self._reduce_jacobian_det_over_event(
- x, ildj, self.forward_min_event_ndims, event_ndims)
- except NotImplementedError:
- raise original_exception
-
- mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj})
- self._cache(mapping)
- if self.is_constant_jacobian:
- self._constant_ildj_map[event_ndims] = ildj
- return ildj
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ except NotImplementedError as original_exception:
+ try:
+ x = (mapping.x if mapping.x is not None
+ else self._inverse(y, **kwargs))
+ ildj = -self._forward_log_det_jacobian(x, **kwargs)
+ ildj = self._reduce_jacobian_det_over_event(
+ x, ildj, self.forward_min_event_ndims, event_ndims)
+ except NotImplementedError:
+ raise original_exception
+
+ mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj})
+ self._cache(mapping)
+ if self.is_constant_jacobian:
+ self._constant_ildj_map[event_ndims] = ildj
+ return ildj
def inverse_log_det_jacobian(
self, y, event_ndims, name="inverse_log_det_jacobian"):
@@ -852,9 +865,7 @@ class Bijector(object):
`self.dtype`.
NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
"""
- with ops.control_dependencies(self._check_valid_event_ndims(
- min_event_ndims=self.inverse_min_event_ndims, event_ndims=event_ndims)):
- return self._call_inverse_log_det_jacobian(y, event_ndims, name)
+ return self._call_inverse_log_det_jacobian(y, event_ndims, name)
def _forward_log_det_jacobian(self, x):
"""Subclass implementation of `forward_log_det_jacobian` public function.
@@ -876,38 +887,46 @@ class Bijector(object):
"forward_log_det_jacobian not implemented.")
def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
+ if not self._is_injective:
+ raise NotImplementedError(
+ "forward_log_det_jacobian cannot be implemented for non-injective "
+ "transforms.")
with self._name_scope(name, [x]):
- if event_ndims in self._constant_ildj_map:
- # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
- return -1. * self._constant_ildj_map[event_ndims]
- x = ops.convert_to_tensor(x, name="x")
- self._maybe_assert_dtype(x)
- if not self._is_injective:
- fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
- return tuple(self._reduce_jacobian_det_over_event(
- x, fldj, self.forward_min_event_ndims, event_ndims)
- for fldj in fldjs)
- mapping = self._lookup(x=x, kwargs=kwargs)
- if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
- return -mapping.ildj_map[event_ndims]
- try:
- y = None # Not needed; leave cache as is.
- ildj = -self._forward_log_det_jacobian(x, **kwargs)
- ildj = self._reduce_jacobian_det_over_event(
- x, ildj, self.forward_min_event_ndims, event_ndims)
- except NotImplementedError as original_exception:
+ with ops.control_dependencies(self._check_valid_event_ndims(
+ min_event_ndims=self.forward_min_event_ndims,
+ event_ndims=event_ndims)):
+ if event_ndims in self._constant_ildj_map:
+ # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
+ return -1. * self._constant_ildj_map[event_ndims]
+ x = ops.convert_to_tensor(x, name="x")
+ self._maybe_assert_dtype(x)
+ if not self._is_injective:
+ fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ mapping = self._lookup(x=x, kwargs=kwargs)
+ if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
+ return -mapping.ildj_map[event_ndims]
try:
- y = mapping.y if mapping.y is not None else self._forward(x, **kwargs)
- ildj = self._inverse_log_det_jacobian(y, **kwargs)
+ y = None # Not needed; leave cache as is.
+ ildj = -self._forward_log_det_jacobian(x, **kwargs)
ildj = self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- except NotImplementedError:
- raise original_exception
- mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj})
- self._cache(mapping)
- if self.is_constant_jacobian:
- self._constant_ildj_map[event_ndims] = ildj
- return -ildj
+ x, ildj, self.forward_min_event_ndims, event_ndims)
+ except NotImplementedError as original_exception:
+ try:
+ y = (mapping.y if mapping.y is not None
+ else self._forward(x, **kwargs))
+ ildj = self._inverse_log_det_jacobian(y, **kwargs)
+ ildj = self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ except NotImplementedError:
+ raise original_exception
+ mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj})
+ self._cache(mapping)
+ if self.is_constant_jacobian:
+ self._constant_ildj_map[event_ndims] = ildj
+ return -ildj
def forward_log_det_jacobian(
self, x, event_ndims, name="forward_log_det_jacobian"):
@@ -933,13 +952,7 @@ class Bijector(object):
nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or
this is a non-injective bijector.
"""
- if not self._is_injective:
- raise NotImplementedError(
- "forward_log_det_jacobian cannot be implemented for non-injective "
- "transforms.")
- with ops.control_dependencies(self._check_valid_event_ndims(
- min_event_ndims=self.forward_min_event_ndims, event_ndims=event_ndims)):
- return self._call_forward_log_det_jacobian(x, event_ndims, name)
+ return self._call_forward_log_det_jacobian(x, event_ndims, name)
@contextlib.contextmanager
def _name_scope(self, name=None, values=None):
@@ -981,12 +994,14 @@ class Bijector(object):
def _reduce_jacobian_det_over_event(
self, y, ildj, min_event_ndims, event_ndims):
"""Reduce jacobian over event_ndims - min_event_ndims."""
+ assert_static(min_event_ndims)
+
if not self.is_constant_jacobian:
return math_ops.reduce_sum(
ildj,
self._get_event_reduce_dims(min_event_ndims, event_ndims))
- # In this case, we need to tile the jacobian over the event and reduce.
+ # In this case, we need to tile the Jacobian over the event and reduce.
y_rank = array_ops.rank(y)
y_shape = array_ops.shape(y)[
y_rank - event_ndims : y_rank - min_event_ndims]
@@ -997,47 +1012,60 @@ class Bijector(object):
axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
# The multiplication by ones can change the inferred static shape so we try
# to recover as much as possible.
- if (isinstance(event_ndims, int) and
- y.get_shape().ndims and ildj.get_shape().ndims):
- y_shape = y.get_shape()
- y_shape = y_shape[y_shape.ndims - event_ndims :
- y_shape.ndims - min_event_ndims]
- ildj_shape = ildj.get_shape()
- broadcast_shape = array_ops.broadcast_static_shape(
- ildj_shape, y_shape)
+ event_ndims_ = get_static_value(event_ndims)
+ if (event_ndims_ is not None and
+ y.shape.ndims is not None and
+ ildj.shape.ndims is not None):
+ y_shape = y.shape[y.shape.ndims - event_ndims_ :
+ y.shape.ndims - min_event_ndims]
+ broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape)
reduced_ildj.set_shape(
broadcast_shape[: broadcast_shape.ndims - (
- event_ndims - min_event_ndims)])
+ event_ndims_ - min_event_ndims)])
return reduced_ildj
def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
"""Compute the reduction dimensions given event_ndims."""
- min_event_ndims_ = (min_event_ndims if isinstance(min_event_ndims, int)
- else tensor_util.constant_value(min_event_ndims))
- event_ndims_ = (event_ndims if isinstance(event_ndims, int)
- else tensor_util.constant_value(event_ndims))
+ assert_static(min_event_ndims)
+ event_ndims_ = get_static_value(event_ndims, np.int32)
- if min_event_ndims_ is not None and event_ndims_ is not None:
- return [-index for index in range(1, event_ndims_ - min_event_ndims_ + 1)]
+ if event_ndims_ is not None:
+ return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
else:
reduce_ndims = event_ndims - min_event_ndims
return math_ops.range(-reduce_ndims, 0)
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
"""Check whether event_ndims is atleast min_event_ndims."""
- min_event_ndims_ = (min_event_ndims if isinstance(min_event_ndims, int)
- else tensor_util.constant_value(min_event_ndims))
- event_ndims_ = (event_ndims if isinstance(event_ndims, int)
- else tensor_util.constant_value(event_ndims))
-
- if min_event_ndims_ is not None and event_ndims_ is not None:
- if min_event_ndims_ > event_ndims_:
+ assert_static(min_event_ndims)
+ event_ndims_ = get_static_value(event_ndims, np.int32)
+ assertions = []
+ if event_ndims_ is not None:
+ if min_event_ndims > event_ndims_:
raise ValueError("event_ndims ({}) must be larger than "
"min_event_ndims ({})".format(
- event_ndims_, min_event_ndims_))
- return []
-
- if self.validate_args:
- return [check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
- return []
+ event_ndims_, min_event_ndims))
+ elif self.validate_args:
+ assertions += [
+ check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
+ return assertions
+
+
+def get_static_value(x, dtype=None):
+ """Helper which returns static value; casting when dtype is preferred."""
+ if x is None:
+ return x
+ try:
+ x_ = tensor_util.constant_value(x)
+ except TypeError:
+ x_ = x
+ if x_ is None or dtype is None:
+ return x_
+ return np.array(x_, dtype)
+
+
+def assert_static(x):
+ """Helper which asserts that input arg is known statically."""
+ if x is None or type(x) != type(get_static_value(x)): # pylint: disable=unidiomatic-typecheck
+ raise TypeError("Input must be known statically.")
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 6f2a34c731..bcc717b043 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -385,7 +385,7 @@ def embedding_lookup_sparse(params,
```
Raises:
- TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
+ TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
neither `None` nor `SparseTensor`.
ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
"""
@@ -421,10 +421,7 @@ def embedding_lookup_sparse(params,
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
ids = sp_ids.values
- if ignore_weights:
- ids, idx = array_ops.unique(ids)
- else:
- idx = None
+ ids, idx = array_ops.unique(ids)
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
@@ -433,6 +430,8 @@ def embedding_lookup_sparse(params,
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
+ embeddings = array_ops.gather(embeddings, idx)
+
# Reshape weights to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py
index d73c21cdc0..a7ba0bbe9c 100644
--- a/tensorflow/python/ops/linalg/linalg.py
+++ b/tensorflow/python/ops/linalg/linalg.py
@@ -22,11 +22,13 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.ops.linalg.linalg_impl import *
from tensorflow.python.ops.linalg.linear_operator import *
+from tensorflow.python.ops.linalg.linear_operator_block_diag import *
from tensorflow.python.ops.linalg.linear_operator_circulant import *
from tensorflow.python.ops.linalg.linear_operator_composition import *
from tensorflow.python.ops.linalg.linear_operator_diag import *
from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
from tensorflow.python.ops.linalg.linear_operator_identity import *
+from tensorflow.python.ops.linalg.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py b/tensorflow/python/ops/linalg/linear_operator_block_diag.py
index 9d3af66c92..438c3496bd 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_block_diag.py
@@ -27,8 +27,14 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.util.tf_export import tf_export
+__all__ = [
+ "LinearOperatorBlockDiag",
+]
+
+@tf_export("linalg.LinearOperatorBlockDiag")
class LinearOperatorBlockDiag(linear_operator.LinearOperator):
"""Combines one or more `LinearOperators` in to a Block Diagonal matrix.
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py b/tensorflow/python/ops/linalg/linear_operator_kronecker.py
index 79080d194f..da959f9a1c 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py
+++ b/tensorflow/python/ops/linalg/linear_operator_kronecker.py
@@ -28,6 +28,11 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.util.tf_export import tf_export
+
+__all__ = [
+ "LinearOperatorKronecker",
+]
def _vec(x):
@@ -59,6 +64,7 @@ def _rotate_last_dim(x, rotate_right=False):
return array_ops.transpose(x, transpose_perm)
+@tf_export("linalg.LinearOperatorKronecker")
class LinearOperatorKronecker(linear_operator.LinearOperator):
"""Kronecker product between two `LinearOperators`.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b9529ce3ed..7ac3bd8091 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1768,6 +1768,7 @@ def reduce_logsumexp(input_tensor,
"keep_dims", keep_dims)
if keepdims is None:
keepdims = False
+ input_tensor = ops.convert_to_tensor(input_tensor)
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
raw_max = reduce_max(
input_tensor,
@@ -1780,13 +1781,13 @@ def reduce_logsumexp(input_tensor,
array_ops.zeros_like(raw_max)))
result = gen_math_ops.log(
reduce_sum(
- gen_math_ops.exp(input_tensor - my_max),
+ gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
axis,
keepdims=keepdims,
reduction_indices=reduction_indices))
if not keepdims:
my_max = array_ops.reshape(my_max, array_ops.shape(result))
- result += my_max
+ result = gen_math_ops.add(result, my_max)
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
@@ -2486,6 +2487,12 @@ def reduced_shape(input_shape, axes):
"""
# Example:
# cast needed for SparseTensor reductions
+ if context.executing_eagerly():
+ input_shape = input_shape.numpy()
+ axes = axes.numpy()
+ input_shape[axes] = 1
+ return input_shape
+
input_shape = to_int32(input_shape) # [2, 3, 5, 7]
axes = to_int32(axes) # [1, 2]
diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py
index f1137e80ab..e31fa02d60 100644
--- a/tensorflow/python/training/device_util.py
+++ b/tensorflow/python/training/device_util.py
@@ -23,17 +23,42 @@ from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
-def canonicalize(d):
+def canonicalize(d, default=None):
+ """Canonicalize device string.
+
+ If d has missing components, the rest would be deduced from the `default`
+ argument or from '/job:localhost/replica:0/task:0/device:CPU:0'. For example:
+ If d = '/cpu:0', default='/job:worker/task:1', it returns
+ '/job:worker/replica:0/task:1/device:CPU:0'.
+ If d = '/cpu:0', default='/job:worker', it returns
+ '/job:worker/replica:0/task:0/device:CPU:0'.
+ If d = '/gpu:0', default=None, it returns
+ '/job:localhost/replica:0/task:0/device:GPU:0'.
+
+ Args:
+ d: a device string.
+ default: a string for default device if d doesn't have all components.
+
+ Returns:
+ a canonicalized device string.
+ """
d = tf_device.DeviceSpec.from_string(d)
assert d.device_type is None or d.device_type == d.device_type.upper(), (
"Device type '%s' must be all-caps." % (d.device_type,))
# Fill in missing device fields using defaults.
result = tf_device.DeviceSpec(
job="localhost", replica=0, task=0, device_type="CPU", device_index=0)
+ if default:
+ result.merge_from(tf_device.DeviceSpec.from_string(default))
result.merge_from(d)
return result.to_string()
+def resolve(d):
+ """Canonicalize `d` with current device as default."""
+ return canonicalize(d, default=current())
+
+
class _FakeNodeDef(object):
"""A fake NodeDef for _FakeOperation."""
diff --git a/tensorflow/python/training/device_util_test.py b/tensorflow/python/training/device_util_test.py
new file mode 100644
index 0000000000..61525e21f5
--- /dev/null
+++ b/tensorflow/python/training/device_util_test.py
@@ -0,0 +1,89 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for device utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import device_util
+
+
+class DeviceUtilTest(test.TestCase):
+
+ def testCurrentDeviceWithGlobalGraph(self):
+ with ops.device("/cpu:0"):
+ self.assertEqual(device_util.current(), "/device:CPU:0")
+
+ with ops.device("/job:worker"):
+ with ops.device("/cpu:0"):
+ self.assertEqual(device_util.current(), "/job:worker/device:CPU:0")
+
+ with ops.device("/cpu:0"):
+ with ops.device("/gpu:0"):
+ self.assertEqual(device_util.current(), "/device:GPU:0")
+
+ def testCurrentDeviceWithNonGlobalGraph(self):
+ with ops.Graph().as_default():
+ with ops.device("/cpu:0"):
+ self.assertEqual(device_util.current(), "/device:CPU:0")
+
+ def testCurrentDeviceWithEager(self):
+ with context.eager_mode():
+ with ops.device("/cpu:0"):
+ self.assertEqual(device_util.current(),
+ "/job:localhost/replica:0/task:0/device:CPU:0")
+
+ def testCanonicalizeWithoutDefaultDevice(self):
+ self.assertEqual(
+ device_util.canonicalize("/cpu:0"),
+ "/job:localhost/replica:0/task:0/device:CPU:0")
+ self.assertEqual(
+ device_util.canonicalize("/job:worker/cpu:0"),
+ "/job:worker/replica:0/task:0/device:CPU:0")
+ self.assertEqual(
+ device_util.canonicalize("/job:worker/task:1/cpu:0"),
+ "/job:worker/replica:0/task:1/device:CPU:0")
+
+ def testCanonicalizeWithDefaultDevice(self):
+ self.assertEqual(
+ device_util.canonicalize("/job:worker/task:1/cpu:0", default="/gpu:0"),
+ "/job:worker/replica:0/task:1/device:CPU:0")
+ self.assertEqual(
+ device_util.canonicalize("/job:worker/task:1", default="/gpu:0"),
+ "/job:worker/replica:0/task:1/device:GPU:0")
+ self.assertEqual(
+ device_util.canonicalize("/cpu:0", default="/job:worker"),
+ "/job:worker/replica:0/task:0/device:CPU:0")
+
+ def testResolveWithDeviceScope(self):
+ with ops.device("/gpu:0"):
+ self.assertEqual(
+ device_util.resolve("/job:worker/task:1/cpu:0"),
+ "/job:worker/replica:0/task:1/device:CPU:0")
+ self.assertEqual(
+ device_util.resolve("/job:worker/task:1"),
+ "/job:worker/replica:0/task:1/device:GPU:0")
+ with ops.device("/job:worker"):
+ self.assertEqual(
+ device_util.resolve("/cpu:0"),
+ "/job:worker/replica:0/task:0/device:CPU:0")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 21ec5292ad..c16b05102e 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import threading
+import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
@@ -896,6 +897,8 @@ class DistributionStrategy(object):
A `Tensor` on `destination`.
"""
_require_cross_tower_context(self)
+ assert isinstance(destination, six.string_types)
+ destination = device_util.resolve(destination)
return self._fetch(val, destination, fn)
def _fetch(self, val, destination, fn):
@@ -1124,8 +1127,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- if kwargs.pop("tower_local_reduce_method", None) is not None:
- kwargs["trainable"] = False
+ kwargs.pop("tower_local_reduce_method", None)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
@@ -1135,7 +1137,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
"""Does not set to resource variables."""
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- kwargs["tower_local_reduce_method"] = reduce_method
+ kwargs["trainable"] = False
return next_creator(*args, **kwargs)
_require_distribution_strategy_scope(self)
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 42a77aa3f8..773cac2c40 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -337,7 +337,9 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
#if CUDNN_VERSION >= 7000
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionMathType) \
- __macro(cudnnSetRNNMatrixMathType)
+ __macro(cudnnSetRNNMatrixMathType) \
+ __macro(cudnnSetConvolutionGroupCount) \
+ __macro(cudnnGetConvolutionGroupCount)
// clang-format on
CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
@@ -779,6 +781,20 @@ class ScopedConvolutionDescriptor {
// NOTE(benbarsdell): This only applies if tensor op math is enabled
// and algo selection is set to Default.
this->set_use_tensor_op_math(true);
+
+#if CUDNN_MAJOR >= 7
+ VLOG(2) << "Requesting grouped convolution: "
+ << convolution_descriptor.group_count();
+ status = wrap::cudnnSetConvolutionGroupCount(
+ parent_, handle_, convolution_descriptor.group_count());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn convolution group count: "
+ << ToString(status);
+ }
+#else
+ CHECK_EQ(convolution_descriptor.group_count(), 1)
+ << "Requested grouped convolution for cuDNN version < 7";
+#endif
}
void set_use_tensor_op_math(bool use_tensor_op_math) {
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 031c82d3f4..eed93efc8d 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -434,6 +434,7 @@ ConvolutionDescriptor::ConvolutionDescriptor(int ndims)
filter_strides_(ndims, 1),
dilation_rates_(ndims, 1),
pad_alignment_(PadAlignment::kDefault),
+ group_count_(1),
ndims_(ndims) {}
ConvolutionDescriptor::ConvolutionDescriptor()
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 0c2e083b39..18606eb717 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -543,6 +543,10 @@ class ConvolutionDescriptor {
pad_alignment_ = pad_alignment;
return *this;
}
+ ConvolutionDescriptor& set_group_count(int group_count) {
+ group_count_ = group_count;
+ return *this;
+ }
int64 zero_padding_height() const {
return GetDim(zero_padding_, DimIndex::Y);
}
@@ -566,6 +570,7 @@ class ConvolutionDescriptor {
int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); }
int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); }
PadAlignment pad_alignment() const { return pad_alignment_; }
+ int group_count() const { return group_count_; }
int ndims() const { return ndims_; }
std::vector<int64> strides() const { return filter_strides_; }
@@ -578,6 +583,7 @@ class ConvolutionDescriptor {
std::vector<int64> filter_strides_;
std::vector<int64> dilation_rates_;
PadAlignment pad_alignment_;
+ int group_count_;
int ndims_;
// TODO(leary) cudnn provides these fields, but need to characterize what
// their effect is -- they may be boolean rather than integral.
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index c06a39bfbd..788f6d3573 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -23,6 +23,7 @@ import collections
import os
import sys
+from tensorflow import python # pylint: disable=unused-import
from tensorflow.python.util import tf_decorator
@@ -158,7 +159,7 @@ def get_api_init_text():
# Traverse over everything imported above. Specifically,
# we want to traverse over TensorFlow Python modules.
- for module in sys.modules.values():
+ for module in list(sys.modules.values()):
# Only look at tensorflow modules.
if (not module or not hasattr(module, '__name__') or
'tensorflow.' not in module.__name__):
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt
new file mode 100644
index 0000000000..b6dee63176
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.linalg.LinearOperatorBlockDiag.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.pbtxt
new file mode 100644
index 0000000000..973705dae2
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.pbtxt
@@ -0,0 +1,134 @@
+path: "tensorflow.linalg.LinearOperatorBlockDiag"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_block_diag.LinearOperatorBlockDiag\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "operators"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'operators\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt
new file mode 100644
index 0000000000..5c6784dd02
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.linalg.LinearOperatorKronecker.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.pbtxt
new file mode 100644
index 0000000000..c11d390829
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.pbtxt
@@ -0,0 +1,134 @@
+path: "tensorflow.linalg.LinearOperatorKronecker"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_kronecker.LinearOperatorKronecker\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "operators"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'operators\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 7a5c533872..00b9238543 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -5,6 +5,10 @@ tf_module {
mtype: "<class \'abc.ABCMeta\'>"
}
member {
+ name: "LinearOperatorBlockDiag"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
name: "LinearOperatorCirculant"
mtype: "<class \'abc.ABCMeta\'>"
}
@@ -33,6 +37,10 @@ tf_module {
mtype: "<class \'abc.ABCMeta\'>"
}
member {
+ name: "LinearOperatorKronecker"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
name: "LinearOperatorLowRankUpdate"
mtype: "<class \'abc.ABCMeta\'>"
}
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index bfaa044c82..275abeb669 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -49,9 +49,13 @@ cd Python-3.6.1
make altinstall
ln -s /usr/local/bin/pip3.6 /usr/local/bin/pip3
+pip3 install --upgrade setuptools
+pip3 install --upgrade pip
+
pip3 install --upgrade virtualenv
set -e
+
# Install six.
pip3 install --upgrade absl-py
pip3 install --upgrade six==1.10.0
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index e2518f6cbf..b23dde2019 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -79,6 +79,16 @@ BLACKLIST = [
]
+def bazel_query(query_target):
+ """Run bazel query on target."""
+ try:
+ output = subprocess.check_output(
+ ["bazel", "query", "--keep_going", query_target])
+ except subprocess.CalledProcessError as e:
+ output = e.output
+ return output
+
+
def main():
"""This script runs the pip smoke test.
@@ -93,15 +103,13 @@ def main():
"""
# pip_package_dependencies_list is the list of included files in pip packages
- pip_package_dependencies = subprocess.check_output(
- ["bazel", "query", PIP_PACKAGE_QUERY_EXPRESSION])
+ pip_package_dependencies = bazel_query(PIP_PACKAGE_QUERY_EXPRESSION)
pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
print("Pip package superset size: %d" % len(pip_package_dependencies_list))
# tf_py_test_dependencies is the list of dependencies for all python
# tests in tensorflow
- tf_py_test_dependencies = subprocess.check_output(
- ["bazel", "query", PY_TEST_QUERY_EXPRESSION])
+ tf_py_test_dependencies = bazel_query(PY_TEST_QUERY_EXPRESSION)
tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
@@ -135,14 +143,15 @@ def main():
print("Affected Tests:")
rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" %
missing_dependency)
- affected_tests = subprocess.check_output(["bazel", "query", rdep_query])
+ affected_tests = bazel_query(rdep_query)
affected_tests_list = affected_tests.split("\n")[:-2]
print("\n".join(affected_tests_list))
- raise RuntimeError("""One or more dependencies are not in the pip package.
-Please either blacklist the dependencies in
-//tensorflow/tools/pip_package/pip_smoke_test.py
-or add them to //tensorflow/tools/pip_package/BUILD.""")
+ raise RuntimeError("""
+ One or more added test dependencies are not in the pip package.
+If these test dependencies need to be in TensorFlow pip package, please add them to //tensorflow/tools/pip_package/BUILD.
+Else either blacklist the dependencies in //tensorflow/tools/pip_package/pip_smoke_test.py
+or add no_pip tag to the test.""")
else:
print("TEST PASSED")
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index ef2e035ef9..16da59c5cf 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/3b2f0b2c7e66d226a9342be5163da4240e2951a8.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/3b2f0b2c7e66d226a9342be5163da4240e2951a8.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/068c967842b83d22007eee4515b57e8d9aaccb82.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/068c967842b83d22007eee4515b57e8d9aaccb82.tar.gz",
],
- sha256 = "49bb3cbb7c8e9af091c5a743fa7ae749656994408438f38c9b6ac6a052fdce56",
- strip_prefix = "llvm-3b2f0b2c7e66d226a9342be5163da4240e2951a8",
+ sha256 = "4950432fb5cc68e5bf1f87a30b17dfdc69a5b93dac1e89d5274242d3ce7dae7c",
+ strip_prefix = "llvm-068c967842b83d22007eee4515b57e8d9aaccb82",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)