aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-04-05 07:34:25 -0700
committerGravatar Michael Case <mikecase@google.com>2018-04-05 07:34:25 -0700
commitc9c17e3d277fffba647d76f1c3a1cfa4b3001761 (patch)
tree1073e8354148c398d6abb87817e2d70e7eef582a
parentc1c819b28476d72c1f086fc4e78ff7f013c225ce (diff)
parent361a13cf0c2b65d26f6e2b5b68875adfcea98dd0 (diff)
Merge commit for internal changes
-rw-r--r--configure.py81
-rw-r--r--tensorflow/c/c_api.h8
-rw-r--r--tensorflow/c/c_api_test.cc2
-rw-r--r--tensorflow/c/eager/c_api.cc8
-rw-r--r--tensorflow/cc/saved_model/loader_test.cc15
-rw-r--r--tensorflow/cc/tutorials/example_trainer.cc6
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc1
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc30
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc23
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc4
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc5
-rw-r--r--tensorflow/compiler/tests/BUILD1
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py43
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl3
-rw-r--r--tensorflow/compiler/tests/cholesky_op_test.py7
-rw-r--r--tensorflow/compiler/tests/matrix_triangular_solve_op_test.py8
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py7
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py5
-rw-r--r--tensorflow/compiler/tests/variable_ops_test.py12
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h10
-rw-r--r--tensorflow/compiler/xla/client/client.cc34
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc95
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h44
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc337
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h18
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc11
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h7
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i164
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py14
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py11
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.h1
-rw-r--r--tensorflow/compiler/xla/service/executable.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc75
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc306
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc80
-rw-r--r--tensorflow/compiler/xla/service/service.cc229
-rw-r--r--tensorflow/compiler/xla/service/service.h22
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc1
-rw-r--r--tensorflow/compiler/xla/tests/BUILD15
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc6
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h13
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc145
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc84
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc69
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc25
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc368
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc3
-rw-r--r--tensorflow/contrib/BUILD8
-rw-r--r--tensorflow/contrib/__init__.py7
-rw-r--r--tensorflow/contrib/android/asset_manager_filesystem.cc11
-rw-r--r--tensorflow/contrib/android/asset_manager_filesystem.h3
-rw-r--r--tensorflow/contrib/autograph/converters/BUILD3
-rw-r--r--tensorflow/contrib/autograph/impl/BUILD3
-rw-r--r--tensorflow/contrib/autograph/impl/config.py17
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD25
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py24
-rw-r--r--tensorflow/contrib/autograph/pyct/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/BUILD2
-rw-r--r--tensorflow/contrib/autograph/utils/BUILD2
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/batch_features.h6
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/contrib/data/BUILD8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py93
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py90
-rw-r--r--tensorflow/contrib/distribute/README.md2
-rw-r--r--tensorflow/contrib/distribute/python/estimator_integration_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/examples/simple_estimator_example.py2
-rw-r--r--tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py2
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py11
-rw-r--r--tensorflow/contrib/distributions/BUILD8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py37
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py192
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils.py113
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils_test.py82
-rw-r--r--tensorflow/contrib/eager/python/datasets.py24
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/BUILD1
-rw-r--r--tensorflow/contrib/estimator/BUILD3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py9
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_test.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py12
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/linear_test.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py14
-rw-r--r--tensorflow/contrib/framework/python/ops/arg_scope_test.py24
-rw-r--r--tensorflow/contrib/gan/BUILD1
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py64
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py20
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD24
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py315
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py62
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py48
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py (renamed from tensorflow/contrib/kfac/examples/convnet_mnist_main.py)32
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py17
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py23
-rw-r--r--tensorflow/contrib/labeled_tensor/BUILD1
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/layers/__init__.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization.py195
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py226
-rw-r--r--tensorflow/contrib/learn/BUILD5
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py2
-rw-r--r--tensorflow/contrib/lite/BUILD1
-rw-r--r--tensorflow/contrib/lite/arena_planner.h2
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java13
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java12
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc49
-rw-r--r--tensorflow/contrib/lite/kernels/maximum.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_test.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice_test.cc113
-rw-r--r--tensorflow/contrib/lite/model.cc70
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc2
-rw-r--r--tensorflow/contrib/lite/python/BUILD1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs2
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h38
-rw-r--r--tensorflow/contrib/lite/toco/BUILD5
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc116
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc190
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc164
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc49
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc137
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc153
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc248
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc116
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc14
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc23
-rw-r--r--tensorflow/contrib/lite/toco/model.h17
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc24
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc5
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc43
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h13
-rw-r--r--tensorflow/contrib/lookup/BUILD1
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt1
-rw-r--r--tensorflow/contrib/nccl/BUILD6
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_ops.cc2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_rewrite.cc3
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py4
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py4
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py21
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/BUILD1
-rw-r--r--tensorflow/contrib/saved_model/BUILD1
-rw-r--r--tensorflow/contrib/session_bundle/BUILD1
-rw-r--r--tensorflow/contrib/slim/python/slim/data/BUILD1
-rw-r--r--tensorflow/contrib/stat_summarizer/BUILD1
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD1
-rw-r--r--tensorflow/contrib/tensorboard/BUILD2
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc3
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD5
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD20
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py90
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py86
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD1
-rw-r--r--tensorflow/contrib/tpu/BUILD14
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/bfloat16.py77
-rw-r--r--tensorflow/contrib/tpu/python/tpu/bfloat16_test.py50
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py3
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py2
-rw-r--r--tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py5
-rw-r--r--tensorflow/contrib/util/loader.py7
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/api_def/base_api/api_def_For.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_If.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_While.pbtxt33
-rw-r--r--tensorflow/core/api_def/python_api/api_def_For.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_If.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_While.pbtxt1
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc25
-rw-r--r--tensorflow/core/common_runtime/executor.cc40
-rw-r--r--tensorflow/core/common_runtime/function_test.cc9
-rw-r--r--tensorflow/core/common_runtime/function_threadpool_test.cc3
-rw-r--r--tensorflow/core/common_runtime/placer.cc4
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc110
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h5
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc3
-rw-r--r--tensorflow/core/common_runtime/session_test.cc16
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/common_runtime/shape_refiner_test.cc9
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc3
-rw-r--r--tensorflow/core/framework/allocator.cc38
-rw-r--r--tensorflow/core/framework/attr_value_util.cc26
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc4
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc18
-rw-r--r--tensorflow/core/framework/dataset.h6
-rw-r--r--tensorflow/core/framework/function.cc5
-rw-r--r--tensorflow/core/framework/function_test.cc2
-rw-r--r--tensorflow/core/framework/graph_def_util.cc3
-rw-r--r--tensorflow/core/framework/node_def_builder_test.cc5
-rw-r--r--tensorflow/core/framework/node_def_util.cc7
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc5
-rw-r--r--tensorflow/core/framework/op.cc3
-rw-r--r--tensorflow/core/framework/op_compatibility_test.cc7
-rw-r--r--tensorflow/core/framework/op_def.proto6
-rw-r--r--tensorflow/core/framework/op_def_builder.cc49
-rw-r--r--tensorflow/core/framework/op_def_util.cc12
-rw-r--r--tensorflow/core/framework/op_def_util_test.cc2
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc21
-rw-r--r--tensorflow/core/framework/op_kernel.cc70
-rw-r--r--tensorflow/core/framework/op_kernel.h42
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc19
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc3
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc209
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc8
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h23
-rw-r--r--tensorflow/core/framework/shape_inference_testutil_test.cc14
-rw-r--r--tensorflow/core/framework/types.cc2
-rw-r--r--tensorflow/core/framework/types_test.cc6
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc10
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc11
-rw-r--r--tensorflow/core/graph/graph.cc5
-rw-r--r--tensorflow/core/graph/graph_constructor.cc15
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc4
-rw-r--r--tensorflow/core/graph/graph_partition.cc3
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc5
-rw-r--r--tensorflow/core/graph/graph_test.cc3
-rw-r--r--tensorflow/core/graph/quantize_training.cc10
-rw-r--r--tensorflow/core/graph/quantize_training_test.cc5
-rw-r--r--tensorflow/core/graph/subgraph_test.cc4
-rw-r--r--tensorflow/core/graph/tensor_id.cc3
-rw-r--r--tensorflow/core/graph/validate_test.cc7
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc4
-rw-r--r--tensorflow/core/grappler/clusters/utils.cc2
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc2
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc1
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc39
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc156
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc4
-rw-r--r--tensorflow/core/grappler/op_types.cc19
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD9
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc473
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h20
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc593
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc144
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc116
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc44
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc240
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc137
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h71
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc86
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/symbolic_shapes.h4
-rw-r--r--tensorflow/core/grappler/utils.cc8
-rw-r--r--tensorflow/core/grappler/utils.h3
-rw-r--r--tensorflow/core/grappler/utils/BUILD25
-rw-r--r--tensorflow/core/grappler/utils/colocation.cc122
-rw-r--r--tensorflow/core/grappler/utils/colocation.h39
-rw-r--r--tensorflow/core/grappler/utils/colocation_test.cc183
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h9
-rw-r--r--tensorflow/core/kernels/assign_op.h3
-rw-r--r--tensorflow/core/kernels/data_format_ops.cc40
-rw-r--r--tensorflow/core/kernels/data_format_ops.h18
-rw-r--r--tensorflow/core/kernels/functional_ops.cc189
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc32
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h17
-rw-r--r--tensorflow/core/kernels/queue_op.h1
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc34
-rw-r--r--tensorflow/core/kernels/sparse_cross_op.cc2
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc2
-rw-r--r--tensorflow/core/lib/core/stringpiece.h2
-rw-r--r--tensorflow/core/lib/io/format.cc7
-rw-r--r--tensorflow/core/lib/strings/numbers.cc5
-rw-r--r--tensorflow/core/lib/strings/ordered_code_test.cc3
-rw-r--r--tensorflow/core/lib/strings/scanner.h5
-rw-r--r--tensorflow/core/lib/wav/wav_io_test.cc3
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt96
-rw-r--r--tensorflow/core/ops/functional_ops.cc41
-rw-r--r--tensorflow/core/ops/math_grad_test.cc3
-rw-r--r--tensorflow/core/ops/math_ops_test.cc14
-rw-r--r--tensorflow/core/ops/ops.pbtxt96
-rw-r--r--tensorflow/core/platform/abi.cc8
-rw-r--r--tensorflow/core/platform/cloud/BUILD2
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc9
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc17
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc29
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils_test.cc8
-rw-r--r--tensorflow/core/platform/default/tracing_impl.h1
-rw-r--r--tensorflow/core/platform/denormal.cc3
-rw-r--r--tensorflow/core/platform/file_system.cc93
-rw-r--r--tensorflow/core/platform/file_system.h4
-rw-r--r--tensorflow/core/platform/file_system_helper.cc126
-rw-r--r--tensorflow/core/platform/file_system_helper.h51
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc6
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.h3
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system_test.cc3
-rw-r--r--tensorflow/core/platform/mem.h2
-rw-r--r--tensorflow/core/platform/null_file_system.h6
-rw-r--r--tensorflow/core/platform/posix/port.cc2
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.cc6
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.h3
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc6
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.h3
-rw-r--r--tensorflow/core/platform/tracing.h4
-rw-r--r--tensorflow/core/platform/windows/port.cc2
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc4
-rw-r--r--tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc14
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto12
-rw-r--r--tensorflow/core/util/command_line_flags.cc20
-rw-r--r--tensorflow/core/util/device_name_utils_test.cc2
-rw-r--r--tensorflow/core/util/equal_graph_def.cc5
-rw-r--r--tensorflow/core/util/memmapped_file_system.cc9
-rw-r--r--tensorflow/core/util/memmapped_file_system.h2
-rw-r--r--tensorflow/core/util/reporter_test.cc2
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc13
-rw-r--r--tensorflow/core/util/tensor_slice_reader_test.cc3
-rw-r--r--tensorflow/core/util/tensor_slice_writer_test.cc9
-rw-r--r--tensorflow/docs_src/extend/index.md5
-rw-r--r--tensorflow/docs_src/mobile/tflite/devguide.md2
-rw-r--r--tensorflow/docs_src/programmers_guide/eager.md453
-rw-r--r--tensorflow/docs_src/programmers_guide/index.md1
-rw-r--r--tensorflow/examples/label_image/main.cc7
-rw-r--r--tensorflow/examples/multibox_detector/main.cc7
-rw-r--r--tensorflow/go/op/wrappers.go250
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/python/BUILD117
-rw-r--r--tensorflow/python/client/session.py91
-rw-r--r--tensorflow/python/client/session_list_devices_test.py19
-rw-r--r--tensorflow/python/client/tf_session.i33
-rw-r--r--tensorflow/python/client/tf_session_helper.cc9
-rw-r--r--tensorflow/python/client/tf_session_helper.h16
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py9
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py38
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/eager/backprop.py11
-rw-r--r--tensorflow/python/eager/benchmarks_test.py3
-rw-r--r--tensorflow/python/eager/context.py89
-rw-r--r--tensorflow/python/eager/function.py38
-rw-r--r--tensorflow/python/eager/imperative_grad.py6
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.cc21
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc9
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/estimator.py2
-rw-r--r--tensorflow/python/estimator/run_config.py18
-rw-r--r--tensorflow/python/feature_column/BUILD1
-rw-r--r--tensorflow/python/feature_column/feature_column.py344
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py1322
-rw-r--r--tensorflow/python/framework/c_api_util.py6
-rw-r--r--tensorflow/python/framework/errors_impl.py3
-rw-r--r--tensorflow/python/framework/function.py34
-rw-r--r--tensorflow/python/framework/importer.py5
-rw-r--r--tensorflow/python/framework/importer_test.py17
-rw-r--r--tensorflow/python/framework/load_library.py7
-rw-r--r--tensorflow/python/framework/ops.py141
-rw-r--r--tensorflow/python/framework/ops_test.py2
-rw-r--r--tensorflow/python/framework/smart_cond.py6
-rw-r--r--tensorflow/python/framework/versions.py2
-rw-r--r--tensorflow/python/grappler/constant_folding_test.py69
-rw-r--r--tensorflow/python/grappler/item.py4
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py47
-rwxr-xr-xtensorflow/python/keras/BUILD5
-rw-r--r--tensorflow/python/keras/_impl/keras/activations.py16
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py18
-rw-r--r--tensorflow/python/keras/_impl/keras/constraints.py13
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology_test.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager.py70
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager_test.py104
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py115
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_utils.py43
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py14
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/core.py19
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/core_test.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings.py8
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/merge.py55
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/noise.py12
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py36
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent_test.py15
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers.py7
-rw-r--r--tensorflow/python/keras/_impl/keras/losses.py38
-rw-r--r--tensorflow/python/keras/_impl/keras/metrics.py28
-rw-r--r--tensorflow/python/keras/_impl/keras/metrics_test.py19
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers.py128
-rw-r--r--tensorflow/python/keras/_impl/keras/regularizers.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/layer_utils.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD10
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py56
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_3d_test.py3
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py42
-rw-r--r--tensorflow/python/kernel_tests/distributions/uniform_test.py16
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py279
-rw-r--r--tensorflow/python/kernel_tests/large_concat_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py136
-rw-r--r--tensorflow/python/kernel_tests/metrics_test.py63
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py6
-rw-r--r--tensorflow/python/layers/base.py3
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.cc50
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.h73
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.i28
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc7
-rw-r--r--tensorflow/python/lib/io/tf_record.py3
-rw-r--r--tensorflow/python/ops/array_grad.py7
-rw-r--r--tensorflow/python/ops/array_ops.py5
-rw-r--r--tensorflow/python/ops/distributions/uniform.py3
-rw-r--r--tensorflow/python/ops/functional_ops.py270
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_util.py85
-rw-r--r--tensorflow/python/ops/math_ops.py84
-rw-r--r--tensorflow/python/ops/math_ops_test.py4
-rw-r--r--tensorflow/python/ops/metrics_impl.py77
-rw-r--r--tensorflow/python/ops/nn_ops.py4
-rw-r--r--tensorflow/python/ops/nn_test.py36
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py4
-rw-r--r--tensorflow/python/ops/state_ops.py52
-rw-r--r--tensorflow/python/ops/variable_scope.py2
-rw-r--r--tensorflow/python/platform/base.i22
-rw-r--r--tensorflow/python/pywrap_tfe.i2
-rw-r--r--tensorflow/python/tensorflow.i2
-rw-r--r--tensorflow/python/tools/optimize_for_inference.py12
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py9
-rw-r--r--tensorflow/python/training/distribute.py11
-rw-r--r--tensorflow/python/training/input.py3
-rw-r--r--tensorflow/stream_executor/BUILD1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc7
-rw-r--r--tensorflow/stream_executor/cuda/cudnn_version.h7
-rw-r--r--tensorflow/tensorflow.bzl41
-rw-r--r--tensorflow/tools/api/generator/BUILD1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/tensorflow.math.pbtxt7
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py10
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh4
-rw-r--r--tensorflow/tools/def_file_filter/BUILD9
-rw-r--r--tensorflow/tools/def_file_filter/BUILD.tpl15
-rw-r--r--tensorflow/tools/def_file_filter/def_file_filter.py.tpl168
-rw-r--r--tensorflow/tools/def_file_filter/def_file_filter_configure.bzl56
-rw-r--r--tensorflow/tools/graph_transforms/backports_test.cc3
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc4
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc49
-rw-r--r--tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc9
-rw-r--r--tensorflow/tools/graph_transforms/insert_logging.cc3
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc7
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph_test.cc8
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc7
-rw-r--r--tensorflow/tools/pip_package/BUILD133
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh4
-rw-r--r--tensorflow/workspace.bzl40
-rw-r--r--third_party/llvm/llvm.BUILD12
-rw-r--r--third_party/nccl/LICENSE203
-rw-r--r--third_party/nccl/nccl_archive.BUILD (renamed from third_party/nccl.BUILD)2
-rw-r--r--third_party/nccl/nccl_configure.bzl172
-rw-r--r--third_party/snappy.BUILD96
-rw-r--r--third_party/zlib.BUILD16
486 files changed, 14248 insertions, 3991 deletions
diff --git a/configure.py b/configure.py
index 6744082d5d..81d5ad77ee 100644
--- a/configure.py
+++ b/configure.py
@@ -35,6 +35,7 @@ except ImportError:
_DEFAULT_CUDA_VERSION = '9.0'
_DEFAULT_CUDNN_VERSION = '7'
+_DEFAULT_NCCL_VERSION = '1.3'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
@@ -484,6 +485,8 @@ def set_cc_opt_flags(environ_cp):
if is_ppc64le():
# gcc on ppc64le does not support -march, use mcpu instead
default_cc_opt_flags = '-mcpu=native'
+ elif is_windows():
+ default_cc_opt_flags = '/arch:AVX'
else:
default_cc_opt_flags = '-march=native'
question = ('Please specify optimization flags to use during compilation when'
@@ -494,7 +497,7 @@ def set_cc_opt_flags(environ_cp):
for opt in cc_opt_flags.split():
write_to_bazelrc('build:opt --copt=%s' % opt)
# It should be safe on the same build host.
- if not is_ppc64le():
+ if not is_ppc64le() and not is_windows():
write_to_bazelrc('build:opt --host_copt=-march=native')
write_to_bazelrc('build:opt --define with_default_optimizations=true')
# TODO(mikecase): Remove these default defines once we are able to get
@@ -1102,6 +1105,81 @@ def set_tf_tensorrt_install_path(environ_cp):
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
+def set_tf_nccl_install_path(environ_cp):
+ """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION.
+
+ Args:
+ environ_cp: copy of the os.environ.
+
+ Raises:
+ ValueError: if this method was called under non-Linux platform.
+ UserInputError: if user has provided invalid input multiple times.
+ """
+ if not is_linux():
+ raise ValueError('Currently NCCL is only supported on Linux platforms.')
+
+ ask_nccl_version = (
+ 'Please specify the NCCL version you want to use. '
+ '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION
+
+ for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
+ tf_nccl_version = get_from_env_or_user_or_default(
+ environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION)
+ tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
+
+ if tf_nccl_version == '1':
+ break # No need to get install path, NCCL 1 is a GitHub repo.
+
+ # TODO(csigg): Look with ldconfig first if we can find the library in paths
+ # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
+ # include directory. This is where the NCCL .deb packages install them.
+ # Then ask the user if we should use that. Instead of a single
+ # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to
+ # nccl_configure.bzl
+ default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
+ ask_nccl_path = (r'Please specify the location where NCCL %s library is '
+ 'installed. Refer to README.md for more details. [Default '
+ 'is %s]:') % (tf_nccl_version, default_nccl_path)
+ nccl_install_path = get_from_env_or_user_or_default(
+ environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
+
+ # Result returned from "read" will be used unexpanded. That make "~"
+ # unusable. Going through one more level of expansion to handle that.
+ nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path))
+ if is_windows() or is_cygwin():
+ nccl_install_path = cygpath(nccl_install_path)
+
+ if is_windows():
+ nccl_lib_path = 'lib/x64/nccl.lib'
+ elif is_linux():
+ nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version
+ elif is_macos():
+ nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
+
+ nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
+ nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
+ if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
+ # Set NCCL_INSTALL_PATH
+ environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
+ write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
+ break
+
+ # Reset and Retry
+ print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
+ 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
+ nccl_hdr_path))
+
+ environ_cp['TF_NCCL_VERSION'] = ''
+ else:
+ raise UserInputError('Invalid TF_NCCL setting was provided %d '
+ 'times in a row. Assuming to be a scripting mistake.' %
+ _DEFAULT_PROMPT_ASK_ATTEMPTS)
+
+ # Set TF_NCCL_VERSION
+ environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
+ write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
+
+
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@@ -1438,6 +1516,7 @@ def main():
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_install_path(environ_cp)
+ set_tf_nccl_install_path(environ_cp)
set_tf_cuda_compute_capabilities(environ_cp)
if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
'LD_LIBRARY_PATH') != '1':
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index b32f574628..fe85f8ee0e 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list);
// If index is out of bounds, an error code will be set in the status object,
// and a null pointer will be returned.
TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
- int index, TF_Status*);
+ int index,
+ TF_Status* status);
// Retrieves the type of the device at the given index.
//
@@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
// If index is out of bounds, an error code will be set in the status object,
// and a null pointer will be returned.
TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
- int index, TF_Status*);
+ int index,
+ TF_Status* status);
// Retrieve the amount of memory associated with a given device.
//
// If index is out of bounds, an error code will be set in the status object,
// and -1 will be returned.
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
- const TF_DeviceList* list, int index, TF_Status*);
+ const TF_DeviceList* list, int index, TF_Status* status);
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 028f146be3..ca80db23ed 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -53,7 +53,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(StringPiece(s).contains(expected))
+ EXPECT_TRUE(str_util::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index bb1492fca2..c96a38dec3 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -496,9 +496,11 @@ tensorflow::Status ValidateInputTypeAndPlacement(
expected_device->name(), " but is actually on ",
actual_device->name(), " (operation running on ",
op_device->name(), ")",
- " Tensors can be copied explicitly using .gpu() or .cpu(),"
- " or transparently copied by using tfe.enable_eager_execution("
- "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
+ " Tensors can be copied explicitly using .gpu() or .cpu() "
+ "methods,"
+ " or transparently copied by using tf.enable_eager_execution("
+ "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
+ "between devices"
" may slow down your model");
case tensorflow::DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index 4c64d2cfe3..72b8bc1871 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -133,9 +134,9 @@ TEST_F(LoaderTest, NoTagMatch) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{"missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
- EXPECT_TRUE(StringPiece(st.error_message())
- .contains("Could not find meta graph def matching supplied "
- "tags: { missing-tag }"))
+ EXPECT_TRUE(str_util::StrContains(
+ st.error_message(),
+ "Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message();
}
@@ -149,9 +150,9 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe, "missing-tag"}, &bundle);
EXPECT_FALSE(st.ok());
- EXPECT_TRUE(
- StringPiece(st.error_message())
- .contains("Could not find meta graph def matching supplied tags: "))
+ EXPECT_TRUE(str_util::StrContains(
+ st.error_message(),
+ "Could not find meta graph def matching supplied tags: "))
<< st.error_message();
}
@@ -169,7 +170,7 @@ TEST_F(LoaderTest, SessionCreationFailure) {
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
- EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget))
+ EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget))
<< st.error_message();
}
diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc
index 3675d72ee3..5dbc4f5f6a 100644
--- a/tensorflow/cc/tutorials/example_trainer.cc
+++ b/tensorflow/cc/tutorials/example_trainer.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -166,7 +167,8 @@ namespace {
bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
int32* dst) {
- if (arg.Consume(flag) && arg.Consume("=")) {
+ if (tensorflow::str_util::ConsumePrefix(&arg, flag) &&
+ tensorflow::str_util::ConsumePrefix(&arg, "=")) {
char extra;
return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
}
@@ -176,7 +178,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
bool* dst) {
- if (arg.Consume(flag)) {
+ if (tensorflow::str_util::ConsumePrefix(&arg, flag)) {
if (arg.empty()) {
*dst = true;
return true;
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 53ec6c1e60..b04b333141 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -825,6 +825,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
builder.Attr("key",
strings::StrCat("host_compute_channel_", subgraph_name, "_",
oc_subgraph_name));
+ builder.Attr("_outside_compilation_subgraph", oc_subgraph_name);
Status s = builder.Finalize(&host_compute_def);
if (!s.ok()) return s;
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 56efe98fdb..8599a7038a 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -902,7 +902,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})}},
+ {"shapes", gtl::ArraySlice<DataType>({})},
+ {"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1046,7 +1047,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
- {"shapes", gtl::ArraySlice<DataType>({})}},
+ {"shapes", gtl::ArraySlice<DataType>({})},
+ {"_outside_compilation_subgraph", "O2"}},
{"F"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
@@ -1056,7 +1058,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})}},
+ {"shapes", gtl::ArraySlice<DataType>({})},
+ {"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
{{"i_0_retval", "I:o:0"}});
@@ -1193,7 +1196,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})}},
+ {"shapes", gtl::ArraySlice<DataType>({})},
+ {"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
@@ -1214,7 +1218,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
+ gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ {"_outside_compilation_subgraph", "O1"}}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
@@ -1321,7 +1326,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
+ gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ {"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1403,7 +1409,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
+ gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ {"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1482,7 +1489,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
+ {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1561,7 +1569,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
+ {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1725,7 +1734,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})}},
+ {"shapes", gtl::ArraySlice<DataType>({})},
+ {"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
{{"f_0_retval", "F:o:0"}});
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 381c0205fd..2e362e0a63 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -138,7 +138,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
EXPECT_EQ(clusters["A"], clusters["C"]);
}
-TEST(XlaCompilationTest, UnsupportedTypes) {
+TEST(XlaCompilationTest, Complex128Unsupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
@@ -158,6 +158,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
EXPECT_TRUE(clusters.empty());
}
+TEST(XlaCompilationTest, HalfSupported) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Tensor t(DT_HALF, TensorShape());
+ t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_HALF)
+ .WithAttr("value", t));
+ Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
+ ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ auto clusters = GetClusters(*graph);
+ EXPECT_FALSE(clusters.empty());
+}
+
TEST(XlaCompilationTest, ConcatWithConstArg) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index d2dfdeea68..bc07dbd7bd 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -62,8 +62,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 6> kAllXlaCpuTypes = {
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 7> kAllXlaCpuTypes = {
+ {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 5a1db81774..ac60423d95 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -62,8 +62,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 6> kAllXlaGpuTypes = {
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 8> kAllXlaGpuTypes = {
+ {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+ DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 204a2a2f90..edabdc218a 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -375,7 +375,6 @@ tf_xla_py_test(
name = "momentum_test",
size = "small",
srcs = ["momentum_test.py"],
- tags = ["no_oss"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index ba7b9bacd2..d1d7379c0a 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -190,19 +190,24 @@ class BinaryOpsTest(XLATestCase):
],
equality_test=self.ListsAreClose)
- self._testBinary(
- gen_nn_ops.sparse_softmax_cross_entropy_with_logits,
- np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8],
- [0.9, 1.0, 1.1, 1.2]], dtype=dtype),
- np.array([2, 1, 7], dtype=np.int32),
- expected=[
- np.array([1.342536, 1.442536, np.nan], dtype=dtype),
- np.array([[0.213838, 0.236328, -0.738817, 0.288651],
- [0.213838, -0.763672, 0.261183, 0.288651],
- [np.nan, np.nan, np.nan, np.nan]],
- dtype=dtype),
- ],
- equality_test=self.ListsAreClose)
+ # TODO(b/68813416): Fails with bfloat16.
+ if dtype != dtypes.bfloat16.as_numpy_dtype:
+ self._testBinary(
+ gen_nn_ops.sparse_softmax_cross_entropy_with_logits,
+ np.array(
+ [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8],
+ [0.9, 1.0, 1.1, 1.2]],
+ dtype=dtype),
+ np.array([2, 1, 7], dtype=np.int32),
+ expected=[
+ np.array([1.342536, 1.442536, np.nan], dtype=dtype),
+ np.array(
+ [[0.213838, 0.236328, -0.738817, 0.288651], [
+ 0.213838, -0.763672, 0.261183, 0.288651
+ ], [np.nan, np.nan, np.nan, np.nan]],
+ dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
def testIntOps(self):
for dtype in self.int_types:
@@ -260,12 +265,6 @@ class BinaryOpsTest(XLATestCase):
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[8], [9]], dtype=dtype))
- self._testBinary(
- math_ops.add,
- np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64),
- np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64),
- expected=np.array(
- [1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64))
self._testBinary(
math_ops.subtract,
@@ -361,6 +360,12 @@ class BinaryOpsTest(XLATestCase):
np.array([2, -1], dtype=dtype),
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
+ self._testBinary(
+ math_ops.add,
+ np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64),
+ np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64),
+ expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64))
+
def testComplexOps(self):
for dtype in self.complex_types:
ctypes = {np.complex64: np.float32}
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 0528a5415d..a9db1c173d 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -56,7 +56,7 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
+ "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16"
]
backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
@@ -89,4 +89,3 @@ def generate_backend_suites(backends=[]):
backends = all_backends()
for backend in backends:
native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
-
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index 5010fe5e21..1a8989d7c2 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -34,6 +34,13 @@ from tensorflow.python.platform import test
class CholeskyOpTest(XLATestCase):
+ # Cholesky defined for float64, float32, complex64, complex128
+ # (https://www.tensorflow.org/api_docs/python/tf/cholesky)
+ @property
+ def float_types(self):
+ return set(super(CholeskyOpTest, self).float_types).intersection(
+ (np.float64, np.float32, np.complex64, np.complex128))
+
def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol):
chol_np, verification_np = sess.run([chol, verification], {placeholder: x})
self.assertAllClose(x, verification_np, atol=atol)
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index cccb7f5789..5819b2bf2b 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -37,6 +37,14 @@ def MakePlaceholder(x):
class MatrixTriangularSolveOpTest(XLATestCase):
+ # MatrixTriangularSolve defined for float64, float32, complex64, complex128
+ # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
+ @property
+ def float_types(self):
+ return set(super(MatrixTriangularSolveOpTest,
+ self).float_types).intersection(
+ (np.float64, np.float32, np.complex64, np.complex128))
+
def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
placeholder_b, a, clean_a, b, verification,
atol):
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index 92518aadc4..6083981493 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import test
@@ -156,6 +157,12 @@ class SpaceToBatchNDTest(XLATestCase):
paddings = np.array(paddings).reshape((len(block_shape), 2))
with self.test_session() as sess, self.test_scope():
for dtype in self.float_types:
+ # TODO(b/68813416): Skip bfloat16's as the input type for direct is
+ # float32 and results in a mismatch, while making testDirect provide the
+ # correctly typed input results in 'no fill-function for data-type'
+ # error.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ continue
placeholder = array_ops.placeholder(dtype)
# outputs = space_to_batch(inputs)
x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index a8ab235378..17149aa1c8 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -793,7 +793,10 @@ class UnaryOpsTest(XLATestCase):
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
self._assertSoftplusMatchesExpected(
[[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
- log_eps = np.log(np.finfo(dtype).eps)
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ log_eps = np.log(np.finfo(np.float32).eps)
+ else:
+ log_eps = np.log(np.finfo(dtype).eps)
one = dtype(1)
ten = dtype(10)
self._assertSoftplusMatchesExpected([
diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py
index b08d6ab21e..8ecad00f6e 100644
--- a/tensorflow/compiler/tests/variable_ops_test.py
+++ b/tensorflow/compiler/tests/variable_ops_test.py
@@ -230,7 +230,10 @@ class SliceAssignTest(XLATestCase):
# shrink shape changes
checker[1:2, 1] = [66]
checker[1, 1:2] = [66]
- checker[1, 1] = 66
+ if dtype != dtypes.bfloat16.as_numpy_dtype:
+ # TODO(b/68813416): valnp call above results in an ndarray and not a
+ # number for bfloat16s.
+ checker[1, 1] = 66
# newaxis shape changes
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
# shrink and newaxis
@@ -243,8 +246,11 @@ class SliceAssignTest(XLATestCase):
# Assign vector to scalar (rank-0) using newaxis
checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype)
- checker2[()] = 6 # no indices
- checker2[...] = 6 # ellipsis
+ if dtype != dtypes.bfloat16.as_numpy_dtype:
+ # TODO(b/68813416): valnp call above results in an ndarray and not a
+ # number for bfloat16s.
+ checker2[()] = 6 # no indices
+ checker2[...] = 6 # ellipsis
checker2[None] = [6] # new axis
def testUninitialized(self):
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index ff7453194a..e255b01dd7 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -51,13 +51,13 @@ constexpr std::array<DataType, 9> kNumericTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_BFLOAT16}};
-constexpr std::array<DataType, 8> kCpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
+constexpr std::array<DataType, 9> kCpuAllTypes = {
+ {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_BOOL}};
-constexpr std::array<DataType, 8> kGpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 10> kGpuAllTypes = {
+ {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index c4c8894374..3f45167fcb 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -324,8 +324,38 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) {
- return Unimplemented(
- "ExecuteParallel is not yet implemented for XlaComputation.");
+ ExecuteGraphParallelRequest request;
+
+ for (const XlaComputationInstance& computation : computations) {
+ ExecuteGraphRequest single_request;
+ *single_request.mutable_computation() = computation.computation.proto();
+ for (GlobalData* argument : computation.arguments) {
+ *single_request.add_arguments() = argument->handle();
+ }
+ *single_request.mutable_execution_options() = computation.execution_options;
+ *request.add_requests() = single_request;
+ }
+
+ ExecuteParallelResponse response;
+ VLOG(1) << "making execute-graph-parallel request: "
+ << request.ShortDebugString();
+ tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::vector<std::unique_ptr<GlobalData>> outputs;
+ for (size_t i = 0; i < computations.size(); ++i) {
+ outputs.push_back(
+ MakeUnique<GlobalData>(stub_, response.responses(i).output()));
+ if (computations[i].execution_profile != nullptr) {
+ *computations[i].execution_profile = response.responses(i).profile();
+ }
+ }
+
+ return std::move(outputs);
}
StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index d02972f2c0..f4673a8204 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -24,6 +24,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 24048a1e5a..63df449e0b 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -26,6 +26,7 @@ limitations under the License.
namespace xla {
namespace {
+
using InstructionGenerator =
ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&,
const ComputationDataHandle&);
@@ -47,6 +48,27 @@ Computation CreateScalarComputation(const string& name, PrimitiveType type,
generator(b.get(), lhs, rhs);
return b->BuildAndNoteError();
}
+
+using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
+
+XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
+ XlaBuilder* builder,
+ XlaOpGenerator generator) {
+ std::unique_ptr<XlaBuilder> b;
+ if (type == PRED) {
+ b = builder->CreateSubBuilder(name);
+ } else {
+ b = builder->CreateSubBuilder(
+ tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
+ }
+
+ const Shape scalar = ShapeUtil::MakeShape(type, {});
+ auto lhs = b->Parameter(0, scalar, "lhs");
+ auto rhs = b->Parameter(1, scalar, "rhs");
+ generator(b.get(), lhs, rhs);
+ return b->BuildAndNoteError();
+}
+
} // namespace
Computation CreateScalarAddComputation(PrimitiveType type,
@@ -60,7 +82,7 @@ Computation CreateScalarAddComputation(PrimitiveType type,
Computation CreateScalarMultiplyComputation(PrimitiveType type,
ComputationBuilder* builder) {
return CreateScalarComputation(
- "add", type, builder,
+ "mul", type, builder,
[](ComputationBuilder* b, const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); });
}
@@ -114,4 +136,75 @@ StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
return builder->Reduce(predicates, f, logical_or, all_dimensions);
}
+XlaComputation CreateScalarAddComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "add", type, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->Add(lhs, rhs);
+ });
+}
+
+XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "mul", type, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->Mul(lhs, rhs);
+ });
+}
+
+XlaComputation CreateScalarGeComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "ge", type, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->Ge(lhs, rhs);
+ });
+}
+
+XlaComputation CreateScalarMaxComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "max", type, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->Max(lhs, rhs);
+ });
+}
+
+XlaComputation CreateScalarMinComputation(PrimitiveType type,
+ XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "min", type, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->Min(lhs, rhs);
+ });
+}
+
+XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "and", PRED, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->And(lhs, rhs);
+ });
+}
+
+XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
+ return CreateScalarComputation(
+ "or", PRED, builder,
+ [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+ return b->Or(lhs, rhs);
+ });
+}
+
+StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
+ auto f = builder->ConstantR0<bool>(false);
+ XlaComputation logical_or = CreateScalarOrComputation(builder);
+ TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
+ builder->GetShape(predicates));
+ std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
+ std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
+ return builder->Reduce(predicates, f, logical_or, all_dimensions);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index ae89784bc2..f4d3fc8015 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -56,6 +58,48 @@ Computation CreateScalarOrComputation(ComputationBuilder* builder);
StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
ComputationBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar add computation and returns it.
+XlaComputation CreateScalarAddComputation(PrimitiveType type,
+ XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar multiply computation and returns it.
+XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
+ XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar ge computation and returns it.
+XlaComputation CreateScalarGeComputation(PrimitiveType type,
+ XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar max computation and returns it.
+XlaComputation CreateScalarMaxComputation(PrimitiveType type,
+ XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar min computation and returns it.
+XlaComputation CreateScalarMinComputation(PrimitiveType type,
+ XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar logical AND computation and returns it.
+XlaComputation CreateScalarAndComputation(XlaBuilder* builder);
+
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar logical OR computation and returns it.
+XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
+
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Returns whether any predicate in "predicates" is set.
+//
+// Note: if predicates is zero-sized, Any() vacuously returns false.
+StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index e51a8b14c0..2d587cc3b9 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include <functional>
#include <numeric>
#include <string>
#include <utility>
@@ -44,6 +45,7 @@ int64 GetUniqueId() {
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
return false;
@@ -52,20 +54,35 @@ bool CanBeRoot(HloOpcode opcode) {
}
}
+StatusOr<std::vector<Shape>> GetOperandShapes(
+ tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ std::vector<Shape> operand_shapes;
+ for (const XlaOp& operand : operands) {
+ TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
+ operand_shapes.push_back(shape);
+ }
+ return operand_shapes;
+}
+
} // namespace
StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
+ TF_RETURN_IF_ERROR(first_error_);
+
TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
return instr->shape();
}
StatusOr<Shape> XlaOp::GetShape() const {
- TF_RET_CHECK(builder_ != nullptr);
+ if (builder_ == nullptr) {
+ return InvalidArgument(
+ "cannot GetShape for an invalid XlaOp with handle %lld", handle());
+ }
return builder_->GetShape(*this);
}
XlaBuilder::XlaBuilder(const string& computation_name)
- : name_(computation_name) {}
+ : name_(computation_name), unique_id_(GetUniqueId()) {}
XlaBuilder::~XlaBuilder() {}
@@ -81,7 +98,22 @@ void XlaBuilder::NoteError(const Status& error) {
}
}
+XlaOp XlaBuilder::NoteErrorOrReturn(
+ const std::function<StatusOr<XlaOp>()>& op_creator) {
+ if (!first_error_.ok()) {
+ return {};
+ }
+ auto op = op_creator();
+ if (!op.ok()) {
+ NoteError(op.status());
+ return {};
+ }
+ return op.ConsumeValueOrDie();
+}
+
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) {
+ TF_RETURN_IF_ERROR(first_error_);
+
TF_RET_CHECK(root_id != nullptr);
ProgramShape program_shape;
@@ -148,7 +180,6 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
}
HloComputationProto entry;
- entry.set_name(name_);
{
int64 root_id;
@@ -162,9 +193,9 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
entry.add_instructions()->Swap(&instruction);
}
- const int64 id = GetUniqueId();
- entry.set_id(id);
- XlaComputation computation(id);
+ entry.set_id(unique_id_);
+ entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
+ XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
@@ -187,6 +218,8 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
const Shape& shape, const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ TF_RETURN_IF_ERROR(first_error_);
+
HloInstructionProto instr;
*instr.mutable_shape() = shape;
for (int64 dim : broadcast_dimensions) {
@@ -197,6 +230,8 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
const XlaOp& operand) {
+ TF_RETURN_IF_ERROR(first_error_);
+
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
CHECK(ShapeUtil::IsScalar(operand_shape) ||
@@ -240,7 +275,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferUnaryOpShape(unop, operand_shape));
return AddInstruction(std::move(instr), unop, {operand});
- }());
+ });
}
XlaOp XlaBuilder::BinaryOp(
@@ -297,7 +332,7 @@ XlaOp XlaBuilder::BinaryOp(
}
return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
- }());
+ });
}
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
@@ -335,7 +370,7 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
}
return AddInstruction(std::move(instr), triop,
{updated_lhs, updated_rhs, updated_ehs});
- }());
+ });
}
XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
@@ -354,7 +389,7 @@ XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
*instr.mutable_shape() = literal.shape();
*instr.mutable_literal() = literal.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kConstant);
- }());
+ });
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
@@ -362,11 +397,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
- std::vector<Shape> operand_shapes;
- for (const auto& operand : operands) {
- TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
- operand_shapes.push_back(shape);
- }
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
@@ -376,15 +407,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
ShapeInference::InferCallShape(operand_shape_ptrs,
/*to_apply=*/called_program_shape));
- // Add called computation.
- instr.add_called_computation_ids(
- computation.proto().entry_computation_id());
- for (const HloComputationProto& e : computation.proto().computations()) {
- embedded_.insert({e.id(), e});
- }
+ AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
- }());
+ });
}
XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
@@ -400,7 +426,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
instr.set_name(name);
*instr.mutable_shape() = shape;
return AddInstruction(std::move(instr), HloOpcode::kParameter);
- }());
+ });
}
XlaOp XlaBuilder::Broadcast(
@@ -424,10 +450,12 @@ XlaOp XlaBuilder::Broadcast(
dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank;
}
return InDimBroadcast(shape, operand, dimensions);
- }());
+ });
}
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
+ TF_RETURN_IF_ERROR(first_error_);
+
HloInstructionProto instr;
*instr.mutable_shape() = shape;
return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
@@ -437,7 +465,22 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferSliceShape(operand_shape, start_indices,
+ limit_indices, strides));
+ for (int i = 0; i < start_indices.size(); i++) {
+ auto* slice_config = instr.add_slice_dimensions();
+ slice_config->set_start(start_indices[i]);
+ slice_config->set_limit(limit_indices[i]);
+ slice_config->set_stride(strides[i]);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+ });
}
XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
@@ -447,17 +490,60 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferDynamicSliceShape(
+ operand_shape, start_indices_shape, slice_sizes));
+
+ for (int64 size : slice_sizes) {
+ instr.add_dynamic_slice_sizes(size);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
+ {operand, start_indices});
+ });
}
XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferDynamicUpdateSliceShape(
+ operand_shape, update_shape, start_indices_shape));
+
+ return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
+ {operand, update, start_indices});
+ });
}
XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
int64 dimension) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
+ c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension));
+
+ instr.add_dimensions(dimension);
+
+ return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
+ });
}
XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
@@ -477,7 +563,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
? operand
: Transpose(operand, dimensions);
return Reshape(shape, transposed);
- }());
+ });
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
@@ -487,7 +573,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
std::vector<int64> dimensions(shape.dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
return Reshape(operand, dimensions, new_sizes);
- }());
+ });
}
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
@@ -496,7 +582,12 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
- UnimplementedOp();
+ NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeNil();
+ *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+ return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
+ });
}
XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
@@ -508,18 +599,14 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
- std::vector<Shape> operand_shapes;
- for (const XlaOp& e : elements) {
- TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(e));
- operand_shapes.push_back(shape);
- }
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
- }());
+ });
}
XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
@@ -538,7 +625,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
- }());
+ });
}
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
@@ -572,12 +659,29 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+
+ DotDimensionNumbers dimension_numbers;
+ dimension_numbers.add_lhs_contracting_dimensions(
+ lhs_shape.dimensions_size() == 1 ? 0 : 1);
+ dimension_numbers.add_rhs_contracting_dimensions(0);
+ return DotGeneral(lhs, rhs, dimension_numbers);
+ });
}
XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
+ dimension_numbers));
+ *instr.mutable_dot_dimension_numbers() = dimension_numbers;
+ return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
+ });
}
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
@@ -788,7 +892,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
- }());
+ });
}
XlaOp XlaBuilder::Rev(const XlaOp& operand,
@@ -812,7 +916,14 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferConvertShape(operand_shape, new_element_type));
+ return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
+ });
}
XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
@@ -846,19 +957,64 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
return UnimplementedOp();
}
+XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<XlaOp> parameters,
+ const Shape& shape) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ // Check the number of parameters per RNG distribution.
+ switch (distribution) {
+ case RandomDistribution::RNG_NORMAL:
+ case RandomDistribution::RNG_UNIFORM:
+ if (parameters.size() != 2) {
+ return InvalidArgument(
+ "RNG distribution (%s) expects 2 parameters, but got %ld",
+ RandomDistribution_Name(distribution).c_str(), parameters.size());
+ }
+ break;
+ default:
+ LOG(FATAL) << "unhandled distribution " << distribution;
+ }
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+ *instr.mutable_shape() = shape;
+
+ instr.set_distribution(distribution);
+
+ return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
+ });
+}
+
XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma,
const Shape& shape) {
- return UnimplementedOp();
+ return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
}
XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
const Shape& shape) {
- return UnimplementedOp();
+ return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
}
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, const XlaOp& init) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ // Infer shape.
+ TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
+ condition.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferWhileShape(condition_program_shape,
+ body_program_shape, init_shape));
+ // Body comes before condition computation in the vector.
+ AddCalledComputation(body, &instr);
+ AddCalledComputation(condition, &instr);
+ return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
+ });
}
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
@@ -878,7 +1034,27 @@ XlaOp XlaBuilder::Reduce(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
+ computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferReduceShape(
+ operand_shape, init_shape, dimensions_to_reduce,
+ called_program_shape));
+
+ for (int64 dim : dimensions_to_reduce) {
+ instr.add_dimensions(dim);
+ }
+
+ AddCalledComputation(computation, &instr);
+
+ return AddInstruction(std::move(instr), HloOpcode::kReduce,
+ {operand, init_value});
+ });
}
XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
@@ -952,11 +1128,43 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
}
void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
- UnimplementedOp();
+ NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ // Send instruction produces a tuple of {aliased operand, U32 context}.
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+ *instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
+ instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(
+ XlaOp send,
+ AddInstruction(std::move(instr), HloOpcode::kSend, {operand}));
+
+ HloInstructionProto send_done_instr;
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeNil();
+ send_done_instr.set_channel_id(handle.handle());
+ return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
+ {send});
+ });
}
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ // Recv instruction produces a tuple of {receive buffer, U32 context}.
+ *instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
+ instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv,
+ AddInstruction(std::move(instr), HloOpcode::kRecv, {}));
+
+ HloInstructionProto recv_done_instr;
+ *recv_done_instr.mutable_shape() = shape;
+ recv_done_instr.set_channel_id(handle.handle());
+ return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
+ {recv});
+ });
}
StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand,
@@ -1055,20 +1263,27 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
StatusOr<XlaOp> XlaBuilder::AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ TF_RETURN_IF_ERROR(first_error_);
+
const int64 handle = instructions_.size();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
- instr.set_name(StrCat(instr.opcode(), ".", handle));
+ instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle));
} else {
// Append the handle to make sure the name is unique.
- instr.set_name(StrCat(instr.name(), ".", handle));
+ instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle));
}
for (const auto& operand : operands) {
- TF_RET_CHECK(operand.builder_ != nullptr);
- TF_RET_CHECK(operand.builder_ == this)
- << "Do not add XlaOp from builder " << operand.builder_->name()
- << " to builder " << this->name();
+ if (operand.builder_ == nullptr) {
+ return InvalidArgument("invalid XlaOp with handle %lld",
+ operand.handle());
+ }
+ if (operand.builder_ != this) {
+ return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
+ operand.builder_->name().c_str(),
+ this->name().c_str());
+ }
instr.add_operand_ids(operand.handle());
}
@@ -1083,8 +1298,22 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
return op;
}
+void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
+ HloInstructionProto* instr) {
+ instr->add_called_computation_ids(computation.proto().entry_computation_id());
+ for (const HloComputationProto& e : computation.proto().computations()) {
+ embedded_.insert({e.id(), e});
+ }
+}
+
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp& op) const {
+ TF_RETURN_IF_ERROR(first_error_);
+
+ if (op.builder_ != this) {
+ return InvalidArgument("invalid XlaOp with handle %lld", op.handle());
+ }
+
TF_RET_CHECK(op.builder_ == this);
if (op.handle() >= instructions_.size() || op.handle() < 0) {
return InvalidArgument("no XlaOp value %lld", op.handle());
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index f66feb93ce..0673b86646 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -803,19 +803,16 @@ class XlaBuilder {
HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands = {});
+ void AddCalledComputation(const XlaComputation& computation,
+ HloInstructionProto* instr);
+
// Notes that the error occurred by:
// * storing it internally and capturing a backtrace if it's the first error
// (this deferred value will be produced on the call to Build())
// * dying if die_immediately_on_error_ is true
void NoteError(const Status& error);
- XlaOp NoteErrorOrReturn(StatusOr<XlaOp>&& op) {
- if (!op.ok()) {
- NoteError(op.status());
- return XlaOp();
- }
- return op.ConsumeValueOrDie();
- }
+ XlaOp NoteErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
// Helper method that creates an empty op and notes error.
XlaOp UnimplementedOp();
@@ -835,6 +832,10 @@ class XlaBuilder {
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
const XlaOp& ehs);
+ XlaOp RngOp(RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<XlaOp> parameters,
+ const Shape& shape);
+
StatusOr<XlaOp> InDimBroadcast(
const Shape& shape, const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
@@ -852,7 +853,8 @@ class XlaBuilder {
// computation and fills the root_id in the pointer.
StatusOr<ProgramShape> GetProgramShape(int64* root_id);
- string name_; // Name to use for the built computation.
+ string name_; // Name to use for the built computation.
+ int64 unique_id_; // The unique id for the built computation.
// The first error encountered while building the computation.
// This is OK until the first error is encountered.
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index b21ab3044f..2bacc6a914 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -521,6 +521,17 @@ ComputationDataHandle LocalComputationBuilder::Conditional(
false_computation.computation());
}
+StatusOr<bool> LocalComputationBuilder::IsConstant(
+ const ComputationDataHandle& operand, int64 num_parameters) {
+ return builder_.IsConstant(operand, num_parameters);
+}
+
+StatusOr<std::unique_ptr<Literal>> LocalComputationBuilder::ComputeConstant(
+ const ComputationDataHandle& operand, const Layout* output_layout,
+ tensorflow::gtl::ArraySlice<Literal> parameters) {
+ return builder_.ComputeConstant(operand, output_layout, parameters);
+}
+
#define _FORWARD(method_name, return_sig, args_sig, args) \
return_sig LocalComputationBuilder::method_name args_sig { \
return builder_.method_name args; \
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index a7375c8965..31046e60f1 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -268,6 +268,13 @@ class LocalComputationBuilder {
const ComputationDataHandle& false_operand,
const LocalComputation& false_computation);
+ StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
+ int64 num_parameters);
+
+ StatusOr<std::unique_ptr<Literal> > ComputeConstant(
+ const ComputationDataHandle& operand, const Layout* output_layout,
+ tensorflow::gtl::ArraySlice<Literal> parameters);
+
#define _FORWARD(method_name, return_sig, args_sig) \
return_sig method_name args_sig;
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 8f231d1a12..ac792e8189 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -182,7 +182,7 @@ tensorflow::ImportNumpy();
%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
const int64 handle = numpy::PyIntOrPyLongToLong($input);
if (handle == -1 && PyErr_Occurred()) {
- return NULL;
+ SWIG_fail;
}
temp.set_handle(handle);
$1 = &temp;
@@ -201,7 +201,7 @@ tensorflow::ImportNumpy();
}
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
}
@@ -211,7 +211,7 @@ tensorflow::ImportNumpy();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
}
@@ -224,7 +224,7 @@ tensorflow::ImportNumpy();
}
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
}
@@ -233,7 +233,16 @@ tensorflow::ImportNumpy();
$result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
+ }
+}
+
+%typemap(out) StatusOr<bool> {
+ if ($1.ok()) {
+ $result = PyBool_FromLong($1.ConsumeValueOrDie());
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+ SWIG_fail;
}
}
@@ -241,7 +250,7 @@ tensorflow::ImportNumpy();
if (!$1.ok()) {
PyErr_SetString(
PyExc_RuntimeError, $1.ToString().c_str());
- return NULL;
+ SWIG_fail;
}
Py_INCREF(Py_None);
$result = Py_None;
@@ -253,7 +262,7 @@ tensorflow::ImportNumpy();
(std::vector<int64> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
temps.resize(size);
@@ -265,13 +274,13 @@ tensorflow::ImportNumpy();
PyExc_TypeError,
"Argument sequence element cannot be converted to int");
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
temps[i] = numpy::PyIntOrPyLongToLong(py_int);
if (temps[i] == -1 && PyErr_Occurred()) {
Py_DECREF(py_int);
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
Py_DECREF(py_int);
Py_DECREF(o);
@@ -285,7 +294,7 @@ tensorflow::ImportNumpy();
(std::vector<ComputationDataHandle> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
temps.resize(size);
@@ -296,13 +305,13 @@ tensorflow::ImportNumpy();
PyErr_SetString(
PyExc_TypeError,
"Argument sequence element cannot be converted to int");
- return NULL;
+ SWIG_fail;
}
const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
if (handle == -1 && PyErr_Occurred()) {
Py_DECREF(py_int);
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
temps[i].set_handle(handle);
Py_DECREF(py_int);
@@ -317,7 +326,7 @@ tensorflow::ImportNumpy();
(std::vector<LocalShapedBuffer*> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
temps.reserve(size);
@@ -326,7 +335,7 @@ tensorflow::ImportNumpy();
LocalShapedBuffer* lsbp;
if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*),
SWIG_POINTER_EXCEPTION)) == -1) {
- return NULL;
+ SWIG_fail;
}
temps.push_back(lsbp);
Py_DECREF(o);
@@ -340,7 +349,7 @@ tensorflow::ImportNumpy();
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
$1 = literal_status.ValueOrDie().get();
}
@@ -352,7 +361,7 @@ tensorflow::ImportNumpy();
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
$result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
}
@@ -360,7 +369,7 @@ tensorflow::ImportNumpy();
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
@@ -369,7 +378,7 @@ tensorflow::ImportNumpy();
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
Py_DECREF(o);
@@ -383,7 +392,7 @@ tensorflow::ImportNumpy();
StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
temp = std::move(statusor).ValueOrDie();
$1 = &temp;
@@ -395,7 +404,7 @@ tensorflow::ImportNumpy();
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
temp = std::move(statusor).ValueOrDie();
$1 = &temp;
@@ -410,7 +419,7 @@ tensorflow::ImportNumpy();
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
temp = std::move(statusor).ValueOrDie();
$1 = &temp;
@@ -424,7 +433,7 @@ tensorflow::ImportNumpy();
%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
@@ -433,7 +442,7 @@ tensorflow::ImportNumpy();
Py_DECREF(o);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
temps.push_back(statusor.ConsumeValueOrDie());
}
@@ -444,7 +453,7 @@ tensorflow::ImportNumpy();
std::vector<tensorflow::gtl::optional<Shape> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
@@ -456,7 +465,7 @@ tensorflow::ImportNumpy();
Py_DECREF(o);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
- return NULL;
+ SWIG_fail;
}
temps.push_back(statusor.ConsumeValueOrDie());
}
@@ -470,18 +479,18 @@ tensorflow::ImportNumpy();
PyObject* py_int = numpy::PyNumberToPyInt($input);
if (!py_int) {
PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
- return NULL;
+ SWIG_fail;
}
const long value = numpy::PyIntOrPyLongToLong(py_int);
if (value == -1 && PyErr_Occurred()) {
Py_DECREF(py_int);
- return NULL;
+ SWIG_fail;
}
if (!PrimitiveType_IsValid(value)) {
PyErr_SetString(
PyExc_TypeError, "Argument not valid for PrimitiveType enum");
Py_DECREF(py_int);
- return NULL;
+ SWIG_fail;
}
$1 = static_cast<PrimitiveType>(value);
}
@@ -492,19 +501,19 @@ tensorflow::ImportNumpy();
(std::vector<std::pair<int64, int64> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
- return NULL;
+ SWIG_fail;
}
const int size = PySequence_Size($input);
temps.reserve(size);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
if (!o) {
- return NULL;
+ SWIG_fail;
}
PyObject* first = PyTuple_GetItem(o, 0);
if (!first) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
PyObject* first_pyint = numpy::PyNumberToPyInt(first);
if (!first_pyint) {
@@ -512,13 +521,13 @@ tensorflow::ImportNumpy();
PyExc_TypeError,
"First pair item cannot be converted to int");
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
PyObject* second = PyTuple_GetItem(o, 1);
if (!second) {
Py_DECREF(o);
Py_DECREF(first_pyint);
- return NULL;
+ SWIG_fail;
}
PyObject* second_pyint = numpy::PyNumberToPyInt(second);
if (!second_pyint) {
@@ -527,21 +536,21 @@ tensorflow::ImportNumpy();
"Second pair item cannot be converted to int");
Py_DECREF(o);
Py_DECREF(first_pyint);
- return NULL;
+ SWIG_fail;
}
const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
if (first_value == -1 && PyErr_Occurred()) {
Py_DECREF(o);
Py_DECREF(first_pyint);
Py_DECREF(second_pyint);
- return NULL;
+ SWIG_fail;
}
const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
if (second_value == -1 && PyErr_Occurred()) {
Py_DECREF(o);
Py_DECREF(first_pyint);
Py_DECREF(second_pyint);
- return NULL;
+ SWIG_fail;
}
temps.push_back(std::make_pair(first_value, second_value));
Py_DECREF(o);
@@ -559,26 +568,26 @@ tensorflow::ImportNumpy();
PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
$input, "lhs_contracting_dimensions");
if (!lhs_contracting_dimensions) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(lhs_contracting_dimensions);
if (length == -1) {
Py_DECREF(lhs_contracting_dimensions);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
if (!item) {
Py_DECREF(lhs_contracting_dimensions);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(lhs_contracting_dimensions);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_lhs_contracting_dimensions(dimension);
Py_DECREF(item);
@@ -589,26 +598,26 @@ tensorflow::ImportNumpy();
PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
$input, "rhs_contracting_dimensions");
if (!lhs_contracting_dimensions) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(rhs_contracting_dimensions);
if (length == -1) {
Py_DECREF(rhs_contracting_dimensions);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
if (!item) {
Py_DECREF(rhs_contracting_dimensions);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(rhs_contracting_dimensions);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_rhs_contracting_dimensions(dimension);
Py_DECREF(item);
@@ -619,26 +628,26 @@ tensorflow::ImportNumpy();
PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
$input, "lhs_batch_dimensions");
if (!lhs_batch_dimensions) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(lhs_batch_dimensions);
if (length == -1) {
Py_DECREF(lhs_batch_dimensions);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
if (!item) {
Py_DECREF(lhs_batch_dimensions);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(lhs_batch_dimensions);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_lhs_batch_dimensions(dimension);
Py_DECREF(item);
@@ -649,26 +658,26 @@ tensorflow::ImportNumpy();
PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
$input, "rhs_batch_dimensions");
if (!rhs_batch_dimensions) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(rhs_batch_dimensions);
if (length == -1) {
Py_DECREF(rhs_batch_dimensions);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
if (!item) {
Py_DECREF(rhs_batch_dimensions);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(rhs_batch_dimensions);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_rhs_batch_dimensions(dimension);
Py_DECREF(item);
@@ -684,20 +693,20 @@ tensorflow::ImportNumpy();
(PaddingConfig padding_config) {
PyObject* dimensions = PyObject_GetAttrString($input, "dimensions");
if (!dimensions) {
- return NULL;
+ SWIG_fail;
}
int length = PySequence_Size(dimensions);
if (length == -1) {
Py_DECREF(dimensions);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(dimensions, i);
if (!item) {
Py_DECREF(dimensions);
- return NULL;
+ SWIG_fail;
}
int64 edge_padding_low, edge_padding_high, interior_padding;
if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low)
@@ -705,7 +714,7 @@ tensorflow::ImportNumpy();
|| !GetIntAttr(item, "interior_padding", &interior_padding)) {
Py_DECREF(item);
Py_DECREF(dimensions);
- return NULL;
+ SWIG_fail;
}
Py_DECREF(item);
@@ -727,32 +736,32 @@ tensorflow::ImportNumpy();
int64 value;
if (!GetIntAttr($input, "input_batch_dimension", &value)) {
- return NULL;
+ SWIG_fail;
}
dimension_numbers.set_input_batch_dimension(value);
if (!GetIntAttr($input, "input_feature_dimension", &value)) {
- return NULL;
+ SWIG_fail;
}
dimension_numbers.set_input_feature_dimension(value);
if (!GetIntAttr($input, "output_batch_dimension", &value)) {
- return NULL;
+ SWIG_fail;
}
dimension_numbers.set_output_batch_dimension(value);
if (!GetIntAttr($input, "output_feature_dimension", &value)) {
- return NULL;
+ SWIG_fail;
}
dimension_numbers.set_output_feature_dimension(value);
if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
- return NULL;
+ SWIG_fail;
}
dimension_numbers.set_kernel_output_feature_dimension(value);
if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
- return NULL;
+ SWIG_fail;
}
dimension_numbers.set_kernel_input_feature_dimension(value);
@@ -761,24 +770,24 @@ tensorflow::ImportNumpy();
o = PyObject_GetAttrString($input, "input_spatial_dimensions");
if (!o) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_input_spatial_dimensions(dimension);
Py_DECREF(item);
@@ -787,24 +796,24 @@ tensorflow::ImportNumpy();
o = PyObject_GetAttrString($input, "kernel_spatial_dimensions");
if (!o) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_kernel_spatial_dimensions(dimension);
Py_DECREF(item);
@@ -813,24 +822,24 @@ tensorflow::ImportNumpy();
o = PyObject_GetAttrString($input, "output_spatial_dimensions");
if (!o) {
- return NULL;
+ SWIG_fail;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
dimension_numbers.add_output_spatial_dimensions(dimension);
Py_DECREF(item);
@@ -865,12 +874,12 @@ tensorflow::ImportNumpy();
PyObject* o = PyObject_GetAttrString($input, "hlo_profile");
if (o == NULL) {
- return NULL;
+ SWIG_fail;
}
if (o != Py_None) {
if (!PyBool_Check(o)) {
PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None.");
- return NULL;
+ SWIG_fail;
}
build_options.set_hlo_profile(o == Py_True);
}
@@ -885,7 +894,7 @@ tensorflow::ImportNumpy();
if (!statusor.ok()) {
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
Py_DECREF(o);
- return NULL;
+ SWIG_fail;
}
build_options.set_result_layout(statusor.ValueOrDie());
}
@@ -951,6 +960,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::RngBernoulli;
%unignore xla::swig::LocalComputationBuilder::While;
%unignore xla::swig::LocalComputationBuilder::Conditional;
+%unignore xla::swig::LocalComputationBuilder::IsConstant;
%unignore xla::swig::LocalComputationBuilder::Eq;
%unignore xla::swig::LocalComputationBuilder::Ne;
%unignore xla::swig::LocalComputationBuilder::Ge;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index e548d420f4..9c81f6439d 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1028,6 +1028,20 @@ class ComputationBuilder(object):
_unwrap_data_handle(false_operand),
false_computation.c_local_computation))
+ def IsConstant(self, operand, num_parameters=0):
+ """Enqueues an IsConstant operation onto the computation.
+
+ Args:
+ operand: a ComputationDataHandle to test.
+ num_parameters: optional int, number of computation parameters to treat as
+ constant (default 0).
+
+ Returns: bool indicating whether `operand` is a compile-time constant,
+ meaning its value does not depend on parameters with index greater than or
+ equal to `num_parameters`.
+ """
+ return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters)
+
def Dot(self, lhs, rhs):
"""Enqueues a dot operation onto the computation.
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 4c16c1f8b0..d97264ea64 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -855,6 +855,17 @@ class SingleOpTest(LocalComputationTest):
self.assertTrue(np.all(lo <= result))
self.assertTrue(np.all(result < hi))
+ def testIsConstant(self):
+ c = self._NewComputation()
+ a = c.ConstantS32Scalar(3)
+ b = c.ConstantS32Scalar(1)
+ x = c.ParameterFromNumpy(NumpyArrayS32(0))
+ const_expr = c.Sub(b, a)
+ non_const_expr = c.Mul(const_expr, x)
+ self.assertTrue(c.IsConstant(const_expr))
+ self.assertFalse(c.IsConstant(non_const_expr))
+ # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564)
+
class EmbeddedComputationsTest(LocalComputationTest):
"""Tests for XLA graphs with embedded computations (such as maps)."""
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 80c24eaccf..4198260a22 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -87,7 +87,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
/*MAttrs=*/DetectMachineAttributes()))),
disassembler_(*target_machine_),
data_layout_(target_machine_->createDataLayout()),
- execution_session_(string_pool_),
symbol_resolver_(llvm::orc::createLegacyLookupResolver(
[this](const std::string& name) -> llvm::JITSymbol {
return this->ResolveRuntimeSymbol(name);
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
index aaeff2de87..f4260a95bc 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
@@ -102,7 +102,6 @@ class SimpleOrcJIT {
std::unique_ptr<llvm::TargetMachine> target_machine_;
const Disassembler disassembler_;
const llvm::DataLayout data_layout_;
- llvm::orc::SymbolStringPool string_pool_;
llvm::orc::ExecutionSession execution_session_;
std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
ObjLayerT object_layer_;
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index be92b1629a..471d2fd6ce 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -80,6 +80,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> Executable::ExecuteOnStreamWrapper(
StatusOr<std::unique_ptr<ShapedBuffer>> return_value =
ExecuteOnStream(run_options, arguments, profile_ptr.get());
+ TF_RETURN_IF_ERROR(return_value.status());
if (profile != nullptr) {
VLOG(1) << "enqueueing 'stop timer' and blocking host until done...";
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 6f983d0b95..594413e88f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -304,19 +304,15 @@ void ComputeComputationPostOrder(
HloComputation* computation,
tensorflow::gtl::FlatSet<HloComputation*>* visited,
std::list<HloComputation*>* post_order) {
- if (visited->count(computation) > 0) {
- return;
- }
-
- for (auto* instruction : computation->instructions()) {
- for (HloComputation* called_computation :
- instruction->called_computations()) {
- ComputeComputationPostOrder(called_computation, visited, post_order);
+ if (visited->insert(computation).second) {
+ for (auto* instruction : computation->instructions()) {
+ for (HloComputation* called_computation :
+ instruction->called_computations()) {
+ ComputeComputationPostOrder(called_computation, visited, post_order);
+ }
}
+ post_order->push_back(computation);
}
-
- visited->insert(computation);
- post_order->push_back(computation);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 693004d364..9d7251b6ae 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1520,14 +1520,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
arg_dim_counts[dim] = arg_dimensions[dim];
}
- // Create mapping from result index to arg index.
- const int64 result_rank = ShapeUtil::Rank(result->shape());
- int64 result_dim = 0;
- std::vector<int64> result_to_arg_index(result_rank);
+ // Map each dimension in the result to a dimension in arg that isn't
+ // being reduced.
+ std::vector<int64> result_to_arg_index;
for (int64 i = 0; i < arg_dimensions.size(); ++i) {
if (arg_dim_steps[i] == 0) {
- result_to_arg_index[result_dim] = i;
- ++result_dim;
+ result_to_arg_index.push_back(i);
}
}
@@ -1542,6 +1540,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
base[result_to_arg_index[i]] = multi_index[i];
}
+ // When the reduction is addition of floats, accumulate in a double
+ // for better precision. Also, avoid creating Literals for the
+ // intermediate results; it's much faster.
+ if (ShapeUtil::ElementIsFloating(init_literal.shape()) &&
+ IsScalarAdd(function)) {
+ double computed_result = 0;
+ auto func = [&](ArraySlice<int64> input_index) {
+ computed_result += arg_literal.Get<float>(input_index);
+ return true;
+ };
+ ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
+ arg_dim_steps, func);
+ return static_cast<ReturnT>(computed_result);
+ }
auto func = [&](ArraySlice<int64> input_index) {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
@@ -1554,19 +1566,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
std::unique_ptr<Literal> computed_result =
embedded_evaluator.Evaluate<const Literal*>(*function, args)
.ConsumeValueOrDie();
- // Clear visit states so that the we can use the evaluate again on
+ // Clear visit states so that we can use the evaluator again on
// the same computation.
embedded_evaluator.ResetVisitStates();
-
// Assign computed result to result_val.
result_val = computed_result->Get<ReturnT>({});
-
return true;
};
-
+ // Computes one element of the result, reducing all dimensions that
+ // contribute to that element.
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
arg_dim_steps, func);
-
return result_val;
}));
@@ -1574,6 +1584,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ bool IsScalarAdd(HloComputation* computation) {
+ HloInstruction* instruction = computation->root_instruction();
+ if (instruction->opcode() == HloOpcode::kAdd &&
+ computation->num_parameters() == 2) {
+ const HloInstruction* lhs = instruction->operand(0);
+ const HloInstruction* rhs = instruction->operand(1);
+ return lhs->opcode() == HloOpcode::kParameter &&
+ ShapeUtil::IsScalar(lhs->shape()) &&
+ rhs->opcode() == HloOpcode::kParameter &&
+ ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
+ }
+ return false;
+ }
+
Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
auto operand = select_and_scatter->operand(0);
auto source = select_and_scatter->operand(1);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 685cacd7f7..dd14dd3853 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -1205,6 +1206,80 @@ TEST_P(HloEvaluatorTest,
LiteralTestUtil::ExpectEqual(*expected, *result);
}
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
+
+// Tests that Reduce doesn't lose precision when adding many numbers (because
+// it accumulates its result in a double).
+TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
+ HloComputation::Builder b(TestName());
+
+ constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
+ std::vector<float> v(kNumElements, 1.0f);
+ HloInstruction* arg_instruction = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<float>(v)));
+ HloInstruction* init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+
+ HloComputation::Builder add_computation("add");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ add_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
+ auto add_func = module().AddEmbeddedComputation(add_computation.Build());
+
+ HloInstruction* reduce_instruction = b.AddInstruction(
+ HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
+ /*dimensions_to_reduce=*/{0}, add_func));
+ module().AddEntryComputation(b.Build());
+
+ HloEvaluator hlo_eval;
+ std::unique_ptr<Literal> result =
+ hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+}
+
+// Reducing many numbers should be fast because it doesn't create
+// intermediate Literals; the microbenchmark should finish in < 1 msec.
+void BM_ReducePrecisely(int num_iters) {
+ tensorflow::testing::StopTiming();
+ HloComputation::Builder b("BM_ReducePrecisely");
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config);
+
+ constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
+ std::vector<float> v(kNumElements, 1.0f);
+ HloInstruction* arg_instruction = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<float>(v)));
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+
+ HloComputation::Builder add_computation("add");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ add_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
+ auto add_func = module.AddEmbeddedComputation(add_computation.Build());
+
+ HloInstruction* reduce_instruction = b.AddInstruction(
+ HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
+ /*dimensions_to_reduce=*/{0}, add_func));
+ module.AddEntryComputation(b.Build());
+
+ HloEvaluator hlo_eval;
+ tensorflow::testing::StartTiming();
+ hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ tensorflow::testing::StopTiming();
+}
+
+BENCHMARK(BM_ReducePrecisely);
+
TEST_P(HloEvaluatorTest, ReduceAdd) {
HloComputation::Builder b(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a2a2c1e615..fcf9ebf5f7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -98,6 +98,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
}
+ if (instruction->opcode() == HloOpcode::kTrace) {
+ TF_RET_CHECK(instruction->operands().size() == 1)
+ << "Trace instruction should have 1 operand but sees "
+ << instruction->operands().size();
+ instruction->mutable_operand(0)->set_tracing(instruction.get());
+ }
+
TF_RET_CHECK(!proto.name().empty());
instruction->name_ = proto.name();
@@ -170,6 +177,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand);
instruction->literal_ = Literal::CreateR1U8(tag);
+ operand->set_tracing(instruction.get());
return instruction;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index a94ba145df..80f8408244 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -928,6 +928,13 @@ class HloInstruction {
const HloSharding& sharding_or_default(const HloSharding& default_) const {
return sharding_ ? *sharding_ : default_;
}
+ // Returns the sharding unique device, if any.
+ tensorflow::gtl::optional<int64> sharding_unique_device() const {
+ if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) {
+ return tensorflow::gtl::optional<int64>();
+ }
+ return sharding_->UniqueDevice().ValueOrDie();
+ }
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index fa5dcb0b36..54c34ce116 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -313,6 +313,27 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
return FailedPrecondition("send/recv shapes do not match");
}
+ const HloModule* send_module = channel.send->parent()->parent();
+ const HloModule* send_done_module = channel.send_done->parent()->parent();
+ if (send_module != send_done_module) {
+ return FailedPrecondition(
+ "send and send-done (channel=%lld) must be on the same device: %lld "
+ "vs. %lld",
+ channel.id, GetModuleId(send_module), GetModuleId(send_done_module));
+ }
+ const HloModule* recv_module = channel.recv->parent()->parent();
+ const HloModule* recv_done_module = channel.recv_done->parent()->parent();
+ if (recv_module != recv_done_module) {
+ return FailedPrecondition(
+ "recv and recv-done (channel=%lld) must be on the same device: %lld "
+ "vs. %lld",
+ channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module));
+ }
+ if (send_module == recv_module) {
+ return FailedPrecondition(
+ "send and recv (channel=%lld) must be on different devices: %lld",
+ channel.id, GetModuleId(send_module));
+ }
}
// Check if channel instructions are used only in allowed computations.
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 18d406f370..06204acbca 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -94,6 +94,10 @@ class HloSharding {
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
+ // Checks whether device is a reserved device number. A reserved device number
+ // has usually a special meaning, with dedicated handling logic.
+ static bool IsReservedDevice(int64 device) { return device < 0; }
+
OpSharding ToProto() const;
string ToString() const;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index 2a282f3be7..ec04239b4f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -762,7 +763,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
fake_argv_storage.push_back("");
for (const auto& it : options) {
// Skip options the XLA backend itself consumes.
- if (!tensorflow::StringPiece(it.first).starts_with("xla_")) {
+ if (!tensorflow::str_util::StartsWith(it.first, "xla_")) {
if (it.second.empty()) {
fake_argv_storage.push_back(it.first);
} else {
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index f15117f45c..49ec38eb62 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -53,16 +53,8 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) {
instruction->opcode() == HloOpcode::kTranspose;
}
-// Returns true if `a` is a broadcast instruction to target shape `shape` and
-// its operand is a scalar.
-bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) {
- return a->opcode() == HloOpcode::kBroadcast &&
- ShapeUtil::SameDimensions(a->shape(), shape) &&
- ShapeUtil::IsScalar(a->operand(0)->shape());
-}
-
-// Returns true iff `instruction` can change its shape simply by adjusting
-// metadata.
+// Returns true if `instruction` can change its shape simply by adjusting
+// metadata or if `instruction` is a broadcast of a scalar value.
bool CanTriviallyChangeShape(const HloInstruction* instruction) {
// NOTE: Technically a sequence of reshape(reshape(constant)) is also
// trivially reshapable, so we might be tempted to simply recurse if
@@ -97,19 +89,30 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) {
return true;
}
+ // A broadcase of scalar can trivially change its shape.
+ if (instruction->opcode() == HloOpcode::kBroadcast &&
+ ShapeUtil::IsScalar(instruction->operand(0)->shape())) {
+ return true;
+ }
+
return false;
}
-// Finds the first non-scalar operand of an instruction that is a non-trivial
-// reshape or transpose. Returns the operand if it is found or nullptr if not
-// found.
+// Returns true iff `instruction` is a reshape/transpose instruction for which
+// a shape change is nontrivial.
+bool IsNontrivialReshape(const HloInstruction* instruction) {
+ return !ShapeUtil::IsScalar(instruction->shape()) &&
+ IsReshapeOrTranspose(instruction) &&
+ !CanTriviallyChangeShape(instruction->operand(0));
+}
+
+// Finds the first operand of an instruction that is a non-trivial reshape or
+// transpose. Returns such an operand or nullptr if not found.
HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand(
const HloInstruction* hlo) {
for (HloInstruction* operand : hlo->operands()) {
- if (!ShapeUtil::IsScalar(operand->shape()) &&
- IsReshapeOrTranspose(operand) &&
- !CanTriviallyChangeShape(operand->operand(0))) {
- VLOG(5) << "Found first non-scalar and non-trivial reshape operand of "
+ if (IsNontrivialReshape(operand)) {
+ VLOG(5) << "Found first non-trivial reshape operand of "
<< hlo->ToString(HloPrintOptions().set_print_metadata(false))
<< ":\n\t"
<< operand->ToString(HloPrintOptions().set_print_metadata(false));
@@ -119,7 +122,7 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand(
return nullptr;
}
-// Returns whether `a` and `b` are equivalent for the purposes of this pass.
+// Returns whether `a` and `b` are equivalent reshapes/transposes.
bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
if (a->opcode() != b->opcode() ||
!ShapeUtil::SameDimensions(a->shape(), b->shape())) {
@@ -136,85 +139,14 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
}
}
-// Returns true if all operands of `instruction` can easily change shape.
-// Operands can easily change shape if they are all reshapes/transposes to and
-// from the same shape. Additionally, operands like constant, rng, and any
-// scalar change shape with only an adjustment of metadata.
-bool AllOperandsHaveEasyShapeChanges(
- const HloInstruction* instruction,
- const HloInstruction* first_reshape_operand) {
- auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
- VLOG(3) << "** Checking whether all operands have easy shape changes: "
- << instruction->ToString(print_no_metadata);
- // Check whether all operands:
- // 0. Have the same dimensions as the output -- if not, it may be
- // implicitly broadcast, which can confound the movement's
- // correctness.
- //
- // And one of the following:
- // 1. Are reshapes or transposes that have the same input and
- // output shapes as all other reshaped or transposed operands.
- // or
- // 2. Are one of kConstant, kRng, and scalars that can change shape
- // trivially,
- // or
- // 3. Are broadcast with a scalar operand.
- for (const HloInstruction* operand : instruction->operands()) {
- if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
- VLOG(5) << "Operand shape differs from output shape; may be "
- "implicitly broadcast, so preventing "
- "movement\n\toperand: "
- << operand->ToString(print_no_metadata) << "\n\tinstruction: "
- << instruction->ToString(print_no_metadata);
- return false;
- }
-
- // Skip the rest checks if the current operand is first_reshape_operand
- // itself.
- if (first_reshape_operand == operand) {
- continue;
- }
-
- if (AreEquivalentReshapes(first_reshape_operand, operand)) {
- VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
- << first_reshape_operand->ToString(print_no_metadata)
- << "\n\toperand: " << operand->ToString(print_no_metadata);
- continue;
- }
-
- if (CanTriviallyChangeShape(operand)) {
- VLOG(5) << "Operand can trivially change shape: "
- << operand->ToString(print_no_metadata);
- continue;
- }
-
- if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) {
- VLOG(5) << "Broadcast scalar to shape: "
- << operand->ToString(print_no_metadata);
- continue;
- }
-
- // TODO(someone): Look into supporting general ops for the operands as
- // well.
- VLOG(5) << "Operand is neither equalivant to the first Reshape operand"
- "nor can trivially change shape: "
- << operand->ToString(print_no_metadata);
- return false;
- }
-
- VLOG(3) << "All operands have easy shape changes: "
- << instruction->ToString(print_no_metadata);
- return true;
-}
-
// This function is called once we've decided to sink reshape/transpose operands
// across an instruction. It returns an updated `operand` with a shape that
// plays nicely with `new_operand_shape`; either it has the same shape (of the
// correct type), or it is a scalar that may be implicitly broadcast.
-HloInstruction* UpdateOperand(HloComputation* computation,
- const HloInstruction* first_reshape_operand,
+HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand,
const Shape& new_operand_shape,
HloInstruction* operand) {
+ HloComputation* computation = operand->parent();
const PrimitiveType element_type = operand->shape().element_type();
const Shape new_shape =
ShapeUtil::ChangeElementType(new_operand_shape, element_type);
@@ -245,42 +177,24 @@ HloInstruction* UpdateOperand(HloComputation* computation,
VLOG(5) << "Using existing operand of kReshape or kTranspose";
return operand->mutable_operand(0);
}
- case HloOpcode::kBroadcast:
- CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape()));
- VLOG(5) << "Changing broadcast";
- return computation->AddInstruction(
+ case HloOpcode::kBroadcast: {
+ CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape()));
+ HloInstruction* inst = computation->AddInstruction(
operand->CloneWithNewOperands(new_shape, operand->operands()));
+ VLOG(5) << "Changing broadcast from " << operand->ToString() << " to "
+ << inst->ToString();
+ return inst;
+ }
default:
LOG(FATAL) << "Unexpected operand opcode during update: " << operand;
}
}
-// Try to sink any reshape or transpose operands of `instruction` across it. We
-// do so if `instruction` is elementwise and all operands are either equivalent
-// reshapes/transposes or are trivially reshapable.
-StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
- HloInstruction* instruction) {
- // Only perform sinks for live elementwise instructions with operands.
- const bool is_dead = instruction->user_count() == 0 &&
- instruction != computation->root_instruction();
- if (!instruction->IsElementwise() || instruction->operands().empty() ||
- is_dead) {
- return false;
- }
-
- // Only perform sinks if there are any nontrivial reshape/transpose operands.
- const HloInstruction* first_reshape_operand =
- FirstNonScalarAndNonTrivialReshapeOperand(instruction);
- if (!first_reshape_operand) {
- return false;
- }
-
- // Only perform sinks if all operands can easily change shape.
- if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) {
- return false;
- }
-
+// Actually performs the reshape-move transformation -- that is, sinks the
+// reshape or transpose operands of `instruction` across it.
+StatusOr<bool> PerformSinkReshapeOrTranspose(
+ HloInstruction* instruction, const HloInstruction* first_reshape_operand) {
auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
// At this point we've decided to sink reshape/transpose operands.
const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape();
@@ -301,8 +215,8 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
}
VLOG(3) << "Updating operand #" << i << ": "
<< operands[i]->ToString(print_no_metadata);
- operands[i] = UpdateOperand(computation, first_reshape_operand,
- new_operand_shape, operands[i]);
+ operands[i] =
+ UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]);
}
if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
@@ -314,6 +228,7 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
*shape->mutable_layout() = new_operand_shape.layout();
}
}
+ HloComputation* computation = instruction->parent();
HloInstruction* new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
@@ -348,6 +263,141 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
return true;
}
+// Returns true if the instruction is a reshape-move candidate.
+//
+// An instruction is a reshape-move candidate if the instruction is elementwise,
+// has at least one nontrivial reshape/transpose operand, and its operands are
+// either trivially reshapable or are equivalent nontrivial reshapes/transposes.
+bool IsReshapeMoveCandidate(HloInstruction* instruction) {
+ auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
+ VLOG(5) << "** Checking instruction: "
+ << instruction->ToString(print_no_metadata);
+
+ // Only perform reshape-move for live elementwise instructions with operands.
+ const bool is_dead = instruction->user_count() == 0 &&
+ instruction != instruction->parent()->root_instruction();
+ if (!instruction->IsElementwise() || instruction->operands().empty() ||
+ is_dead) {
+ return false;
+ }
+
+ // Check whether all operands:
+ // 0. Have the same dimensions as the output -- if not, they may be
+ // implicitly broadcast, which can confound the movement's
+ // correctness.
+ //
+ // And one of the following:
+ // 1. Are reshapes or transposes that have the same input and
+ // output shapes as all other reshaped or transposed operands.
+ // or
+ // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars
+ // that can change shape trivially.
+ const HloInstruction* first_reshape_operand = nullptr;
+ for (const HloInstruction* operand : instruction->operands()) {
+ if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
+ VLOG(5) << "Operand shape differs from output shape; may be "
+ "implicitly broadcast, so preventing "
+ "movement\n\toperand: "
+ << operand->ToString(print_no_metadata) << "\n\tinstruction: "
+ << instruction->ToString(print_no_metadata);
+ return false;
+ }
+
+ if (CanTriviallyChangeShape(operand)) {
+ VLOG(5) << "Operand can trivially change shape: "
+ << operand->ToString(print_no_metadata);
+ continue;
+ }
+
+ if (!IsNontrivialReshape(operand)) {
+ VLOG(5) << "Operand can't trivially change shape: "
+ << operand->ToString(print_no_metadata);
+ return false;
+ }
+
+ if (first_reshape_operand == nullptr) {
+ first_reshape_operand = operand;
+ VLOG(5) << "First reshape operand "
+ << operand->ToString(print_no_metadata);
+ } else if (AreEquivalentReshapes(first_reshape_operand, operand)) {
+ VLOG(5)
+ << "Operand is an equivalent reshape of the first reshape operand "
+ << operand->ToString(print_no_metadata);
+ } else {
+ // TODO(someone): Look into supporting general ops for the operands as
+ // well.
+ VLOG(5) << "Operand is a reshape but is not equivalent to the first "
+ "Reshape operand"
+ << operand->ToString(print_no_metadata);
+ return false;
+ }
+ }
+
+ if (first_reshape_operand) {
+ VLOG(5) << "All operands have easy shape changes: "
+ << instruction->ToString(print_no_metadata);
+ }
+
+ return first_reshape_operand != nullptr;
+}
+
+// Reshape-moves all qualifying instructions in reshape_candidates. Returns
+// true if it makes changes.
+//
+// `reshape_candidates` is a set of HloInstructions with nontrivial reshape
+// operands, and a instruction in the set can be reshape-moved iff all the users
+// of its nontrivial reshape operands can also be reshaped-moved.
+//
+// The algorithm here iteratively finds the nontrivial operands with users that
+// are outside the set of `reshape_candidates`, and removes their users from
+// `reshape_candidates`, until either `reshape_candidates` becomes empty or none
+// of the remaining nontrivial operands have users outside `reshape_candidates`.
+// In the later case, all the remaining instructions in `reshape_candidates`
+// are reshape-moved and the routine returns true.
+StatusOr<bool> TryReshapeMoveOnCandidates(
+ HloInstructionSet* reshape_candidates) {
+ bool removed = true;
+ while (!reshape_candidates->empty() && removed) {
+ if (VLOG_IS_ON(5)) {
+ for (const HloInstruction* instruction : *reshape_candidates) {
+ VLOG(5) << "candidate " << instruction->ToString();
+ }
+ }
+ ConstHloInstructionSet nontrivial_operands;
+ for (const HloInstruction* instruction : *reshape_candidates) {
+ for (const auto* operand : instruction->operands()) {
+ if (IsNontrivialReshape(operand)) {
+ nontrivial_operands.insert(operand);
+ }
+ }
+ }
+
+ removed = false;
+ for (auto operand : nontrivial_operands) {
+ if (c_any_of(operand->users(), [&](HloInstruction* user) {
+ return !reshape_candidates->count(user);
+ })) {
+ for (auto* user : operand->users()) {
+ removed |= reshape_candidates->erase(user) > 0;
+ }
+ }
+ }
+ }
+
+ if (reshape_candidates->empty()) {
+ return false;
+ }
+ for (HloInstruction* instruction : *reshape_candidates) {
+ const HloInstruction* first_reshape_operand =
+ FirstNonScalarAndNonTrivialReshapeOperand(instruction);
+ TF_ASSIGN_OR_RETURN(
+ bool did_change,
+ PerformSinkReshapeOrTranspose(instruction, first_reshape_operand));
+ CHECK(did_change);
+ }
+ return true;
+}
+
} // namespace
StatusOr<bool> ReshapeMover::Run(HloModule* module) {
@@ -355,11 +405,15 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) {
VLOG(2) << "Pre ReshapeMover HLO:";
XLA_VLOG_LINES(2, module->ToString());
for (auto* comp : module->MakeNonfusionComputations()) {
- for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
- TF_ASSIGN_OR_RETURN(bool did_change,
- TrySinkReshapeOrTranspose(comp, instruction));
- changed |= did_change;
+ HloInstructionSet reshape_candidates;
+ for (HloInstruction* instruction : comp->instructions()) {
+ if (IsReshapeMoveCandidate(instruction)) {
+ reshape_candidates.insert(instruction);
+ }
}
+ TF_ASSIGN_OR_RETURN(bool did_change,
+ TryReshapeMoveOnCandidates(&reshape_candidates));
+ changed |= did_change;
}
VLOG(2) << "Post ReshapeMover HLO:";
XLA_VLOG_LINES(2, module->ToString());
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 4e0a0a8832..094f7319f4 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -564,15 +564,15 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) {
const string hlo_string = R"(
HloModule TransposeMulInversedTransposeModule
ENTRY TransposeMulInversedTranspose {
- src0 = f32[1,20,8,32]{3,2,1,0} parameter(0)
- transpose0 = f32[1,8,20,32]{3,2,1,0} transpose(src0), dimensions={0,2,1,3}
+ src0 = f32[20,8]{1,0} parameter(0)
+ transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0}
src1 = f32[] parameter(1)
- broadcast0 = f32[1,8,20,32]{3,2,1,0} broadcast(src1), dimensions={}
- ROOT multiply0 = f32[1,8,20,32]{3,2,1,0} multiply(transpose0, broadcast0)
+ broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={}
+ ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0)
}
)";
- ParseAndVerifyModule(hlo_string.c_str());
+ ParseAndVerifyModule(hlo_string);
TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module()));
EXPECT_TRUE(changed);
@@ -580,5 +580,75 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) {
op::Transpose(op::Multiply()));
}
+TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) {
+ const string hlo_string = R"(
+ HloModule ReshapeWithUsersOutsideCandidates
+ ENTRY ReshapeWithMultipleUsers {
+ param0 = f32[20,8]{1,0} parameter(0)
+ reshape0 = f32[8,20]{1,0} reshape(param0)
+ param1 = f32[] parameter(1)
+ broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={}
+ param2 = f32[20,8]{1,0} parameter(2)
+ reshape1 = f32[8,20]{1,0} reshape(param2)
+ param3 = f32[20,8]{1,0} parameter(3)
+ reshape2 = f32[8,20]{1,0} reshape(param3)
+ param4 = f32[8,20]{1,0} parameter(4)
+ add0 = f32[8,20]{1,0} add(reshape0, broadcast0)
+ add1 = f32[8,20]{1,0} add(reshape0, reshape1)
+ add2 = f32[8,20]{1,0} add(reshape1, param4)
+ ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0},
+ f32[8,20]{1,0}) tuple(add0, add1, add2)
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) {
+ const string hlo_string = R"(
+ HloModule ReshapeNoUsersOutsideCandidates1
+ ENTRY ReshapeWithMultipleUsers1 {
+ param0 = f32[20,8]{1,0} parameter(0)
+ reshape0 = f32[8,20]{1,0} reshape(param0)
+ param1 = f32[] parameter(1)
+ broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={}
+ param2 = f32[20,8]{1,0} parameter(2)
+ reshape1 = f32[8,20]{1,0} reshape(param2)
+ param3 = f32[20,8]{1,0} parameter(3)
+ reshape2 = f32[8,20]{1,0} reshape(param3)
+ add0 = f32[8,20]{1,0} add(reshape0, broadcast0)
+ add1 = f32[8,20]{1,0} add(reshape0, reshape1)
+ add2 = f32[8,20]{1,0} add(reshape1, reshape2)
+ ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0},
+ f32[8,20]{1,0}) tuple(add0, add1, add2)
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module()));
+ EXPECT_TRUE(changed);
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
+ op::Tuple(op::Reshape(), op::Reshape(), op::Reshape()));
+}
+
+TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) {
+ const string hlo_string = R"(
+ HloModule ReshapeNoUsersOutsideCandidates2
+ ENTRY ReshapeWithMultipleUsers2 {
+ param0 = f32[20,8]{1,0} parameter(0)
+ reshape0 = f32[8,20]{1,0} reshape(param0)
+ ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0)
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module()));
+ EXPECT_TRUE(changed);
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
+ op::Reshape(op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ca8071b7bb..ec883a6cf3 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -409,6 +409,37 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
return std::move(executables);
}
+StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
+ const std::vector<const HloModuleProto*>& module_protos,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ Backend* backend,
+ std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
+ DeviceMemoryAllocator* device_allocator) {
+ VLOG(1) << Printf("BuildExecutable on service %p", this);
+
+ VLOG(1) << "Computations:";
+ for (const HloModuleProto* proto : module_protos) {
+ VLOG(1) << proto->name();
+ }
+
+ CHECK_EQ(module_protos.size(), module_configs.size());
+ std::vector<std::unique_ptr<HloModule>> modules;
+ for (int64 i = 0; i < module_protos.size(); ++i) {
+ const HloModuleProto* proto = module_protos[i];
+ const HloModuleConfig& config = *module_configs[i];
+ TF_ASSIGN_OR_RETURN(auto module,
+ HloModule::CreateFromProto(*proto, config));
+ modules.push_back(std::move(module));
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<std::unique_ptr<Executable>> executables,
+ backend->compiler()->Compile(std::move(modules), std::move(executors),
+ device_allocator));
+
+ return std::move(executables);
+}
+
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
@@ -703,6 +734,47 @@ tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg,
return computation->SetReturnValue(arg->operand());
}
+StatusOr<std::vector<perftools::gputools::StreamExecutor*>>
+Service::GetExecutors(const ExecutionOptions& execution_options,
+ int64 requests_size, int64 request_index) const {
+ if (execution_options.device_handles().empty()) {
+ return FailedPrecondition(
+ "device handles must be given to execute parallel computations");
+ }
+ if (requests_size > 1 && execution_options.device_handles_size() > 1) {
+ return InvalidArgument(
+ "Parallel requests with multiple device handles is not supported. "
+ "Found %lld parallel requests, with request %lld containing %d device "
+ "handles.",
+ requests_size, request_index, execution_options.device_handles_size());
+ }
+ std::vector<perftools::gputools::StreamExecutor*> executors;
+ for (const auto& device_handle : execution_options.device_handles()) {
+ TF_ASSIGN_OR_RETURN(auto replicas,
+ Replicas(*execute_backend_, device_handle));
+ se::StreamExecutor* executor = replicas[0];
+ CHECK(executor != nullptr);
+ executors.push_back(executor);
+ }
+ return executors;
+}
+
+StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
+ const ExecutionOptions& execution_options,
+ tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments) {
+ // Resolve the allocations for the arguments of the computation, and create
+ // a vector of device memory offsets for the arguments from the allocations.
+ // In the case of partitioned computations, assume all arguments go on the
+ // zeroth core.
+ TF_ASSIGN_OR_RETURN(
+ auto replicas,
+ Replicas(*execute_backend_, execution_options.device_handles(0)));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
+ ResolveAndValidateArguments(arguments, replicas));
+ return replicated_arguments;
+}
+
tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
ExecuteParallelResponse* result) {
VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString();
@@ -731,26 +803,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
// is one of the executors to run the replicated computation.
const ExecutionOptions& execution_options =
arg->requests(i).execution_options();
- if (execution_options.device_handles().empty()) {
- return FailedPrecondition(
- "device handles must be given to execute parallel computations");
- }
- if (arg->requests_size() > 1 &&
- execution_options.device_handles_size() > 1) {
- return InvalidArgument(
- "Parallel requests with multiple device handles is not supported. "
- "Found %d parallel requests, with request %lld containing %d device "
- "handles.",
- arg->requests_size(), i, execution_options.device_handles_size());
- }
- std::vector<perftools::gputools::StreamExecutor*> executors;
- for (const auto& device_handle : execution_options.device_handles()) {
- TF_ASSIGN_OR_RETURN(auto replicas,
- Replicas(*execute_backend_, device_handle));
- se::StreamExecutor* executor = replicas[0];
- CHECK(executor != nullptr);
- executors.push_back(executor);
- }
+
+ // Get the executors.
+ TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
+ arg->requests_size(), i));
// Resolve the UserComputation object associated with the requested
// computation and compute the program shape.
@@ -767,16 +823,9 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
- // Resolve the allocations for the arguments of the computation, and create
- // a vector of device memory offsets for the arguments from the allocations.
- // In the case of partitioned computations, assume all arguments go on the
- // zeroth core.
- TF_ASSIGN_OR_RETURN(
- auto replicas,
- Replicas(*execute_backend_, execution_options.device_handles(0)));
- TF_ASSIGN_OR_RETURN(
- std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
- ResolveAndValidateArguments(request.arguments(), replicas));
+ // Get the replicated arguments.
+ TF_ASSIGN_OR_RETURN(auto replicated_arguments,
+ GetArguments(execution_options, request.arguments()));
// Create an HloModuleConfig object for the computation, given the shape of
// the program and the argument allocations. Here, we care only about the
@@ -839,7 +888,103 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
tensorflow::Status Service::ExecuteGraphParallel(
const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) {
- return Unimplemented("execute-graph-parallel is not yet implemented");
+ VLOG(1) << "running execute-graph-parallel request";
+
+ std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
+ std::vector<std::vector<perftools::gputools::StreamExecutor*>> all_executors;
+ std::vector<const HloModuleProto*> module_protos;
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
+ std::vector<string> computation_names;
+ std::vector<DeviceHandle> device_handles;
+
+ int num_requested_devices =
+ std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
+ [](int a, const ExecuteGraphRequest& r) -> int {
+ return a + r.execution_options().device_handles_size();
+ });
+ if (num_requested_devices * options_.number_of_replicas() >
+ execute_backend_->device_count()) {
+ return FailedPrecondition(
+ "there are not enough stream executors to execute %d computations",
+ num_requested_devices);
+ }
+
+ for (int64 i = 0; i < arg->requests_size(); ++i) {
+ // Get the stream executor for the i'th computation. This stream executor
+ // is one of the executors to run the replicated computation.
+ const ExecutionOptions& execution_options =
+ arg->requests(i).execution_options();
+ const ExecuteGraphRequest& request = arg->requests(i);
+ TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
+ TF_RET_CHECK(request.computation().has_program_shape())
+ << "programe shape may not be empty";
+
+ // Get the executors.
+ TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
+ arg->requests_size(), i));
+
+ // Get the replicated arguments.
+ TF_ASSIGN_OR_RETURN(auto replicated_arguments,
+ GetArguments(execution_options, request.arguments()));
+
+ // Create an HloModuleConfig object for the computation, given the shape of
+ // the program and the argument allocations. Here, we care only about the
+ // shapes of the arguments, so, it is sufficient to use the arguments of
+ // replica 0.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(request.computation().program_shape(),
+ replicated_arguments.front(),
+ request.execution_options(),
+ /*user_computation=*/nullptr));
+ VLOG(3)
+ << "ExecuteGraphParallel created HloModuleConfig computation layout: "
+ << module_config->entry_computation_layout().ToString();
+
+ // Adds to the vectors to build and execute the computations after the loop.
+ all_arguments.push_back(replicated_arguments);
+ all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
+ module_protos.push_back(&request.computation());
+ module_configs.push_back(std::move(module_config));
+ computation_names.insert(computation_names.end(), executors.size(),
+ request.computation().name());
+ all_executors.push_back(executors);
+ device_handles.insert(device_handles.end(),
+ execution_options.device_handles().begin(),
+ execution_options.device_handles().end());
+ }
+
+ // Build the HloModules and compile to generate the executables.
+ //
+ // TODO(jlebar): There's currently no way to pass a device allocator to
+ // ExecuteGraphParallel, so we have to pass a null device_allocator below.
+ TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
+ BuildExecutables(module_protos, std::move(module_configs),
+ execute_backend_.get(), all_executors,
+ /*device_allocator=*/nullptr));
+ std::vector<Executable*> executable_ptrs;
+ executable_ptrs.reserve(executables.size());
+ for (const auto& executable : executables) {
+ executable_ptrs.push_back(executable.get());
+ }
+
+ // Execute the generated executables in parallel and return the device
+ // handles for each computation's output.
+ ExecutionProfile profile;
+ TF_ASSIGN_OR_RETURN(
+ std::vector<GlobalDataHandle> outputs,
+ ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
+ execute_backend_.get(), device_handles,
+ computation_names, &profile));
+ for (const GlobalDataHandle& output : outputs) {
+ ExecuteResponse response;
+ *response.mutable_output() = output;
+ *response.mutable_profile() = profile;
+ *result->add_responses() = response;
+ }
+
+ VLOG(1) << "successfully completed 'execute-graph-parallel' request";
+ return tensorflow::Status::OK();
}
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
@@ -872,6 +1017,20 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg,
*parallel_arg.add_requests() = *arg;
ExecuteParallelResponse parallel_result;
TF_RETURN_IF_ERROR(ExecuteParallel(&parallel_arg, &parallel_result));
+ return PickParallelResponse(parallel_result, result);
+}
+
+tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) {
+ ExecuteGraphParallelRequest parallel_arg;
+ *parallel_arg.add_requests() = *arg;
+ ExecuteParallelResponse parallel_result;
+ TF_RETURN_IF_ERROR(ExecuteGraphParallel(&parallel_arg, &parallel_result));
+ return PickParallelResponse(parallel_result, result);
+}
+
+tensorflow::Status Service::PickParallelResponse(
+ const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) {
// The "result device" selection is a bit hacky, but better than assuming it
// is device 0. We have b/76035356 for restructuring the client API to clean
// up the current asymmetries and support more functionalities.
@@ -999,8 +1158,14 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
if (!arg->has_computation()) {
return InvalidArgument("computations may not be empty");
}
+ if (!arg->computation().has_program_shape()) {
+ return InvalidArgument("programe shape may not be empty");
+ }
- // TODO(b/74197823): Handle partitioning.
+ // If we received multiple device handles, we must partition the module.
+ if (arg->execution_options().device_handles_size() > 1) {
+ return ExecuteOneToN(arg, result);
+ }
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
SingleComputationDeviceHandle()));
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index ebe4a2e043..e09d58bbe7 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -278,6 +278,20 @@ class Service : public ServiceInterface {
const ExecutionOptions& execution_options,
const UserComputation* user_computation = nullptr);
+ // Picks a parallel response and fills the result.
+ Status PickParallelResponse(const ExecuteParallelResponse& parallel_result,
+ ExecuteResponse* result);
+
+ // Prepare the executors for executing parallel.
+ StatusOr<std::vector<perftools::gputools::StreamExecutor*>> GetExecutors(
+ const ExecutionOptions& execution_options, int64 requests_size,
+ int64 request_index) const;
+
+ // Prepare the arguments for executing parallel.
+ StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
+ const ExecutionOptions& execution_options,
+ tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
+
protected:
friend class LocalExecutable;
@@ -334,6 +348,12 @@ class Service : public ServiceInterface {
Backend* backend,
std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
DeviceMemoryAllocator* device_allocator);
+ StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
+ const std::vector<const HloModuleProto*>& module_protos,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ Backend* backend,
+ std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
+ DeviceMemoryAllocator* device_allocator);
// Similar to BuildExecutable, but look in the compilation cache for the
// executable first. If the executable is not in the cache, it is built and
@@ -378,6 +398,8 @@ class Service : public ServiceInterface {
// will be the result of this computation.
tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg,
ExecuteResponse* result);
+ tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result);
// Convenience function which checks whether the given shape_with_layout
// (presumably passed by the client to set the result layout) is valid for the
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index fcdb2e01fb..532f7fd5bf 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit(
HloInstruction* operand = lookup_instruction(trace_request.operand());
hlo_instruction = add_instruction(
HloInstruction::CreateTrace(trace_request.tag(), operand));
- operand->set_tracing(hlo_instruction);
break;
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index e337669aeb..6f58c20f34 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -347,10 +347,10 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -937,8 +937,8 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -977,9 +977,8 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
@@ -1444,9 +1443,9 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1566,6 +1565,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 4a9faef1dc..17c6a83c1a 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -601,6 +601,12 @@ ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
}
+XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
+ XlaBuilder* builder) {
+ return builder->ConstantLiteral(
+ use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
+}
+
template void ClientLibraryTestBase::ComputeAndCompareLiteral(
ComputationBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index be90f14c8e..52f31b0669 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -312,6 +312,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// will be converted to BF16s.
ComputationDataHandle CreateConstantFromLiteral(const Literal& literal,
ComputationBuilder* builder);
+ XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
// Creates a constant instruction with the given array. When the use_bfloat16
// flag is set but the array has float elements, the elements will be
@@ -322,6 +323,12 @@ class ClientLibraryTestBase : public ::testing::Test {
return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
}
+ template <typename NativeT>
+ XlaOp CreateConstantFromArray(const Array<NativeT>& array,
+ XlaBuilder* builder) {
+ return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
+ }
+
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
ComputationDataHandle CreateConstantFromScalar(NativeT value,
@@ -330,6 +337,12 @@ class ClientLibraryTestBase : public ::testing::Test {
builder);
}
+ template <typename NativeT>
+ XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
+ return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
+ builder);
+ }
+
// Creates a parameter instruction that wraps a given value and then stores
// into "data_handle" the global handle for that parameter.
//
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 045148cdd1..32e2f2c084 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -109,14 +111,14 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
XLA_TEST_F(ClientTest,
DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
- Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
+ XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}})));
- ComputationBuilder b(client_, TestName() + ".add");
+ XlaBuilder b(TestName() + ".add");
b.Add(b.Parameter(0, shape, "param_0"),
b.ConstantR2<int32>({{1, 2}, {3, 4}}));
TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build());
@@ -124,14 +126,14 @@ XLA_TEST_F(ClientTest,
// We can't really test parallel execution on CPU since all of the cores in a
// CPU are presented as a single device. So for now we test "parallel"
// execution on a single device.
- std::vector<Client::ComputationInstance> computation_instances;
+ std::vector<Client::XlaComputationInstance> computation_instances;
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
ASSERT_EQ(devices.size(), 1);
ExecutionOptions options = execution_options_;
*options.add_device_handles() = devices[0];
- computation_instances.push_back(Client::ComputationInstance(
+ computation_instances.push_back(Client::XlaComputationInstance(
add_with_one_arg, {const_arg.get()}, options, nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto results,
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index fb0e9c724a..a4c8a83eb1 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -38,9 +38,9 @@ using ::testing::HasSubstr;
// Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest, Concat_Nothing) {
- ComputationBuilder builder(client_, TestName());
- auto concatenated = builder.ConcatInDim({}, 0);
- StatusOr<Computation> computation_status = builder.Build();
+ XlaBuilder builder(TestName());
+ builder.ConcatInDim({}, 0);
+ StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
HasSubstr("Concatenate expects at least one argument"));
@@ -48,18 +48,18 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) {
// Concatenate with one argument works.
XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({42.0, 64.0});
- auto concatenated = builder.ConcatInDim({a}, 0);
+ builder.ConcatInDim({a}, 0);
std::vector<float> expected = {42, 64};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({});
- auto concatenated = builder.ConcatInDim({a}, 0);
+ builder.ConcatInDim({a}, 0);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -68,51 +68,51 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
// Show that we can't concatenate R0 with R0 because we can't name the dimension
// to concatenate on.
XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR0<float>(42.0);
auto b = builder.ConstantR0<float>(64.0);
- auto concatenated = builder.ConcatInDim({a, b}, 0);
- StatusOr<Computation> computation_status = builder.Build();
+ builder.ConcatInDim({a, b}, 0);
+ StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
HasSubstr("out of bounds: 0"));
}
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({});
auto b = builder.ConstantR1<float>({});
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({});
auto b = builder.ConstantR1<float>({256.0});
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
std::vector<float> expected = {256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({42.0, 64.0});
auto b = builder.ConstantR1<float>({});
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
std::vector<float> expected = {42, 64};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({42.0, 64.0});
auto b = builder.ConstantR1<float>({256.0});
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -129,20 +129,20 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
expected[253 + i] = rhs[i] = 253 + i + 1;
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>(lhs);
auto b = builder.ConstantR1<float>(rhs);
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
for (int dim : {0, 1}) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
- auto concatenated = builder.ConcatInDim({a, b}, dim);
+ builder.ConcatInDim({a, b}, dim);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
ErrorSpec(0.0001));
@@ -150,26 +150,27 @@ XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
}
XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(1, 1);
auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
auto a = builder.ConstantR2FromArray2D(*a_array);
auto b = builder.ConstantR2FromArray2D(*b_array);
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
Array2D<float> expected({
- {0}, {64},
+ {0},
+ {64},
});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(1, 1);
auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
auto a = builder.ConstantR2FromArray2D(*a_array);
auto b = builder.ConstantR2FromArray2D(*b_array);
- auto concatenated = builder.ConcatInDim({a, b}, 1);
+ builder.ConcatInDim({a, b}, 1);
Array2D<float> expected({
{0, 64},
@@ -178,22 +179,22 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
}
XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0));
auto b = builder.ConstantR2FromArray2D(*b_array);
- auto concatenated = builder.ConcatInDim({a, b}, 1);
+ builder.ConcatInDim({a, b}, 1);
ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(2, 3);
auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
auto a = builder.ConstantR2FromArray2D(*a_array);
auto b = builder.ConstantR2FromArray2D(*b_array);
- auto concatenated = builder.ConcatInDim({a, b}, 1);
+ builder.ConcatInDim({a, b}, 1);
Array2D<float> expected({
{0, 1, 2, 64, 65, 66, 67, 68},
@@ -203,22 +204,22 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
}
XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(3, 2);
auto a = builder.ConstantR2FromArray2D(*a_array);
auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2));
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(3, 2);
auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
auto a = builder.ConstantR2FromArray2D(*a_array);
auto b = builder.ConstantR2FromArray2D(*b_array);
- auto concatenated = builder.ConcatInDim({a, b}, 0);
+ builder.ConcatInDim({a, b}, 0);
Array2D<float> expected({
{0, 1},
@@ -234,16 +235,16 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
}
XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2));
auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1));
- auto concatenated = builder.ConcatInDim({a, b}, 2);
+ builder.ConcatInDim({a, b}, 2);
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array3D<float> a_array({
// 3x1x2
{{0, 1}},
@@ -258,27 +259,29 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
});
auto a = builder.ConstantR3FromArray3D(a_array);
auto b = builder.ConstantR3FromArray3D(b_array);
- auto concatenated = builder.ConcatInDim({a, b}, 2);
+ builder.ConcatInDim({a, b}, 2);
Array3D<float> expected({
- {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}},
+ {{0, 1, 6}},
+ {{2, 3, 7}},
+ {{4, 5, 8}},
});
ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({42.0});
auto b = builder.ConstantR1<float>({64.0});
auto c = builder.ConstantR1<float>({256.0});
- auto concatenated = builder.ConcatInDim({a, b, c}, 0);
+ builder.ConcatInDim({a, b, c}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array3D<float> a_array({
// 3x1x2
{{0, 1}},
@@ -300,35 +303,35 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
auto a = builder.ConstantR3FromArray3D(a_array);
auto b = builder.ConstantR3FromArray3D(b_array);
auto c = builder.ConstantR3FromArray3D(c_array);
- auto concatenated = builder.ConcatInDim({a, b, c}, 2);
+ builder.ConcatInDim({a, b, c}, 2);
Array3D<float> expected({
- {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}},
+ {{0, 1, 2, 3}},
+ {{4, 5, 6, 7}},
+ {{8, 9, 10, 11}},
});
ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({42.0});
auto b = builder.ConstantR1<float>({64.0});
auto c = builder.ConstantR1<float>({256.0});
// concatenated = (a concat b) concat c
- auto concatenated =
- builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
+ builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<float>({42.0});
auto b = builder.ConstantR1<float>({64.0});
auto c = builder.ConstantR1<float>({256.0});
// concatenated = a concat (b concat c)
- auto concatenated =
- builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
+ builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -342,7 +345,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
rhs(0, i) = i + 1024;
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2FromArray2D<float>(lhs);
auto b = builder.ConstantR2FromArray2D<float>(rhs);
builder.ConcatInDim({a, b}, 0);
@@ -363,7 +366,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
rhs(0, i) = i + 1024;
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2FromArray2D<float>(lhs);
auto b = builder.ConstantR2FromArray2D<float>(rhs);
builder.ConcatInDim({a, b}, 1);
@@ -388,7 +391,7 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
}
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2FromArray2D<float>(lhs);
auto b = builder.ConstantR2FromArray2D<float>(rhs);
builder.ConcatInDim({a, b}, 1);
@@ -404,13 +407,13 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
// Show that we can't concatenate with an opaques.
XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto opaque_shape = ShapeUtil::MakeOpaqueShape();
auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
auto x = builder.Parameter(0, r1f32, "x");
auto y = builder.Parameter(1, opaque_shape, "y");
- auto concatenated = builder.ConcatInDim({x, y}, 0);
- StatusOr<Computation> computation_status = builder.Build();
+ builder.ConcatInDim({x, y}, 0);
+ StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(
computation_status.status().ToString(),
@@ -418,23 +421,23 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
}
XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto p0 = builder.ConstantR1<bool>({true});
auto p1 = builder.ConstantR1<bool>({false});
auto p2 = builder.ConstantR1<bool>({true});
- auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0);
+ builder.ConcatInDim({p0, p1, p2}, 0);
bool expected[] = {true, false, true};
ComputeAndCompareR1<bool>(&builder, expected, {});
}
XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a0 = builder.ConstantR1<int32>({1});
auto a1 = builder.ConstantR1<int32>({2, 3});
auto a2 = builder.ConstantR1<int32>({4, 5, 6});
auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10});
- auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0);
+ builder.ConcatInDim({a0, a1, a2, a3}, 0);
std::vector<int32> expected(10);
std::iota(expected.begin(), expected.end(), 1);
@@ -442,7 +445,7 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
}
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array3D<float> arr0(9, 17, 1);
arr0.Fill(1);
@@ -462,14 +465,14 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
}
}
- ComputationDataHandle h0;
+ XlaOp h0;
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
&builder, &h0);
- ComputationDataHandle h1;
+ XlaOp h1;
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
&builder, &h1);
- auto concatenated = builder.ConcatInDim({h0, h1}, 2);
+ builder.ConcatInDim({h0, h1}, 2);
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
}
@@ -495,7 +498,7 @@ TEST_P(ConcatR2BinaryTest, DoIt) {
Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
rhs.FillUnique(1000);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a0 = builder.ConstantR2FromArray2D<int32>(lhs);
auto a1 = builder.ConstantR2FromArray2D<int32>(rhs);
builder.ConcatInDim({a0, a1}, spec.concat_dimension);
@@ -521,7 +524,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, f32_scalar, "x");
auto y = builder.Parameter(1, f32_scalar, "y");
auto mul = builder.Mul(x, y);
@@ -545,7 +548,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, x_literal->shape(), "x");
auto y = builder.Parameter(1, f32_scalar, "y");
auto z = builder.Parameter(2, f32_scalar, "z");
@@ -573,7 +576,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, x_literal->shape(), "x");
auto y = builder.Parameter(1, f32_scalar, "y");
auto z = builder.Parameter(2, f32_scalar, "y");
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 9a899b7914..0842a8918b 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -230,6 +230,43 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
}
+XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
+ ComputationBuilder builder(client_, TestName());
+ // Test cases from compiler_rt library.
+ std::vector<float> arg{0.0f,
+ 0.5f,
+ 0.99f,
+ 1.0f,
+ 1.5f,
+ 1.99f,
+ 2.0f,
+ 2.01f,
+ 2147483648.f,
+ -0.5f,
+ -0.99f,
+ -1.0f,
+ -1.5f,
+ -1.99f,
+ -2.0f,
+ -2.01f,
+ 0x1.FFFFFEp+62F,
+ 0x1.FFFFFCp+62F,
+ -0x1.FFFFFEp+62F,
+ -0x1.FFFFFCp+62F};
+ std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
+ auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<GlobalData> arg_data =
+ client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+
+ builder.ConvertElementType(arg_param, S64);
+
+ std::vector<int64> expected(arg.size());
+ for (int64 i = 0; i < arg.size(); ++i) {
+ expected[i] = static_cast<int64>(arg[i]);
+ }
+ ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
+}
+
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<uint8_t>({32, 64});
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 4f354e6aef..5f00c34002 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -18,9 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
@@ -112,10 +111,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
void TestR3Wrap() {
// Slice at dimension boundaries, but with sizes that cause indices to wrap.
RunR3<IndexT, DataT>(
- {{{1, 2}, {3, 4}, {5, 6}},
- {{7, 8}, {9, 10}, {11, 12}}},
- {0, 2, 1}, {2, 1, 2},
- {{{6, 5}}, {{12, 11}}});
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1},
+ {2, 1, 2}, {{{6, 5}}, {{12, 11}}});
}
template <typename IndexT, typename DataT>
@@ -137,9 +134,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -163,9 +160,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -189,9 +186,9 @@ class DynamicSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -282,6 +279,15 @@ XLA_TEST_F(DynamicSliceTest, Int32R3Pred) {
class DynamicUpdateSliceTest : public ClientLibraryTestBase {
protected:
template <typename IndexT, typename DataT>
+ void TestR0() {
+ // Disable algebraic simplifier, otherwise the op will be replaced by a
+ // constant.
+ execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
+ "algsimp");
+ RunR0<IndexT, DataT>(0, 123, {}, 123);
+ }
+
+ template <typename IndexT, typename DataT>
void TestR1() {
// Slice at dimension start.
RunR1<IndexT, DataT>({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {0},
@@ -342,6 +348,35 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
+ void RunR0(int input_value_int, int update_value_int,
+ const std::vector<IndexT> slice_starts, int expected_value_int) {
+ Literal input_value =
+ std::move(*Literal::CreateR0(input_value_int)
+ ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
+ Literal update_value =
+ std::move(*Literal::CreateR0(update_value_int)
+ ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
+ Literal expected_value =
+ std::move(*Literal::CreateR0(expected_value_int)
+ ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
+
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantLiteral(input_value);
+ auto update = builder.ConstantLiteral(update_value);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()});
+ }
+
+ template <typename IndexT, typename DataT>
void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
tensorflow::gtl::ArraySlice<int> update_values_int,
const std::vector<IndexT> slice_starts,
@@ -359,9 +394,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -390,9 +425,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -421,9 +456,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
- ComputationDataHandle starts;
+ XlaOp starts;
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
@@ -474,13 +509,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
// Build dynamic slice computation.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Initialize and transfer input parameter.
- ComputationDataHandle input;
+ XlaOp input;
std::unique_ptr<GlobalData> input_data =
CreateR3Parameter<T>(input_values, 0, "input_values", &builder, &input);
// Initialize and transfer update parameter.
- ComputationDataHandle update;
+ XlaOp update;
std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
update_values, 1, "update_values", &builder, &update);
auto starts = builder.ConstantR1<int32>({index, 0, 0});
@@ -500,6 +535,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
};
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0<int32, bfloat16>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32, float>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64, float>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64, float>(); }
+
// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) {
TestR1<int32, bfloat16>();
@@ -672,7 +712,7 @@ void BM_DynamicSlice(int num_iters) {
TransferManager::GetForPlatform(platform).ValueOrDie();
int device_ordinal = client->default_device_ordinal();
- ComputationBuilder builder(client, "DynamicSlice");
+ XlaBuilder builder("DynamicSlice");
// Create input as a constant: shape [1, 2, 3, 4]
auto input_literal = Literal::CreateR4(
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 3a097a01ab..d24927d22b 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -57,6 +57,11 @@ limitations under the License.
namespace xla {
namespace {
+using FuncGeneratorForType = Computation (*)(PrimitiveType,
+ ComputationBuilder*);
+
+using FuncGenerator = Computation (*)(ComputationBuilder*);
+
class ReduceTest : public ClientLibraryTestBase {
protected:
ReduceTest() {
@@ -755,53 +760,57 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
}
XLA_TEST_F(ReduceTest, VectorizedReduce_Add) {
- RunVectorizedReduceTest(CreateScalarAddComputation,
- [](float a, float b) { return a + b; },
- [](int32 a, int32 b) {
- return static_cast<int32>(static_cast<uint32>(a) +
- static_cast<uint32>(b));
- },
- [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0);
+ RunVectorizedReduceTest(
+ static_cast<FuncGeneratorForType>(CreateScalarAddComputation),
+ [](float a, float b) { return a + b; },
+ [](int32 a, int32 b) {
+ return static_cast<int32>(static_cast<uint32>(a) +
+ static_cast<uint32>(b));
+ },
+ [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0);
}
XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) {
- RunVectorizedReduceTest(CreateScalarMultiplyComputation,
- [](float a, float b) { return a * b; },
- [](int32 a, int32 b) {
- return static_cast<int32>(static_cast<uint32>(a) *
- static_cast<uint32>(b));
- },
- [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1);
+ RunVectorizedReduceTest(
+ static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation),
+ [](float a, float b) { return a * b; },
+ [](int32 a, int32 b) {
+ return static_cast<int32>(static_cast<uint32>(a) *
+ static_cast<uint32>(b));
+ },
+ [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1);
}
XLA_TEST_F(ReduceTest, VectorizedReduce_Max) {
- RunVectorizedReduceTest(CreateScalarMaxComputation,
- [](float a, float b) { return std::max(a, b); },
- [](int32 a, int32 b) { return std::max(a, b); },
- [](uint32 a, uint32 b) { return std::max(a, b); },
- std::numeric_limits<float>::min(),
- std::numeric_limits<int32>::min(),
- std::numeric_limits<uint32>::min());
+ RunVectorizedReduceTest(
+ static_cast<FuncGeneratorForType>(CreateScalarMaxComputation),
+ [](float a, float b) { return std::max(a, b); },
+ [](int32 a, int32 b) { return std::max(a, b); },
+ [](uint32 a, uint32 b) { return std::max(a, b); },
+ std::numeric_limits<float>::min(), std::numeric_limits<int32>::min(),
+ std::numeric_limits<uint32>::min());
}
XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
- RunVectorizedReduceTest(CreateScalarMinComputation,
- [](float a, float b) { return std::min(a, b); },
- [](int32 a, int32 b) { return std::min(a, b); },
- [](uint32 a, uint32 b) { return std::min(a, b); },
- std::numeric_limits<float>::max(),
- std::numeric_limits<int32>::max(),
- std::numeric_limits<uint32>::max());
+ RunVectorizedReduceTest(
+ static_cast<FuncGeneratorForType>(CreateScalarMinComputation),
+ [](float a, float b) { return std::min(a, b); },
+ [](int32 a, int32 b) { return std::min(a, b); },
+ [](uint32 a, uint32 b) { return std::min(a, b); },
+ std::numeric_limits<float>::max(), std::numeric_limits<int32>::max(),
+ std::numeric_limits<uint32>::max());
}
XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
RunVectorizedReduceTestForType<bool>(
- CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true);
+ static_cast<FuncGenerator>(CreateScalarAndComputation),
+ [](bool a, bool b) { return a && b; }, true);
}
XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
RunVectorizedReduceTestForType<bool>(
- CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false);
+ static_cast<FuncGenerator>(CreateScalarOrComputation),
+ [](bool a, bool b) { return a || b; }, false);
}
class ReduceR3ToR2Test : public ReduceTest,
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 9c317fe579..8dd24f1237 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -252,6 +252,48 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
DefaultErrorSpec());
}
+// Tests the super windowing logic w.r.t handling prime number of windows in a
+// major dimension with reduction.
+TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
+ Array4D<float> input_array(15, 15, 4, 128);
+ input_array.FillRandom(2.f, 4.f);
+
+ int win_len = 3;
+ int win_stride = 2;
+
+ const auto input_data_handle =
+ CreateConstantFromArray(input_array, &builder_);
+
+ Padding padding = Padding::kSame;
+ // Reduce only along the x and y dimensions, according to the win_len.
+ ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ auto result = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
+ DefaultErrorSpec());
+}
+
+TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
+ Array4D<float> input_array(19, 17, 8, 256);
+ input_array.FillWithMinorDimNum();
+
+ const auto input_data_handle =
+ CreateConstantFromArray(input_array, &builder_);
+
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
+
+ auto result = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
+ DefaultErrorSpec());
+}
+
// Tests a reduction function that is not a simple add/min/max/etc.
XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
Array4D<float> input_array(1, 2, 2, 1);
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 574c494c6d..69fbe98bd6 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -41,7 +41,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
Array3D<float> values(3, 3, 3);
values.FillIota(0);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR3FromArray3D<float>(values);
builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1});
@@ -54,7 +54,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) {
Array3D<float> values(3, 3, 3);
values.FillIota(0);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR3FromArray3D<float>(values);
builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1});
@@ -67,7 +67,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
Array3D<float> values(3, 3, 3);
values.FillIota(0);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR3FromArray3D<float>(values);
builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1});
@@ -77,7 +77,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
}
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
builder.Slice(original, {0, 0}, {0, 0}, {1, 1});
@@ -85,7 +85,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
}
XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
builder.Slice(original, {0, 15}, {0, 20}, {1, 1});
@@ -93,7 +93,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
}
XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
builder.Slice(original, {1, 0}, {3, 0}, {1, 1});
@@ -108,7 +108,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
}
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
builder.Slice(original, {128, 128}, {256, 256}, {1, 1});
@@ -126,7 +126,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
Array2D<float> values(1, 4096);
std::iota(values.data(), values.data() + 4096, 0.0);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1});
@@ -147,7 +147,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
}
}
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
builder.Slice(original, {0, 0}, {16, 2}, {1, 1});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -159,7 +159,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
values.FillRandom(3.14f);
auto expected = ReferenceUtil::Slice4D(
values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}});
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR4FromArray4D(values);
builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
@@ -172,7 +172,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
/*strides=*/{{1, 1, 2, 1}});
auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
*expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR4FromArray4D(values);
builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
@@ -198,7 +198,7 @@ class SliceR1Test : public ClientLibraryTestBase,
tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0);
std::iota(input.begin(), input.end(), NativeT());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto original = builder.ConstantR1<NativeT>(input);
builder.Slice(original, {spec.slice_start}, {spec.slice_limit},
{spec.slice_stride});
@@ -363,7 +363,7 @@ XLA_TEST_P(SliceR2Test, DoIt) {
Array2D<int32> input(spec.input_dim0, spec.input_dim1);
input.FillUnique();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2FromArray2DWithLayout<int32>(
input, LayoutUtil::MakeLayout(spec.layout));
builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
@@ -453,7 +453,7 @@ class SliceR4Test : public ClientLibraryTestBase,
values.FillRandom(3.14f);
auto expected = ReferenceUtil::Slice4D(
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto literal = Literal::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
auto parameter = builder.Parameter(0, literal->shape(), "p0");
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 0bc7df2a65..821432ef7d 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -23,14 +23,14 @@ namespace xla {
namespace {
-template <typename FloatT>
-void PopulateWithRandomFloatingPointData(Literal* literal,
- std::minstd_rand0* engine) {
+template <typename FloatT, typename GeneratorT>
+void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
+ std::minstd_rand0* engine) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
// Create uniform numbers between 1 and 1.125 to avoid creating denormal
// numbers.
- std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f);
+ std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f);
const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000;
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> indices) {
@@ -52,10 +52,22 @@ void PopulateWithRandomFloatingPointData(Literal* literal,
FloatT index_bias =
static_cast<FloatT>(index_product % 113 - negative_bias) /
static_cast<FloatT>(256.0f);
- return (generator(*engine) - 1.0625) + index_bias;
+ return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias;
}));
}
+template <typename FloatT>
+void PopulateWithRandomFloatingPointData(Literal* literal,
+ std::minstd_rand0* engine) {
+ PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine);
+}
+
+template <>
+void PopulateWithRandomFloatingPointData<half>(Literal* literal,
+ std::minstd_rand0* engine) {
+ PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine);
+}
+
// The standard library does not have a case for bfloat16, unsurprisingly, so we
// handle that one specially.
template <>
@@ -100,6 +112,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
case BF16:
PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine);
break;
+ case F16:
+ PopulateWithRandomFloatingPointData<half>(literal.get(), engine);
+ break;
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get(), engine);
break;
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 33d457c70b..89ce2ce797 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -18,10 +18,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) {
auto result_shape = ShapeUtil::MakeShape(S64, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int64>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int64>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int64>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int64>(&builder, 5, {});
}
@@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
builder.ConstantR0<int32>(0),
CreateScalarAddComputation(S32, &builder), {0});
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
auto result_shape = ShapeUtil::MakeShape(PRED, {});
// Create a computation for the condition: run until condition is true.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Ne(builder.ConstantR0<bool>(true), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: or condition with true.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
- auto result = builder.Or(prev, builder.ConstantR0<bool>(true));
+ builder.Or(prev, builder.ConstantR0<bool>(true));
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true));
- auto result = builder.While(condition, body, init);
+ builder.While(condition, body, init);
ComputeAndCompareR0<bool>(&builder, true, {});
}
@@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 15.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>({});
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>({});
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
}
@@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) {
Shape result_shape = ShapeUtil::MakeShape(F32, {8});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) {
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 5.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>(8, 0.125f);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>(8, 0.f);
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
// Individual elements with increase by 1/8 each time through the loop, so
// the sum will increase by 1.0. It will first be >15.5 when the elements
@@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
Shape result_shape = ShapeUtil::MakeShape(F32, {8});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 5.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>(8, 0.125f);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>(8, 0.f);
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
builder.Tuple({result});
// Individual elements with increase by 1/8 each time through the loop, so
@@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
// Create a computation for the condition.
// Repeat for N iterations.
const int N = 2;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(N), iteration);
@@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and permute the weights.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto w1 = builder.GetTupleElement(prev, 1);
auto w2 = builder.GetTupleElement(prev, 2);
auto w3 = builder.GetTupleElement(prev, 3);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
auto result = builder.While(condition, body, init);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(N);
auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
@@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
// Create a computation for the condition.
// Repeat for N iterations.
const int N = 2;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(N), iteration);
@@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
// Create a computation for the body.
// Add 1 to the iteration variable permute the weights.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto w1 = builder.GetTupleElement(prev, 1);
auto w2 = builder.GetTupleElement(prev, 2);
auto w3 = builder.GetTupleElement(prev, 3);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
@@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
std::vector<float> expected = {6.f, 6.f, 6.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
@@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR1<float>(
@@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and or the predicate with true
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto pred = builder.GetTupleElement(prev, 1);
auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true));
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple({builder.ConstantR0<int32>(0),
builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true))});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_predicate = Literal::CreateR0<bool>(true);
@@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and set the other tuple element to a
// constant.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
- auto result =
- builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
- builder.ConstantR0<int32>(7)});
+ builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
+ builder.ConstantR0<int32>(7)});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR0<int32>(7);
@@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
- Computation body2;
+ XlaComputation body2;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
@@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
@@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
@@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
@@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- auto result = builder.Tuple({out0, out1});
+ builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR1<float>(
@@ -915,18 +914,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
// Create a computation for the condition: repeat for count iterations.
auto build_condition = [this, v6s32](int count) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto prev = builder.Reshape(
builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
- {});
+ {});
builder.Gt(builder.ConstantR0<int32>(count), prev);
return builder.Build().ConsumeValueOrDie();
};
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, v6s32, "prev");
auto inc = builder.ConcatInDim(
{builder.ConstantR1<int32>({1}),
@@ -934,16 +933,15 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
builder.ConstantR0<int32>(100),
ShapeUtil::MakeShape(S32, {5}))},
0);
- auto result = builder.Add(inc, prev);
+ builder.Add(inc, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
auto while_loop = [this, &body, build_condition](int count) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
- auto result = builder.While(build_condition(count), body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(build_condition(count), body, init);
return builder.Build();
};
@@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
auto inner_result_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
- Computation inner_condition;
+ XlaComputation inner_condition;
{
- ComputationBuilder builder(client_, "inner_condition");
+ XlaBuilder builder("inner_condition");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
builder.Lt(i, builder.ConstantR0<int32>(7));
@@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
// Creates a computation for the outer loop condition:
// repeat while result < 30.
- Computation outer_condition;
+ XlaComputation outer_condition;
{
- ComputationBuilder builder(client_, "outer_condition");
+ XlaBuilder builder("outer_condition");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
builder.Lt(prev, builder.ConstantR0<int32>(30));
outer_condition = builder.Build().ConsumeValueOrDie();
@@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
// Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
// `result`.
- Computation inner_body;
+ XlaComputation inner_body;
{
- ComputationBuilder builder(client_, "inner_body");
+ XlaBuilder builder("inner_body");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
auto result = builder.GetTupleElement(params, 1);
i = builder.Add(builder.ConstantR0<int32>(1), i);
result = builder.Add(builder.ConstantR0<int32>(2), result);
- auto output = builder.Tuple({i, result});
+ builder.Tuple({i, result});
inner_body = builder.Build().ConsumeValueOrDie();
}
// Creates a computation for the outer loop: run the inner loop with i = 0.
- Computation outer_body;
+ XlaComputation outer_body;
{
- ComputationBuilder builder(client_, "outer_body");
+ XlaBuilder builder("outer_body");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
auto result = builder.While(inner_condition, inner_body, init);
- auto output = builder.GetTupleElement(result, 1);
+ builder.GetTupleElement(result, 1);
outer_body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(outer_condition, outer_body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(outer_condition, outer_body, init);
ComputeAndCompareR0<int32>(&builder, 42, {});
}
@@ -1170,18 +1167,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition_callee;
+ XlaComputation condition_callee;
{
- ComputationBuilder builder(client_, "condition_callee");
+ XlaBuilder builder("condition_callee");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
condition_callee = builder.Build().ConsumeValueOrDie();
}
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto result = builder.Call(condition_callee, {prev});
builder.GetTupleElement(result, 0);
@@ -1189,20 +1186,19 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
{scalar_s32, matrix_shape, matrix_shape, matrix_shape});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto state = builder.Parameter(0, while_shape, "state");
builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto state = builder.Parameter(0, while_shape, "state");
auto indvar = builder.GetTupleElement(state, 0);
auto input_0 = builder.GetTupleElement(state, 1);
auto input_1 = builder.GetTupleElement(state, 2);
auto output = builder.Tanh(builder.Dot(input_0, input_1));
auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
- auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
+ builder.Tuple({indvar_next, input_0, input_1, output});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
@@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) {
// Create while condition computation with 'loop_limit'.
const int32 loop_limit = 100;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
@@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) {
}
// Create while body computation with unit loop increment.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
@@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) {
auto starts = builder.ConstantR1<int32>({0, 0, 0});
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- auto result = builder.Tuple({out0, out1});
+ builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While instruction.
- ComputationBuilder builder(client, "while");
+ XlaBuilder builder("while");
auto zero = builder.ConstantR0<float>(0.0);
auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 863081d654..adc8b1d620 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -894,7 +895,7 @@ class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(StringPiece(s).contains(expected))
+ EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 0cebb49afb..bf69144ad8 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -8,6 +8,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
+load("//tensorflow:tensorflow.bzl", "if_not_windows")
py_library(
name = "contrib_py",
@@ -40,7 +41,6 @@ py_library(
"//tensorflow/contrib/estimator:estimator_py",
"//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/feature_column:feature_column_py",
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/fused_conv:fused_conv_py",
"//tensorflow/contrib/gan",
@@ -63,7 +63,6 @@ py_library(
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
- "//tensorflow/contrib/lite/python:lite",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/losses:metric_learning_py",
@@ -117,7 +116,10 @@ py_library(
"//tensorflow/contrib/kafka",
],
"//conditions:default": [],
- }),
+ }) + if_not_windows([
+ "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+ "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code
+ ]),
)
cc_library(
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index a8e05df708..1c5b00f92e 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-import-not-at-top
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import batching
from tensorflow.contrib import bayesflow
@@ -84,7 +87,8 @@ from tensorflow.contrib import tpu
from tensorflow.contrib import training
from tensorflow.contrib import util
from tensorflow.contrib.eager.python import tfe as eager
-from tensorflow.contrib.lite.python import lite
+if os.name != "nt":
+ from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph
@@ -94,6 +98,7 @@ from tensorflow.contrib.summary import summary
from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
"tensorflow.contrib.ffmpeg")
+del os
del LazyLoader
del absolute_import
diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc
index 380a652435..513d519eab 100644
--- a/tensorflow/contrib/android/asset_manager_filesystem.cc
+++ b/tensorflow/contrib/android/asset_manager_filesystem.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system_helper.h"
namespace tensorflow {
namespace {
@@ -228,9 +229,8 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
}
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
- string output(name);
- StringPiece piece(output);
- piece.Consume(prefix_);
+ StringPiece piece(name);
+ str_util::ConsumePrefix(&piece, prefix_);
return piece.ToString();
}
@@ -243,6 +243,11 @@ bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
return AAssetDir_getNextFileName(dir.get()) != NULL;
}
+Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) {
+ return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
+}
+
Status AssetManagerFileSystem::NewWritableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
return errors::Unimplemented("Asset storage is read only.");
diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h
index 665304b5ee..a87ff42ae2 100644
--- a/tensorflow/contrib/android/asset_manager_filesystem.h
+++ b/tensorflow/contrib/android/asset_manager_filesystem.h
@@ -66,6 +66,9 @@ class AssetManagerFileSystem : public FileSystem {
Status DeleteDir(const string& d) override;
Status RenameFile(const string& s, const string& t) override;
+ Status GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) override;
+
private:
string RemoveAssetPrefix(const string& name);
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD
index 608bd82722..c5a0dc1095 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/contrib/autograph/converters/BUILD
@@ -61,6 +61,7 @@ py_test(
name = "asserts_test",
srcs = ["asserts_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":test_lib",
"//tensorflow/python:client_testlib",
@@ -81,6 +82,7 @@ py_test(
name = "builtin_functions_test",
srcs = ["builtin_functions_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":test_lib",
"//tensorflow/python:client_testlib",
@@ -92,6 +94,7 @@ py_test(
size = "large",
srcs = ["call_trees_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":test_lib",
"//tensorflow/contrib/autograph/impl",
diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD
index e468176da1..54424e2647 100644
--- a/tensorflow/contrib/autograph/impl/BUILD
+++ b/tensorflow/contrib/autograph/impl/BUILD
@@ -26,6 +26,7 @@ py_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib/autograph/converters",
+ "//tensorflow/contrib/autograph/operators",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis",
"//tensorflow/contrib/autograph/utils",
@@ -38,6 +39,7 @@ py_test(
name = "api_test",
srcs = ["api_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":impl",
"//tensorflow/contrib/autograph/utils",
@@ -50,6 +52,7 @@ py_test(
name = "conversion_test",
srcs = ["conversion_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":impl",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/impl/config.py
index 543c1486e6..26326465e2 100644
--- a/tensorflow/contrib/autograph/impl/config.py
+++ b/tensorflow/contrib/autograph/impl/config.py
@@ -41,10 +41,15 @@ DEFAULT_UNCOMPILED_MODULES = set((
NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
-# TODO(mdan): Also allow controlling the generated names (for testability).
+# TODO(mdan): Also allow controlling the generated names.
+# TODO(mdan); Consolidate all internal imports into a single __ag module.
COMPILED_IMPORT_STATEMENTS = (
- 'from __future__ import print_function', 'import tensorflow as tf',
- 'from tensorflow.contrib.autograph.impl import api as '
- 'autograph_api',
- 'from tensorflow.contrib.autograph import utils as '
- 'autograph_utils')
+ 'from __future__ import print_function',
+ 'import tensorflow as tf',
+ 'from tensorflow.contrib.autograph.impl import api'
+ ' as autograph_api',
+ 'from tensorflow.contrib.autograph import utils'
+ ' as autograph_utils',
+ 'from tensorflow.contrib.autograph import operators'
+ ' as __ops',
+)
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
new file mode 100644
index 0000000000..7856c253bd
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -0,0 +1,25 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "operators",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [],
+)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
new file mode 100644
index 0000000000..c3f4cab69e
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module implements operators that we overload.
+
+Note that "operator" is used loosely here, and includes control structures like
+conditionals and loops, implemented in functional form, using for example
+closures for the body.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD
index edec5f7712..c483ff68c4 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/contrib/autograph/pyct/BUILD
@@ -66,6 +66,7 @@ py_test(
name = "compiler_test",
srcs = ["compiler_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
index d192bc7aab..83f3bafc42 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
@@ -34,6 +34,7 @@ py_test(
name = "activity_test",
srcs = ["activity_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":static_analysis",
"//tensorflow/contrib/autograph/pyct",
@@ -46,6 +47,7 @@ py_test(
name = "live_values_test",
srcs = ["live_values_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":static_analysis",
"//tensorflow/contrib/autograph/pyct",
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD
index b53fbb5c18..d3a1b94688 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/contrib/autograph/utils/BUILD
@@ -44,6 +44,7 @@ py_test(
name = "builtins_test",
srcs = ["builtins_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":utils",
"//tensorflow/python:client_testlib",
@@ -84,6 +85,7 @@ py_test(
name = "py_func_test",
srcs = ["py_func_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":utils",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
index da5e744851..a3b1b013e3 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
@@ -48,9 +48,9 @@ class BatchFeatures {
Status GetFeatureColumnSizes(int64* const num_dense_float_features,
int64* const num_sparse_float_features,
int64* const num_sparse_int_features) const {
- QCHECK_NE(num_dense_float_features, nullptr);
- QCHECK_NE(num_sparse_float_features, nullptr);
- QCHECK_NE(num_sparse_int_features, nullptr);
+ QCHECK_NE(num_dense_float_features, static_cast<int64*>(nullptr));
+ QCHECK_NE(num_sparse_float_features, static_cast<int64*>(nullptr));
+ QCHECK_NE(num_sparse_int_features, static_cast<int64*>(nullptr));
*num_dense_float_features = dense_float_feature_columns_.size();
*num_sparse_float_features = sparse_float_feature_columns_.size();
*num_sparse_int_features = sparse_int_feature_columns_.size();
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index b776307924..fae45ead5c 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -474,6 +474,8 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
+ "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h"
+ "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 35312f06b3..7bb0dc1c0f 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -8,6 +8,7 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_libs",
+ "if_not_windows",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
@@ -31,12 +32,17 @@ py_library(
],
)
+cc_library(
+ name = "lib_proto_parsing_for_dataset_ops",
+ deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]),
+)
+
tf_custom_op_library(
name = "_dataset_ops.so",
srcs = ["ops/dataset_ops.cc"],
deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] +
if_static(
- extra_deps = ["//tensorflow/core:lib_proto_parsing"],
+ extra_deps = [":lib_proto_parsing_for_dataset_ops"],
otherwise = [],
),
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 676959a900..4b50260670 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class StagingAreaOpsTest(test.TestCase):
+class PrefetchingKernelsOpsTest(test.TestCase):
def setUp(self):
self._event = threading.Event()
@@ -200,6 +200,9 @@ class StagingAreaOpsTest(test.TestCase):
sess.run(destroy_op)
+
+class PrefetchToDeviceTest(test.TestCase):
+
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
@@ -231,6 +234,37 @@ class StagingAreaOpsTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.device_count["CPU"] = 2
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@@ -248,5 +282,62 @@ class StagingAreaOpsTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.device_count["CPU"] = 2
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 98651bb568..77e23d0319 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -28,6 +28,7 @@ from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
# TODO(rohanj): Add a python class that constructs resource in the __init__
@@ -67,28 +68,77 @@ def function_buffering_resource_reset(function_buffer_resource, name=None):
# pylint: disable=protected-access
class _PrefetchToDeviceIterator(object):
- """A replacement for @{tf.data.Iterator} that prefetches to another device."""
+ """A replacement for @{tf.data.Iterator} that prefetches to another device.
- def __init__(self, input_dataset, device, buffer_size):
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ device,
+ buffer_size,
+ shared_name=None):
self._input_dataset = input_dataset
self._get_next_call_count = 0
- input_iterator = input_dataset.make_one_shot_iterator()
- input_iterator_handle = input_iterator.string_handle()
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
@function.Defun(dtypes.string)
def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, input_iterator.output_types, input_iterator.output_shapes,
- input_iterator.output_classes)
- return remote_iterator.get_next()
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+ ])
+
+ # Serialize any sparse tensors and convert result to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ ops.convert_to_tensor(t)
+ for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
+ ])
+ return nest.flatten(ret)
with ops.device(device):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
target_device=gen_dataset_ops.iterator_get_device(
- input_iterator._iterator_resource),
+ self._input_iterator._iterator_resource),
string_arg=input_iterator_handle,
- buffer_size=buffer_size)
+ buffer_size=buffer_size,
+ shared_name=shared_name)
+
+ if not self._one_shot:
+ reset_op = function_buffering_resource_reset(self._buffering_resource)
+ with ops.control_dependencies([reset_op]):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
def get_next(self, name=None):
"""See @{tf.data.Iterator.get_next}."""
@@ -113,6 +163,12 @@ class _PrefetchToDeviceIterator(object):
return ret
@property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
def output_classes(self):
return self._input_dataset.output_classes
@@ -135,13 +191,19 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset):
self._buffer_size = buffer_size if buffer_size is not None else 1
def make_one_shot_iterator(self):
- return _PrefetchToDeviceIterator(self._input_dataset, self._device,
- self._buffer_size)
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=True,
+ device=self._device,
+ buffer_size=self._buffer_size)
def make_initializable_iterator(self, shared_name=None):
- raise NotImplementedError("`prefetch_to_device()` is not currently "
- "compatible with initializable iterators. Use "
- "`make_one_shot_iterator()` instead.")
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ device=self._device,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
def _as_variant_tensor(self):
# TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 4af51bec1a..28483f4c88 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -77,7 +77,7 @@ parameter of `Estimator`.
```python
distribution = tf.contrib.distribute.MirroredStrategy()
-config = tf.estimator.RunConfig(distribute=distribution)
+config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
classifier.train(input_fn=input_fn)
```
diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py
index 9be186a724..2b49b8f4ef 100644
--- a/tensorflow/contrib/distribute/python/estimator_integration_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py
@@ -95,7 +95,7 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
# TODO(isaprykin): Work around the colocate_with error.
dnn_optimizer=adagrad.AdagradOptimizer(0.001),
linear_optimizer=adagrad.AdagradOptimizer(0.001),
- config=run_config.RunConfig(distribute=distribution))
+ config=run_config.RunConfig(train_distribute=distribution))
num_steps = 10
estimator.train(train_input_fn, steps=num_steps)
diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
index 5d6e02b4b9..00c25c7a24 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
+++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
@@ -59,7 +59,7 @@ def build_model_fn_optimizer():
def main(_):
distribution = tf.contrib.distribute.MirroredStrategy(
["/device:GPU:0", "/device:GPU:1"])
- config = tf.estimator.RunConfig(distribute=distribution)
+ config = tf.estimator.RunConfig(train_distribute=distribution)
def input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
index e714255f69..b87224251c 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
+++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
@@ -41,7 +41,7 @@ def main(args):
strategy = tf.contrib.distribute.MirroredStrategy(
['/device:GPU:0', '/device:GPU:1'])
- config = tf.estimator.RunConfig(distribute=strategy)
+ config = tf.estimator.RunConfig(train_distribute=strategy)
optimizer = tf.train.GradientDescentOptimizer(0.2)
model = tf.keras.Sequential()
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index fe80bb4df5..7644acedc9 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
from tensorflow.python.ops import variables
@@ -55,7 +56,9 @@ class Monitor(object):
def run_steps(self, num_steps=None):
step = 0
- done = False
- while done is not None and (num_steps is None or step < num_steps):
- done = self._run_step()
- step += 1
+ while num_steps is None or step < num_steps:
+ try:
+ self._run_step()
+ step += 1
+ except errors.OutOfRangeError:
+ break
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index de08eb491b..9799901483 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -454,6 +454,7 @@ cuda_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
cuda_py_test(
@@ -501,12 +502,6 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
],
shard_count = 4,
- tags = [
- "manual",
- "noasan",
- "noguitar",
- "optonly",
- ],
)
cuda_py_test(
@@ -1128,6 +1123,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
cuda_py_test(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
index 4d2f40e27f..c6c8d2cf6e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib
+from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
@@ -514,6 +515,42 @@ class _BatchReshapeTest(object):
batch_shape=new_batch_shape_ph,
validate_args=True).sample().eval()
+ def test_broadcasting_explicitly_unsupported(self):
+ old_batch_shape = [4]
+ new_batch_shape = [1, 4, 1]
+ rate_ = self.dtype([1, 10, 2, 20])
+
+ rate = array_ops.placeholder_with_default(
+ rate_,
+ shape=old_batch_shape if self.is_static_shape else None)
+ poisson_4 = poisson_lib.Poisson(rate)
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+ poisson_141_reshaped = batch_reshape_lib.BatchReshape(
+ poisson_4, new_batch_shape_ph, validate_args=True)
+
+ x_4 = self.dtype([2, 12, 3, 23])
+ x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4)
+
+ if self.is_static_shape:
+ with self.assertRaisesRegexp(NotImplementedError,
+ "too few event dims"):
+ poisson_141_reshaped.log_prob(x_4)
+ with self.assertRaisesRegexp(NotImplementedError,
+ "unexpected batch and event shape"):
+ poisson_141_reshaped.log_prob(x_114)
+ return
+
+ with self.assertRaisesOpError("too few event dims"):
+ with self.test_session():
+ poisson_141_reshaped.log_prob(x_4).eval()
+
+ with self.assertRaisesOpError("unexpected batch and event shape"):
+ with self.test_session():
+ poisson_141_reshaped.log_prob(x_114).eval()
+
class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase):
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index c7ee9b2117..3e6c35e0d6 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -115,7 +115,7 @@ class BatchReshape(distribution_lib.Distribution):
self._batch_shape_static = tensor_util.constant_value(self._batch_shape_)
if self._batch_shape_static is not None:
self._batch_shape_static = np.int32(self._batch_shape_static)
- self._runtime_assertions = make_runtime_assertions(
+ self._runtime_assertions = validate_init_args(
self._distribution,
self._batch_shape_,
validate_args,
@@ -229,7 +229,8 @@ class BatchReshape(distribution_lib.Distribution):
def _call_reshape_input_output(self, fn, x):
"""Calls `fn`, appropriately reshaping its input `x` and output."""
- with ops.control_dependencies(self._runtime_assertions):
+ with ops.control_dependencies(
+ self._runtime_assertions + self._validate_sample_arg(x)):
sample_shape, static_sample_shape = self._sample_shape(x)
old_shape = array_ops.concat([
sample_shape,
@@ -273,61 +274,142 @@ class BatchReshape(distribution_lib.Distribution):
result.set_shape(static_shape)
return result
-
-def make_runtime_assertions(
+ def _validate_sample_arg(self, x):
+ """Helper which validates sample arg, e.g., input to `log_prob`."""
+ with ops.name_scope(name="validate_sample_arg", values=[x]):
+ x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims)
+ event_ndims = (array_ops.size(self.event_shape_tensor())
+ if self.event_shape.ndims is None
+ else self.event_shape.ndims)
+ batch_ndims = (array_ops.size(self.batch_shape_tensor())
+ if self.batch_shape.ndims is None
+ else self.batch_shape.ndims)
+ expected_batch_event_ndims = batch_ndims + event_ndims
+
+ if (isinstance(x_ndims, int) and
+ isinstance(expected_batch_event_ndims, int)):
+ if x_ndims < expected_batch_event_ndims:
+ raise NotImplementedError(
+ "Broadcasting is not supported; too few event dims "
+ "(expected at least {}, saw {}).".format(
+ expected_batch_event_ndims, x_ndims))
+ ndims_assertion = []
+ elif self.validate_args:
+ ndims_assertion = [
+ check_ops.assert_greater_equal(
+ x_ndims,
+ expected_batch_event_ndims,
+ message="Broadcasting is not supported; too few event dims.",
+ name="assert_batch_and_event_ndims_large_enough"),
+ ]
+
+ if (self.batch_shape.is_fully_defined() and
+ self.event_shape.is_fully_defined()):
+ expected_batch_event_shape = np.int32(self.batch_shape.concatenate(
+ self.event_shape).as_list())
+ else:
+ expected_batch_event_shape = array_ops.concat([
+ self.batch_shape_tensor(),
+ self.event_shape_tensor(),
+ ], axis=0)
+
+ sample_ndims = x_ndims - expected_batch_event_ndims
+ if isinstance(sample_ndims, int):
+ sample_ndims = max(sample_ndims, 0)
+ if (isinstance(sample_ndims, int) and
+ x.shape[sample_ndims:].is_fully_defined()):
+ actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list())
+ else:
+ sample_ndims = math_ops.maximum(sample_ndims, 0)
+ actual_batch_event_shape = array_ops.shape(x)[sample_ndims:]
+
+ if (isinstance(expected_batch_event_shape, np.ndarray) and
+ isinstance(actual_batch_event_shape, np.ndarray)):
+ if any(expected_batch_event_shape != actual_batch_event_shape):
+ raise NotImplementedError("Broadcasting is not supported; "
+ "unexpected batch and event shape "
+ "(expected {}, saw {}).".format(
+ expected_batch_event_shape,
+ actual_batch_event_shape))
+ # We need to set the final runtime-assertions to `ndims_assertion` since
+ # its possible this assertion was created. We could add a condition to
+ # only do so if `self.validate_args == True`, however this is redundant
+ # as `ndims_assertion` already encodes this information.
+ runtime_assertions = ndims_assertion
+ elif self.validate_args:
+ # We need to make the `ndims_assertion` a control dep because otherwise
+ # TF itself might raise an exception owing to this assertion being
+ # ill-defined, ie, one cannot even compare different rank Tensors.
+ with ops.control_dependencies(ndims_assertion):
+ shape_assertion = check_ops.assert_equal(
+ expected_batch_event_shape,
+ actual_batch_event_shape,
+ message=("Broadcasting is not supported; "
+ "unexpected batch and event shape."),
+ name="assert_batch_and_event_shape_same")
+ runtime_assertions = [shape_assertion]
+ else:
+ runtime_assertions = []
+
+ return runtime_assertions
+
+
+def validate_init_args(
distribution,
batch_shape,
validate_args,
batch_shape_static):
"""Helper to __init__ which makes or raises assertions."""
- runtime_assertions = []
-
- if batch_shape.shape.ndims is not None:
- if batch_shape.shape.ndims != 1:
- raise ValueError("`batch_shape` must be a vector "
- "(saw rank: {}).".format(
- batch_shape.shape.ndims))
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_rank(
- batch_shape,
- 1,
- message="`batch_shape` must be a vector.",
- name="assert_batch_shape_is_vector"),
- ]
-
- batch_size_static = np.prod(batch_shape_static)
- dist_batch_size_static = (
- None if not distribution.batch_shape.is_fully_defined()
- else np.prod(distribution.batch_shape).value)
-
- if batch_size_static is not None and dist_batch_size_static is not None:
- if batch_size_static != dist_batch_size_static:
- raise ValueError("`batch_shape` size ({}) must match "
- "`distribution.batch_shape` size ({}).".format(
- batch_size_static,
- dist_batch_size_static))
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_equal(
- math_ops.reduce_prod(batch_shape),
- math_ops.reduce_prod(distribution.batch_shape_tensor()),
- message=("`batch_shape` size must match "
- "`distributions.batch_shape` size."),
- name="assert_batch_size"),
- ]
-
- if batch_shape_static is not None:
- if np.any(batch_shape_static < 1):
- raise ValueError("`batch_shape` elements must be positive "
- "(i.e., larger than zero).")
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_positive(
- batch_shape,
- message=("`batch_shape` elements must be positive "
- "(i.e., larger than zero)."),
- name="assert_batch_shape_positive")
- ]
-
- return runtime_assertions
+ with ops.name_scope(name="validate_init_args",
+ values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access
+ runtime_assertions = []
+
+ if batch_shape.shape.ndims is not None:
+ if batch_shape.shape.ndims != 1:
+ raise ValueError("`batch_shape` must be a vector "
+ "(saw rank: {}).".format(
+ batch_shape.shape.ndims))
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_rank(
+ batch_shape,
+ 1,
+ message="`batch_shape` must be a vector.",
+ name="assert_batch_shape_is_vector"),
+ ]
+
+ batch_size_static = np.prod(batch_shape_static)
+ dist_batch_size_static = (
+ None if not distribution.batch_shape.is_fully_defined()
+ else np.prod(distribution.batch_shape).value)
+
+ if batch_size_static is not None and dist_batch_size_static is not None:
+ if batch_size_static != dist_batch_size_static:
+ raise ValueError("`batch_shape` size ({}) must match "
+ "`distribution.batch_shape` size ({}).".format(
+ batch_size_static,
+ dist_batch_size_static))
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_equal(
+ math_ops.reduce_prod(batch_shape),
+ math_ops.reduce_prod(distribution.batch_shape_tensor()),
+ message=("`batch_shape` size must match "
+ "`distributions.batch_shape` size."),
+ name="assert_batch_size"),
+ ]
+
+ if batch_shape_static is not None:
+ if np.any(batch_shape_static < 1):
+ raise ValueError("`batch_shape` elements must be positive "
+ "(i.e., larger than zero).")
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_positive(
+ batch_shape,
+ message=("`batch_shape` elements must be positive "
+ "(i.e., larger than zero)."),
+ name="assert_batch_shape_positive")
+ ]
+
+ return runtime_assertions
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py
index 91a7aded11..34cb8d0e08 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import abc
import collections
+import functools
import weakref
from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
@@ -867,3 +868,115 @@ class Checkpoint(core_checkpointable.Checkpointable):
# initialization when executing eagerly.
self._maybe_create_save_counter()
return status
+
+
+class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """Wraps save and restore callbacks as a `SaveableObject`."""
+
+ def __init__(self, name, dtype, save_callback, restore_callback):
+ self._restore_callback = restore_callback
+ spec = saver_lib.BaseSaverBuilder.SaveSpec(
+ tensor=save_callback,
+ slice_spec="",
+ name=name,
+ dtype=dtype)
+ super(_CallbackSaveable, self).__init__(
+ save_callback, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return self._restore_callback(tensor)
+
+
+class _SplitDependency(core_checkpointable.CheckpointableBase):
+ """Looks like a regular variable while synchronizing save/restores."""
+
+ def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
+ fill_save_buffer_fn, consume_restore_buffer_fn):
+ self._save_buffer = save_buffer
+ self._restore_buffer = restore_buffer
+ self._name = name
+ self._dtype = dtype
+ self._num_components = num_components
+ self._fill_save_buffer_fn = fill_save_buffer_fn
+ self._consume_restore_buffer_fn = consume_restore_buffer_fn
+
+ def _save(self):
+ """Pull from the shared buffer, populating it if necessary."""
+ if self._name not in self._save_buffer:
+ if self._save_buffer:
+ raise AssertionError(
+ ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+ "be saved together.") % (self._name, self))
+ self._fill_save_buffer_fn(self._save_buffer)
+ return self._save_buffer.pop(self._name)
+
+ def _restore(self, tensor):
+ """Push into the shared buffer, flushing it if necessary."""
+ if self._name in self._restore_buffer:
+ raise AssertionError(
+ ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+ "be restored together.") % (self._name, self))
+ self._restore_buffer[self._name] = tensor
+ if len(self._restore_buffer) == self._num_components:
+ op = self._consume_restore_buffer_fn(self._restore_buffer)
+ self._restore_buffer.clear()
+ return op
+ else:
+ return control_flow_ops.no_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Looks to Checkpointable like a regular variable."""
+ return {
+ core_checkpointable.VARIABLE_VALUE_KEY:
+ functools.partial(_CallbackSaveable,
+ dtype=self._dtype,
+ save_callback=self._save,
+ restore_callback=self._restore)
+ }
+
+
+def split_dependency(component_names, component_dtypes,
+ fill_save_buffer_fn, consume_restore_buffer_fn):
+ """Creates multiple dependencies with a synchronized save/restore.
+
+ Useful when a single op produces `Tensor`s which should each be saved under
+ different objects, or when `Tensor`s saved with many different objects need to
+ be restored together as inputs to a single op (i.e. an object which uses a
+ single fused op may be swapped out for a subgraph of objects, and these two
+ programs are checkpoint compatible).
+
+ Args:
+ component_names: A sequence of names for the split
+ dependencies. `fill_save_buffer_fn` must add these keys to the dictionary
+ it is passed, and `consume_restore_buffer_fn` will receive a dictionary
+ with these keys.
+ component_dtypes: Data types for the `Tensor`s being saved and restored, a
+ sequence corresponding to `component_names`.
+ fill_save_buffer_fn: A function which takes an empty dictionary as an
+ argument and adds `Tensor`s with `component_names` as keys. These
+ `Tensor`s will be saved as if they were individual variables.
+ consume_restore_buffer_fn: A function which takes a dictionary with
+ `component_names` as keys mapping to restored individual `Tensor`s and
+ returns a restore op (or if executing eagerly, runs the restoration and
+ may return `None`).
+
+ Returns:
+ A dictionary mapping from names to Checkpointable objects. If one is
+ reachable from an object as a dependency, the others should be too; adding
+ dependencies on some but not all of the objects will result in errors.
+ """
+ save_buffer = {}
+ restore_buffer = {}
+ split_dependencies = {}
+ for name, dtype in zip(component_names, component_dtypes):
+ split_dependencies[name] = _SplitDependency(
+ save_buffer=save_buffer,
+ restore_buffer=restore_buffer,
+ name=name,
+ dtype=dtype,
+ num_components=len(component_names),
+ fill_save_buffer_fn=fill_save_buffer_fn,
+ consume_restore_buffer_fn=consume_restore_buffer_fn)
+ return split_dependencies
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
index 5e1b64728a..891c093a0f 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl.keras.engine import sequential
from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
@@ -69,6 +70,87 @@ class MyModel(training.Model):
return ret
+def _split_variable_closure(variable):
+ def _fill_save_buffer_fn(save_buffer):
+ save_buffer["first_half"] = variable[:2]
+ save_buffer["second_half"] = variable[2:]
+ return _fill_save_buffer_fn
+
+
+def _combine_variable_closure(variable):
+ def _consume_restore_buffer_fn(restore_buffer):
+ return variable.assign(
+ array_ops.concat([restore_buffer["first_half"],
+ restore_buffer["second_half"]],
+ axis=0))
+ return _consume_restore_buffer_fn
+
+
+class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
+
+ def __init__(self):
+ self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
+ split_dependencies = checkpointable_utils.split_dependency(
+ component_names=("first_half", "second_half"),
+ component_dtypes=(self.combined.dtype,) * 2,
+ fill_save_buffer_fn=_split_variable_closure(
+ self.combined),
+ consume_restore_buffer_fn=_combine_variable_closure(
+ self.combined))
+ for name, dep in split_dependencies.items():
+ self._track_checkpointable(dep, name=name)
+
+
+class HasRegularDeps(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+ self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class OnlyOneDep(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class SplitTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testSaveRestoreSplitDep(self):
+ save_checkpoint = checkpointable_utils.Checkpoint(
+ dep=SaveTensorSlicesAsDeps())
+ self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_checkpoint.save(checkpoint_prefix)
+
+ regular_deps = HasRegularDeps()
+ regular_restore_checkpoint = checkpointable_utils.Checkpoint(
+ dep=regular_deps)
+ regular_restore_checkpoint.restore(
+ save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half))
+ self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
+
+ one_dep = OnlyOneDep()
+ one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep)
+ status = one_dep_restore_checkpoint.restore(save_path)
+ with self.assertRaises(AssertionError):
+ # Missing the second dependency.
+ status.assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
+
+ restore_checkpoint = checkpointable_utils.Checkpoint()
+ status = restore_checkpoint.restore(save_path)
+ restore_checkpoint.dep = SaveTensorSlicesAsDeps()
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual(
+ [1., 2., 3., 4.],
+ self.evaluate(restore_checkpoint.dep.combined))
+
+
class InterfaceTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 60453006f4..99b1e098d5 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -107,16 +107,20 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
def _next_internal(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
"""
- if self._buffer_resource_handle is not None:
- with ops.device(self._device):
- ret = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=self._buffer_resource_handle,
- output_types=self._flat_output_types)
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
- else:
- return super(Iterator, self)._next_internal()
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ if self._buffer_resource_handle is not None:
+ with ops.device(self._device):
+ ret = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=self._buffer_resource_handle,
+ output_types=self._flat_output_types)
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+ else:
+ return super(Iterator, self)._next_internal()
# TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
# attributes(potential).
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
index f86331af6f..2f6cfdf31e 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
@@ -22,6 +22,7 @@ cuda_py_test(
":linear_regression",
"//tensorflow:tensorflow_py",
],
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
cuda_py_test(
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 2be62c9438..bec0329ebb 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -89,6 +89,7 @@ py_test(
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -129,6 +130,7 @@ py_test(
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -266,6 +268,7 @@ py_test(
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:prediction_keys",
"//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py
index b5e4d34dc7..dd009a6753 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -52,7 +53,9 @@ def _dnn_only_estimator_fn(
config=None):
return dnn_linear_combined.DNNLinearCombinedEstimator(
head=head_lib.regression_head(
- weight_column=weight_column, label_dimension=label_dimension),
+ weight_column=weight_column, label_dimension=label_dimension,
+ # Tests in core (from which this test inherits) test the sum loss.
+ loss_reduction=losses.Reduction.SUM),
model_dir=model_dir,
dnn_feature_columns=feature_columns,
dnn_optimizer=optimizer,
@@ -100,7 +103,9 @@ def _linear_only_estimator_fn(
partitioner=None):
return dnn_linear_combined.DNNLinearCombinedEstimator(
head=head_lib.regression_head(
- weight_column=weight_column, label_dimension=label_dimension),
+ weight_column=weight_column, label_dimension=label_dimension,
+ # Tests in core (from which this test inherits) test the sum loss.
+ loss_reduction=losses.Reduction.SUM),
model_dir=model_dir,
linear_feature_columns=feature_columns,
linear_optimizer=optimizer,
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
index 71f810acec..75e3107670 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import ops
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -41,7 +42,9 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
"""Returns a DNNEstimator that uses regression_head."""
return dnn.DNNEstimator(
head=head_lib.regression_head(
- weight_column=weight_column, label_dimension=label_dimension),
+ weight_column=weight_column, label_dimension=label_dimension,
+ # Tests in core (from which this test inherits) test the sum loss.
+ loss_reduction=losses.Reduction.SUM),
*args, **kwargs)
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 74da2cbb3f..85ef3291ba 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -178,7 +178,7 @@ def binary_classification_head(
def regression_head(weight_column=None,
label_dimension=1,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
inverse_link_fn=None,
name=None):
@@ -218,7 +218,9 @@ def regression_head(weight_column=None,
of the last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`).
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
- reduce training loss over batch. Defaults to `SUM`.
+ reduce training loss over batch and label dimension. Defaults to
+ `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by
+ `batch size * label_dimension`. See `tf.losses.Reduction`.
loss_fn: Optional loss function. Defaults to `mean_squared_error`.
inverse_link_fn: Optional inverse link function, also known as 'mean
function'. Defaults to identity.
@@ -243,7 +245,7 @@ def regression_head(weight_column=None,
def poisson_regression_head(
weight_column=None,
label_dimension=1,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
compute_full_loss=True,
name=None):
"""Creates a `_Head` for poisson regression using `tf.nn.log_poisson_loss`.
@@ -275,7 +277,9 @@ def poisson_regression_head(
of the last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`).
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
- reduce training loss over batch. Defaults to `SUM`.
+ reduce training loss over batch and label dimension. Defaults to
+ `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by
+ `batch size * label_dimension`. See `tf.losses.Reduction`.
compute_full_loss: Whether to include the constant `log(z!)` term in
computing the poisson loss. See `tf.nn.log_poisson_loss` for the full
documentation.
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 8837dfdc6c..98962ca427 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -1162,8 +1162,8 @@ class PoissonRegressionHead(test.TestCase):
# exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2),
# exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)]
# = [1.0, 3.020, 1.482]
- # sum_loss = 5.502
- expected_loss = 5.502
+ # training_loss = (1.0 + 3.020 + 1.482) / 3
+ expected_loss = 1.834
atol = 0.001
expected_train_result = b'my_train_op'
def _train_op_fn(loss):
diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py
index c63514eb68..c41996b9c6 100644
--- a/tensorflow/contrib/estimator/python/estimator/linear_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import ops
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -42,7 +43,9 @@ def _linear_estimator_fn(
"""Returns a LinearEstimator that uses regression_head."""
return linear.LinearEstimator(
head=head_lib.regression_head(
- weight_column=weight_column, label_dimension=label_dimension),
+ weight_column=weight_column, label_dimension=label_dimension,
+ # Tests in core (from which this test inherits) test the sum loss.
+ loss_reduction=losses.Reduction.SUM),
*args, **kwargs)
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 74d3d6d728..d9e5aca295 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -483,14 +483,14 @@ class MultiHeadTest(test.TestCase):
[[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32),
}
# Loss for the first head:
- # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +
- # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2
- # = 28
+ # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +
+ # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8
+ # = 3.5
# Loss for the second head:
- # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +
- # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2
- # = 74
- expected_training_loss = 28. + 74.
+ # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +
+ # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12
+ # = 6.167
+ expected_training_loss = 3.5 + 6.167
training_loss = multi_head.create_loss(
features={},
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py
index 7ba9d4ffa9..4c3879d4fc 100644
--- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py
+++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py
@@ -170,6 +170,30 @@ class ArgScopeTest(test.TestCase):
self.assertTupleEqual(args, func1_args)
self.assertDictEqual(kwargs, func1_kwargs)
+ def testNestedArgScopeObjectCreatedOutsideScopeOverridesArgScope(self):
+
+ def get_scope_object():
+ with arg_scope([func1], a=1, b=None, c=[1]) as sc:
+ return sc
+
+ scope_object = get_scope_object()
+ with arg_scope([func1], b=2, d=10):
+ with arg_scope(scope_object):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, (0,))
+ self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1]})
+
+ def testArgScopeObjectCreatedWithinScopeInheritsArgScope(self):
+ def get_scope_object():
+ with arg_scope([func1], a=1, b=None, c=[1]) as sc:
+ return sc
+
+ with arg_scope([func1], b=2, d=10):
+ with arg_scope(get_scope_object()):
+ args, kwargs = func1(0)
+ self.assertTupleEqual(args, (0,))
+ self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1], 'd': 10})
+
def testSharedArgScope(self):
func1_args = (0,)
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 9e56d3c039..461066bbb4 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -354,6 +354,7 @@ py_test(
name = "classifier_metrics_test",
srcs = ["python/eval/python/classifier_metrics_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":classifier_metrics",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 0d1afad72d..508f487722 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -31,6 +31,7 @@ __all__ = [
'add_image_comparison_summaries',
'add_gan_model_summaries',
'add_regularization_loss_summaries',
+ 'add_cyclegan_image_summaries',
]
@@ -51,14 +52,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True):
ValueError: If real and generated data aren't images.
"""
if isinstance(gan_model, namedtuples.CycleGANModel):
- saved_params = locals()
- saved_params.pop('gan_model', None)
- with ops.name_scope('cyclegan_x2y_image_summaries'):
- add_gan_model_image_summaries(gan_model.model_x2y, **saved_params)
- with ops.name_scope('cyclegan_y2x_image_summaries'):
- add_gan_model_image_summaries(gan_model.model_y2x, **saved_params)
- return
-
+ raise ValueError(
+ '`add_gan_model_image_summaries` does not take CycleGANModels. Please '
+ 'use `add_cyclegan_image_summaries` instead.')
_assert_is_image(gan_model.real_data)
_assert_is_image(gan_model.generated_data)
@@ -89,6 +85,49 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True):
add_gan_model_summaries(gan_model)
+def add_cyclegan_image_summaries(cyclegan_model):
+ """Adds image summaries for CycleGAN.
+
+ There are two summaries, one for each generator. The first image is the
+ generator input, the second is the generator output, and the third is G(F(x)).
+
+ Args:
+ cyclegan_model: A CycleGANModel tuple.
+
+ Raises:
+ ValueError: If `cyclegan_model` isn't a CycleGANModel.
+ ValueError: If generated data, generator inputs, and reconstructions aren't
+ images.
+ ValueError: If the generator input, generated data, and reconstructions
+ aren't all the same size.
+ """
+ if not isinstance(cyclegan_model, namedtuples.CycleGANModel):
+ raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was '
+ '%s' % type(cyclegan_model))
+
+ _assert_is_image(cyclegan_model.model_x2y.generator_inputs)
+ _assert_is_image(cyclegan_model.model_x2y.generated_data)
+ _assert_is_image(cyclegan_model.reconstructed_x)
+ _assert_is_image(cyclegan_model.model_y2x.generator_inputs)
+ _assert_is_image(cyclegan_model.model_y2x.generated_data)
+ _assert_is_image(cyclegan_model.reconstructed_y)
+
+ def _add_comparison_summary(gan_model, reconstructions):
+ image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) +
+ array_ops.unstack(gan_model.generated_data[:1]) +
+ array_ops.unstack(reconstructions[:1]))
+ summary.image(
+ 'image_comparison', eval_utils.image_reshaper(
+ image_list, num_cols=len(image_list)), max_outputs=1)
+
+ with ops.name_scope('x2y_image_comparison_summaries'):
+ _add_comparison_summary(
+ cyclegan_model.model_x2y, cyclegan_model.reconstructed_x)
+ with ops.name_scope('y2x_image_comparison_summaries'):
+ _add_comparison_summary(
+ cyclegan_model.model_y2x, cyclegan_model.reconstructed_y)
+
+
def add_image_comparison_summaries(gan_model, num_comparisons=2,
display_diffs=False):
"""Adds image summaries to compare triplets of images.
@@ -109,15 +148,6 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2,
ValueError: If the generator input, real, and generated data aren't all the
same size.
"""
- if isinstance(gan_model, namedtuples.CycleGANModel):
- saved_params = locals()
- saved_params.pop('gan_model', None)
- with ops.name_scope('cyclegan_x2y_image_comparison_summaries'):
- add_image_comparison_summaries(gan_model.model_x2y, **saved_params)
- with ops.name_scope('cyclegan_y2x_image_comparison_summaries'):
- add_image_comparison_summaries(gan_model.model_y2x, **saved_params)
- return
-
_assert_is_image(gan_model.generator_inputs)
_assert_is_image(gan_model.generated_data)
_assert_is_image(gan_model.real_data)
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index 45eb108586..33d51bfc21 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -65,15 +65,14 @@ def get_cyclegan_model():
return namedtuples.CycleGANModel(
model_x2y=model_x2y,
model_y2x=model_y2x,
- reconstructed_x=array_ops.zeros([3, 30, 35, 6]),
- reconstructed_y=array_ops.zeros([3, 30, 35, 6]))
+ reconstructed_x=array_ops.zeros([4, 32, 32, 3]),
+ reconstructed_y=array_ops.zeros([4, 32, 32, 3]))
class SummariesTest(test.TestCase):
- def _test_add_gan_model_image_summaries_impl(self, get_model_fn,
- expected_num_summary_ops,
- model_summaries):
+ def _test_add_gan_model_image_summaries_impl(
+ self, get_model_fn, expected_num_summary_ops, model_summaries):
summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2,
model_summaries=model_summaries)
@@ -89,8 +88,9 @@ class SummariesTest(test.TestCase):
def test_add_gan_model_image_summaries_no_model(self):
self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False)
- def test_add_gan_model_image_summaries_for_cyclegan(self):
- self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True)
+ def test_cyclegan_image_summaries_dont_work(self):
+ with self.assertRaises(ValueError):
+ summaries.add_gan_model_image_summaries(get_cyclegan_model())
def _test_add_gan_model_summaries_impl(self, get_model_fn,
expected_num_summary_ops):
@@ -137,7 +137,11 @@ class SummariesTest(test.TestCase):
self._test_add_image_comparison_summaries_impl(get_gan_model, 1)
def test_add_image_comparison_summaries_for_cyclegan(self):
- self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2)
+ summaries.add_cyclegan_image_summaries(get_cyclegan_model())
+
+ self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ with self.test_session(use_gpu=True):
+ summary.merge_all().eval()
if __name__ == '__main__':
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
index 7dd40c19c5..8186fa1c62 100644
--- a/tensorflow/contrib/kfac/examples/BUILD
+++ b/tensorflow/contrib/kfac/examples/BUILD
@@ -28,8 +28,28 @@ py_library(
)
py_binary(
- name = "convnet_mnist_main",
- srcs = ["convnet_mnist_main.py"],
+ name = "convnet_mnist_single_main",
+ srcs = ["convnet_mnist_single_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_multi_tower_main",
+ srcs = ["convnet_mnist_multi_tower_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_distributed_main",
+ srcs = ["convnet_mnist_distributed_main.py"],
srcs_version = "PY2AND3",
deps = [
":convnet",
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
index 39d80addaa..e8e3353091 100644
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ b/tensorflow/contrib/kfac/examples/convnet.py
@@ -37,6 +37,8 @@ import tensorflow as tf
from tensorflow.contrib.kfac.examples import mlp
from tensorflow.contrib.kfac.examples import mnist
+from tensorflow.contrib.kfac.python.ops import optimizer as opt
+
lc = tf.contrib.kfac.layer_collection
oq = tf.contrib.kfac.op_queue
@@ -48,12 +50,18 @@ __all__ = [
"linear_layer",
"build_model",
"minimize_loss_single_machine",
- "minimize_loss_distributed",
+ "distributed_grads_only_and_ops_chief_worker",
+ "distributed_grads_and_ops_dedicated_workers",
"train_mnist_single_machine",
- "train_mnist_distributed",
+ "train_mnist_distributed_sync_replicas",
+ "train_mnist_multitower"
]
+# Inverse update ops will be run every _INVERT_EVRY iterations.
+_INVERT_EVERY = 10
+
+
def conv_layer(layer_id, inputs, kernel_size, out_channels):
"""Builds a convolutional layer with ReLU non-linearity.
@@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection):
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
- tf.summary.scalar("loss", loss)
- tf.summary.scalar("accuracy", accuracy)
+ with tf.device("/cpu:0"):
+ tf.summary.scalar("loss", loss)
+ tf.summary.scalar("accuracy", accuracy)
# Register parameters. K-FAC needs to know about the inputs, outputs, and
# parameters of each conv/fully connected layer and the logits powering the
@@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection):
def minimize_loss_single_machine(loss,
accuracy,
layer_collection,
+ device="/gpu:0",
session_config=None):
"""Minimize loss with K-FAC on a single machine.
- A single Session is responsible for running all of K-FAC's ops.
+ A single Session is responsible for running all of K-FAC's ops. The covariance
+ and inverse update ops are placed on `device`. All model variables are on CPU.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
+ device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse
+ update ops are run on this device.
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
final value for 'accuracy'.
"""
# Train with K-FAC.
- global_step = tf.train.get_or_create_global_step()
+ g_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
+ placement_strategy="round_robin",
+ cov_devices=[device],
+ inv_devices=[device],
momentum=0.9)
- train_op = optimizer.minimize(loss, global_step=global_step)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ with tf.device(device):
+ train_op = optimizer.minimize(loss, global_step=g_step)
+
+ def make_update_op(update_thunks):
+ update_op = [thunk() for thunk in update_thunks]
+ return tf.group(*update_op)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([train_op, cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
- global_step_, loss_, accuracy_, _, _ = sess.run(
- [global_step, loss, accuracy, train_op, optimizer.cov_update_op])
-
- if global_step_ % 100 == 0:
- sess.run(optimizer.inv_update_op)
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [g_step, loss, accuracy, inverse_op])
- if global_step_ % 100 == 0:
+ if (global_step_ + 1) % _INVERT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
@@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks):
return int(np.ceil(0.6 * num_tasks))
-def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection):
- """Minimize loss with an synchronous implementation of K-FAC.
+def _make_distributed_train_op(
+ task_id,
+ num_worker_tasks,
+ num_ps_tasks,
+ layer_collection
+):
+ """Creates optimizer and distributed training op.
- Different tasks are responsible for different parts of K-FAC's Ops. The first
- 60% of tasks update weights; the next 20% accumulate covariance statistics;
- the last 20% invert the matrices used to precondition gradients.
+ Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
+ the train op.
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+
+ Returns:
+ sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
+ optimizer.
+ optimizer: Instance of `opt.KfacOptimizer`.
+ global_step: `tensor`, Global step.
+ """
+ tf.logging.info("Task id : %d", task_id)
+ with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
+ global_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ momentum=0.9)
+ sync_optimizer = tf.train.SyncReplicasOptimizer(
+ opt=optimizer,
+ replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
+ total_num_replicas=num_worker_tasks)
+ return sync_optimizer, optimizer, global_step
+
+
+def distributed_grads_only_and_ops_chief_worker(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+ loss, accuracy, layer_collection, invert_every=10):
+ """Minimize loss with a synchronous implementation of K-FAC.
+
+ All workers perform gradient computation. Chief worker applies gradient after
+ averaging the gradients obtained from all the workers. All workers block
+ execution untill the update is applied. Chief worker runs covariance and
+ inverse update ops. Covariance and inverse matrices are placed on parameter
+ servers in a round robin manner. For further details on synchronous
+ distributed optimization check `tf.train.SyncReplicasOptimizer`.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
@@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
+ invert_every: `int`, Number of steps between update the inverse.
Returns:
final value for 'accuracy'.
@@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
Raises:
ValueError: if task_id >= num_worker_tasks.
"""
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- momentum=0.9)
- inv_update_queue = oq.OpQueue(optimizer.inv_update_ops)
- sync_optimizer = tf.train.SyncReplicasOptimizer(
- opt=optimizer,
- replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+ sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+ task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+ tf.logging.info("Starting training.")
+ hooks = [sync_optimizer.make_session_run_hook(is_chief)]
+
+ def make_update_op(update_thunks):
+ update_op = [thunk() for thunk in update_thunks]
+ return tf.group(*update_op)
+
+ if is_chief:
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([train_op, cov_update_op]):
+ update_op = tf.cond(
+ tf.equal(tf.mod(global_step + 1, invert_every), 0),
+ lambda: make_update_op(inv_update_thunks),
+ tf.no_op)
+ else:
+ update_op = train_op
+
+ with tf.train.MonitoredTrainingSession(
+ master=master,
+ is_chief=is_chief,
+ checkpoint_dir=checkpoint_dir,
+ hooks=hooks,
+ stop_grace_period_secs=0) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [global_step, loss, accuracy, update_op])
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
+ loss_, accuracy_)
+ return accuracy_
+
+
+def distributed_grads_and_ops_dedicated_workers(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+ loss, accuracy, layer_collection):
+ """Minimize loss with a synchronous implementation of K-FAC.
+
+ Different workers are responsible for different parts of K-FAC's Ops. The
+ first 60% of tasks compute gradients; the next 20% accumulate covariance
+ statistics; the last 20% invert the matrices used to precondition gradients.
+ The chief worker applies the gradient .
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ master: string. IP and port of TensorFlow runtime process. Set to empty
+ string to run locally.
+ checkpoint_dir: string or None. Path to store checkpoints under.
+ loss: 0-D Tensor. Loss to be minimized.
+ accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
+ run with each step.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+
+ Returns:
+ final value for 'accuracy'.
+
+ Raises:
+ ValueError: if task_id >= num_worker_tasks.
+ """
+ sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+ task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+ _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+ inv_update_queue = oq.OpQueue(inv_update_ops)
tf.logging.info("Starting training.")
is_chief = (task_id == 0)
@@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
if _is_gradient_task(task_id, num_worker_tasks):
learning_op = train_op
elif _is_cov_update_task(task_id, num_worker_tasks):
- learning_op = optimizer.cov_update_op
+ learning_op = cov_update_op
elif _is_inv_update_task(task_id, num_worker_tasks):
# TODO(duckworthd): Running this op before cov_update_op has been run a
# few times can result in "InvalidArgumentError: Cholesky decomposition
@@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
return accuracy_
-def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
+def train_mnist_single_machine(data_dir,
+ num_epochs,
+ use_fake_data=False,
+ device="/gpu:0"):
"""Train a ConvNet on MNIST.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
+ device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse
+ update ops are run on this device.
Returns:
accuracy of model on the final minibatch of training data.
@@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
examples, labels, num_labels=10, layer_collection=layer_collection)
# Fit model.
- return minimize_loss_single_machine(loss, accuracy, layer_collection)
+ return minimize_loss_single_machine(
+ loss, accuracy, layer_collection, device=device)
def train_mnist_multitower(data_dir, num_epochs, num_towers,
- use_fake_data=True):
+ use_fake_data=True, devices=None):
"""Train a ConvNet on MNIST.
+ Training data is split equally among the towers. Each tower computes loss on
+ its own batch of data and the loss is aggregated on the CPU. The model
+ variables are placed on first tower. The covariance and inverse update ops
+ and variables are placed on GPUs in a round robin manner.
+
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of CPUs to split inference across.
use_fake_data: bool. If True, generate a synthetic dataset.
+ devices: string, Either list of CPU or GPU. The covaraince and inverse
+ update ops are run on this device.
Returns:
accuracy of model on the final minibatch of training data.
"""
+ if devices:
+ device_count = {"GPU": num_towers}
+ else:
+ device_count = {"CPU": num_towers}
+
+ devices = devices or [
+ "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
+ ]
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
tower_batch_size = 128
@@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
layer_collection = lc.LayerCollection()
tower_results = []
for tower_id in range(num_towers):
- with tf.device("/cpu:%d" % tower_id):
+ with tf.device(devices[tower_id]):
with tf.name_scope("tower%d" % tower_id):
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
tf.logging.info("Building tower %d." % tower_id)
@@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
accuracy = tf.reduce_mean(accuracies)
# Fit model.
+
session_config = tf.ConfigProto(
- allow_soft_placement=False, device_count={
- "CPU": num_towers
- })
- return minimize_loss_single_machine(
- loss, accuracy, layer_collection, session_config=session_config)
+ allow_soft_placement=False,
+ device_count=device_count,
+ )
+
+ g_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ placement_strategy="round_robin",
+ cov_devices=devices,
+ inv_devices=devices,
+ momentum=0.9)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+ train_op = optimizer.minimize(loss, global_step=g_step)
-def train_mnist_distributed(task_id,
- num_worker_tasks,
- num_ps_tasks,
- master,
- data_dir,
- num_epochs,
- use_fake_data=False):
- """Train a ConvNet on MNIST.
+ def make_update_op(update_thunks):
+ update_op = [thunk() for thunk in update_thunks]
+ return tf.group(*update_op)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([train_op, cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
+
+ tf.logging.info("Starting training.")
+ with tf.train.MonitoredTrainingSession(config=session_config) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [g_step, loss, accuracy, inverse_op])
+
+ if (global_step_ + 1) % _INVERT_EVERY == 0:
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
+ global_step_, loss_, accuracy_)
+
+
+def train_mnist_distributed_sync_replicas(task_id,
+ is_chief,
+ num_worker_tasks,
+ num_ps_tasks,
+ master,
+ data_dir,
+ num_epochs,
+ op_strategy,
+ use_fake_data=False):
+ """Train a ConvNet on MNIST using Sync replicas optimizer.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables.
master: string. IP and port of TensorFlow runtime process.
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
+ op_strategy: `string`, Strategy to run the covariance and inverse
+ ops. If op_strategy == `chief_worker` then covaraiance and inverse
+ update ops are run on chief worker otherwise they are run on dedicated
+ workers.
+
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
+
+ Raises:
+ ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
@@ -448,9 +649,17 @@ def train_mnist_distributed(task_id,
# Fit model.
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
- return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
- master, checkpoint_dir, loss, accuracy,
- layer_collection)
+ if op_strategy == "chief_worker":
+ return distributed_grads_only_and_ops_chief_worker(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+ checkpoint_dir, loss, accuracy, layer_collection)
+ elif op_strategy == "dedicated_workers":
+ return distributed_grads_and_ops_dedicated_workers(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+ checkpoint_dir, loss, accuracy, layer_collection)
+ else:
+ raise ValueError("Only supported op strategies are : {}, {}".format(
+ "chief_worker", "dedicated_workers"))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
new file mode 100644
index 0000000000..b4c2d4a9e9
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Distributed training with sync replicas optimizer. See
+`convnet.train_mnist_distributed_sync_replicas` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_integer("task", -1, "Task identifier")
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
+flags.DEFINE_string(
+ "cov_inv_op_strategy", "chief_worker",
+ "In dist training mode run the cov, inv ops on chief or dedicated workers."
+)
+flags.DEFINE_string("master", "local", "Session master.")
+flags.DEFINE_integer("ps_tasks", 2,
+ "Number of tasks in the parameter server job.")
+flags.DEFINE_integer("replicas_to_aggregate", 5,
+ "Number of replicas to aggregate.")
+flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
+flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
+
+
+def _is_chief():
+ """Determines whether a job is the chief worker."""
+ if "chief_worker" in FLAGS.brain_jobs:
+ return FLAGS.brain_job_name == "chief_worker"
+ else:
+ return FLAGS.task == 0
+
+
+def main(unused_argv):
+ _ = unused_argv
+ convnet.train_mnist_distributed_sync_replicas(
+ FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
+ FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
new file mode 100644
index 0000000000..4249bf8a8d
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Multi tower training mode. See `convnet.train_mnist_multitower` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
+flags.DEFINE_integer("num_towers", 2,
+ "Number of towers for multi tower training.")
+
+
+def main(unused_argv):
+ _ = unused_argv
+ assert FLAGS.num_towers > 1
+ devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
+ convnet.train_mnist_multitower(
+ FLAGS.data_dir,
+ num_epochs=200,
+ num_towers=FLAGS.num_towers,
+ devices=devices)
+
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
index b0c6fbde19..3aa52aff19 100644
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
@@ -14,44 +14,26 @@
# ==============================================================================
r"""Train a ConvNet on MNIST using K-FAC.
-See convnet.py for details.
+Train on single machine. See `convnet.train_mnist_single_machine` for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import argparse
-import sys
+from absl import flags
import tensorflow as tf
from tensorflow.contrib.kfac.examples import convnet
-FLAGS = None
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-def main(argv):
- _ = argv
-
- if FLAGS.num_towers > 1:
- convnet.train_mnist_multitower(
- FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
- else:
- convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
+def main(unused_argv):
+ convnet.train_mnist_single_gpu(FLAGS.data_dir, num_epochs=200)
if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_dir",
- type=str,
- default="/tmp/mnist",
- help="Directory to store dataset in.")
- parser.add_argument(
- "--num_towers",
- type=int,
- default=1,
- help="Number of CPUs to split minibatch across.")
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
index 8d86c2bb51..6de775cc79 100644
--- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py
+++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
@@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase):
def testMinimizeLossSingleMachine(self):
with tf.Graph().as_default():
loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy,
- layer_collection)
- self.assertLess(accuracy_, 1.0)
+ accuracy_ = convnet.minimize_loss_single_machine(
+ loss, accuracy, layer_collection, device="/cpu:0")
+ self.assertLess(accuracy_, 2.0)
def testMinimizeLossDistributed(self):
with tf.Graph().as_default():
loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_distributed(
+ accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
task_id=0,
+ is_chief=True,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
@@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase):
loss=loss,
accuracy=accuracy,
layer_collection=layer_collection)
- self.assertLess(accuracy_, 1.0)
+ self.assertLess(accuracy_, 2.0)
def testTrainMnistSingleMachine(self):
with tf.Graph().as_default():
@@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase):
# but there are too few parameters for the model to effectively memorize
# the training set the way an MLP can.
convnet.train_mnist_single_machine(
- data_dir=None, num_epochs=1, use_fake_data=True)
+ data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
def testTrainMnistMultitower(self):
with tf.Graph().as_default():
@@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase):
def testTrainMnistDistributed(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
- convnet.train_mnist_distributed(
+ convnet.train_mnist_distributed_sync_replicas(
task_id=0,
+ is_chief=True,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
data_dir=None,
num_epochs=1,
+ op_strategy="chief_worker",
use_fake_data=True)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index f73c24f8fb..2477d2bfc1 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -114,6 +114,7 @@ py_test(
name = "utils_test",
srcs = ["utils_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/contrib/tpu",
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 586a004f88..19608aca47 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -990,9 +990,11 @@ class LayerCollection(object):
num_uses=num_uses),
reuse=reuse)
block.register_additional_tower(inputs, outputs)
-
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
def register_conv2d_multi(self,
params,
@@ -1066,9 +1068,11 @@ class LayerCollection(object):
reuse=reuse)
block.register_additional_tower(inputs, outputs)
-
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
# TODO(b/74108452): change the loss registration functions names to refer
# to "loss functions" instead of distributions. Following naming convention
@@ -1088,7 +1092,7 @@ class LayerCollection(object):
inputs: A list of Tensors, each of shape [batch_size, input_size] and
dtype int32. Indices into embedding matrix. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [num_uses, batch_size, input_size],
+ OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
which is a reshaped version of a Tensor of shape [num_uses, batch_size,
input_size].
outputs: A list of Tensors, each of shape [batch_size, embedding_size].
@@ -1129,7 +1133,10 @@ class LayerCollection(object):
params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
block.register_additional_tower(inputs, outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
def register_categorical_predictive_distribution(self,
logits,
diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD
index 18b265ae80..c8812d4b23 100644
--- a/tensorflow/contrib/labeled_tensor/BUILD
+++ b/tensorflow/contrib/labeled_tensor/BUILD
@@ -70,6 +70,7 @@ py_test(
"python/ops/core_test.py",
],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":_typecheck",
":core",
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 4be55468db..d5b3b279a1 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -188,6 +188,7 @@ py_test(
size = "small",
srcs = ["python/layers/normalization_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":layers_py",
"//tensorflow/contrib/framework:framework_py",
@@ -353,6 +354,7 @@ py_test(
size = "small",
srcs = ["python/ops/sparse_ops_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":layers_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index 337c9e06b8..00f03a111a 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -104,6 +104,7 @@ See the @{$python/contrib.layers} guide.
@@infer_real_valued_columns
@@sequence_input_from_feature_columns
+@@group_norm
@@instance_norm
"""
@@ -122,6 +123,7 @@ _allowed_symbols = ['bias_add',
'conv3d',
'elu',
'feature_column',
+ 'group_norm',
'instance_norm',
'legacy_fully_connected',
'legacy_linear',
diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py
index e7d4080ff7..c807ab0f2e 100644
--- a/tensorflow/contrib/layers/python/layers/normalization.py
+++ b/tensorflow/contrib/layers/python/layers/normalization.py
@@ -24,11 +24,13 @@ from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
__all__ = [
+ 'group_norm',
'instance_norm',
]
@@ -158,3 +160,196 @@ def instance_norm(inputs,
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
+
+@add_arg_scope
+def group_norm(inputs,
+ groups=32,
+ channels_axis=-1,
+ reduction_axes=(-3, -2),
+ center=True,
+ scale=True,
+ epsilon=1e-6,
+ activation_fn=None,
+ param_initializers=None,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ scope=None):
+ """Functional interface for the group normalization layer.
+
+ Reference: https://arxiv.org/abs/1803.08494.
+
+ "Group Normalization", Yuxin Wu, Kaiming He
+
+ Args:
+ inputs: A Tensor with at least 2 dimensions one which is channels. All
+ shape dimensions must be fully defined.
+ groups: Integer. Divide the channels into this number of groups over which
+ normalization statistics are computed. This number must be commensurate
+ with the number of channels in `inputs`.
+ channels_axis: An integer. Specifies index of channels axis which will be
+ broken into `groups`, each of which whose statistics will be computed
+ across. Must be mutually exclusive with `reduction_axes`. Preferred usage
+ is to specify negative integers to be agnostic as to whether a batch
+ dimension is included.
+ reduction_axes: Tuple of integers. Specifies dimensions over which
+ statistics will be accumulated. Must be mutually exclusive with
+ `channels_axis`. Statistics will not be accumulated across axes not
+ specified in `reduction_axes` nor `channel_axis`. Preferred usage is to
+ specify negative integers to be agnostic to whether a batch dimension is
+ included.
+
+ Some sample usage cases:
+ NHWC format: channels_axis=-1, reduction_axes=[-3, -2]
+ NCHW format: channels_axis=-3, reduction_axes=[-2, -1]
+
+ center: If True, add offset of `beta` to normalized tensor. If False, `beta`
+ is ignored.
+ scale: If True, multiply by `gamma`. If False, `gamma` is
+ not used. When the next layer is linear (also e.g. `nn.relu`), this can be
+ disabled since the scaling can be done by the next layer.
+ epsilon: Small float added to variance to avoid dividing by zero.
+ activation_fn: Activation function, default set to None to skip it and
+ maintain a linear activation.
+ param_initializers: Optional initializers for beta, gamma, moving mean and
+ moving variance.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional collections for the variables.
+ outputs_collections: Collections to add the outputs.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ A `Tensor` representing the output of the operation.
+
+ Raises:
+ ValueError: If the rank of `inputs` is undefined.
+ ValueError: If rank or channels dimension of `inputs` is undefined.
+ ValueError: If number of groups is not commensurate with number of channels.
+ ValueError: If reduction_axes or channels_axis are out of bounds.
+ ValueError: If reduction_axes are not mutually exclusive with channels_axis.
+ """
+ # TODO(shlens): Support partially defined shapes for the inputs.
+ inputs = ops.convert_to_tensor(inputs)
+ original_shape = inputs.shape
+
+ if inputs.shape.ndims is None:
+ raise ValueError('Inputs %s has undefined rank.' % inputs.name)
+ if channels_axis > (inputs.shape.ndims - 1):
+ raise ValueError('Axis is out of bounds.')
+
+ # Standardize the channels_axis to be positive and identify # of channels.
+ if channels_axis < 0:
+ channels_axis = inputs.shape.ndims + channels_axis
+ channels = inputs.shape[channels_axis].value
+
+ if channels is None:
+ raise ValueError('Inputs %s has undefined channel dimension: %d.' % (
+ inputs.name, channels_axis))
+
+ # Standardize the reduction_axes to be positive.
+ reduction_axes = list(reduction_axes)
+ for i in range(len(reduction_axes)):
+ if reduction_axes[i] < 0:
+ reduction_axes[i] += inputs.shape.ndims
+
+ for a in reduction_axes:
+ if a > inputs.shape.ndims:
+ raise ValueError('Axis is out of bounds.')
+ if inputs.shape[a].value is None:
+ raise ValueError('Inputs %s has undefined dimensions %d.' % (
+ inputs.name, a))
+ if channels_axis == a:
+ raise ValueError('reduction_axis must be mutually exclusive '
+ 'with channels_axis')
+ if groups > channels:
+ raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
+ if channels % groups != 0:
+ raise ValueError('%d channels is not commensurate with %d groups.' %
+ (channels, groups))
+
+ # Determine axes before channels. Some examples of common image formats:
+ # 'NCHW': before = [N], after = [HW]
+ # 'NHWC': before = [NHW], after = []
+ axes_before_channels = inputs.shape.as_list()[:channels_axis]
+ axes_after_channels = inputs.shape.as_list()[channels_axis+1:]
+
+ # Manually broadcast the parameters to conform to the number of groups.
+ params_shape_broadcast = ([1] * len(axes_before_channels) +
+ [groups, channels // groups] +
+ [1] * len(axes_after_channels))
+
+ # Reshape the input by the group within the channel dimension.
+ inputs_shape = (axes_before_channels + [groups, channels // groups] +
+ axes_after_channels)
+ inputs = array_ops.reshape(inputs, inputs_shape)
+
+ # Determine the dimensions across which moments are calculated.
+ moments_axes = [channels_axis + 1]
+ for a in reduction_axes:
+ if a > channels_axis:
+ moments_axes.append(a + 1)
+ else:
+ moments_axes.append(a)
+
+ with variable_scope.variable_scope(
+ scope, 'GroupNorm', [inputs], reuse=reuse) as sc:
+ # Note that the params_shape is the number of channels always.
+ params_shape = [channels]
+
+ # Allocate parameters for the beta and gamma of the normalization.
+ beta, gamma = None, None
+ dtype = inputs.dtype.base_dtype
+ if param_initializers is None:
+ param_initializers = {}
+ if center:
+ beta_collections = utils.get_variable_collections(
+ variables_collections, 'beta')
+ beta_initializer = param_initializers.get(
+ 'beta', init_ops.zeros_initializer())
+ beta = variables.model_variable('beta',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=beta_initializer,
+ collections=beta_collections,
+ trainable=trainable)
+ beta = array_ops.reshape(beta, params_shape_broadcast)
+
+ if scale:
+ gamma_collections = utils.get_variable_collections(
+ variables_collections, 'gamma')
+ gamma_initializer = param_initializers.get(
+ 'gamma', init_ops.ones_initializer())
+ gamma = variables.model_variable('gamma',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=gamma_initializer,
+ collections=gamma_collections,
+ trainable=trainable)
+ gamma = array_ops.reshape(gamma, params_shape_broadcast)
+
+ # Calculate the moments.
+ mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+
+ # Compute normalization.
+ # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
+ # appropriately so that this operation may be faster.
+ gain = math_ops.rsqrt(variance + epsilon)
+ offset = -mean * gain
+ if gamma is not None:
+ gain *= gamma
+ offset *= gamma
+ if beta is not None:
+ offset += beta
+ outputs = inputs * gain + offset
+
+ # Collapse the groups into the channel dimension.
+ outputs = array_ops.reshape(outputs, original_shape)
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index 5cff1bf0eb..b6e96350db 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -166,5 +166,231 @@ class InstanceNormTest(test.TestCase):
def testOutputBigInput5DNCHW(self):
self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3)
+
+class GroupNormTest(test.TestCase):
+
+ def testInvalidGroupSize(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10))
+ with self.assertRaisesRegexp(ValueError,
+ 'Invalid groups 10 for 2 channels.'):
+ normalization.group_norm(inputs, groups=10,
+ reduction_axes=[-2, -1], channels_axis=-3)
+
+ def testBadCommensurateGroup(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10))
+ with self.assertRaisesRegexp(ValueError,
+ '4 channels is not commensurate with '
+ '3 groups.'):
+ normalization.group_norm(inputs, groups=3,
+ reduction_axes=[-2, -1], channels_axis=-3)
+
+ def testAxisIsBad(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5))
+ with self.assertRaisesRegexp(ValueError,
+ 'Axis is out of bounds.'):
+ normalization.group_norm(inputs, channels_axis=5)
+ with self.assertRaisesRegexp(ValueError,
+ 'Axis is out of bounds.'):
+ normalization.group_norm(inputs, reduction_axes=[1, 5])
+
+ def testNotMutuallyExclusiveAxis(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32))
+ # Specify axis with negative values.
+ with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+ normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2])
+ # Specify axis with positive values.
+ with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+ normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3])
+ # Specify axis with mixed positive and negative values.
+ with self.assertRaisesRegexp(ValueError, 'mutually exclusive'):
+ normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2])
+
+ def testUnknownShape(self):
+ inputs = array_ops.placeholder(dtypes.float32)
+ with self.assertRaisesRegexp(ValueError, 'undefined rank'):
+ normalization.group_norm(inputs)
+
+ def testParamsShapeNotFullyDefinedReductionAxes(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4))
+ with self.assertRaisesRegexp(ValueError, 'undefined dimensions'):
+ normalization.group_norm(inputs)
+
+ def testParamsShapeNotFullyDefinedChannelsAxis(self):
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None))
+ with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'):
+ normalization.group_norm(inputs, channels_axis=-1,
+ reduction_axes=[-3, -2])
+
+ def testCreateOp(self):
+ height, width, groups = 3, 3, 4
+ images = random_ops.random_uniform((5, height, width, 2*groups), seed=1)
+ output = normalization.group_norm(images, groups=groups, channels_axis=-1,
+ reduction_axes=[-3, -2])
+ print('name: ', output.op.name)
+ self.assertListEqual([5, height, width, 2*groups], output.shape.as_list())
+
+ def testCreateOpFloat64(self):
+ height, width, groups = 3, 3, 5
+ images = random_ops.random_uniform(
+ (5, height, width, 4*groups), dtype=dtypes.float64, seed=1)
+ output = normalization.group_norm(images, groups=groups)
+ self.assertEqual(dtypes.float64, output.dtype)
+ self.assertListEqual([5, height, width, 4*groups], output.shape.as_list())
+
+ def testCreateOpNoScaleCenter(self):
+ height, width, groups = 3, 3, 7
+ images = random_ops.random_uniform(
+ (5, height, width, 3*groups), dtype=dtypes.float32, seed=1)
+ output = normalization.group_norm(images, groups=groups, center=False,
+ scale=False)
+ self.assertListEqual([5, height, width, 3*groups], output.shape.as_list())
+ self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta')))
+ self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma')))
+
+ def testCreateVariables_NHWC(self):
+ height, width = 3, 3
+ images = random_ops.random_uniform((5, height, width, 8), seed=1)
+ normalization.group_norm(images, groups=4,
+ channels_axis=-1, reduction_axes=(-3, -2),
+ center=True, scale=True)
+ beta = contrib_variables.get_variables_by_name('beta')[0]
+ gamma = contrib_variables.get_variables_by_name('gamma')[0]
+ self.assertEqual('GroupNorm/beta', beta.op.name)
+ self.assertEqual('GroupNorm/gamma', gamma.op.name)
+
+ def testCreateVariables_NCHW(self):
+ height, width, groups = 3, 3, 4
+ images = random_ops.random_uniform((5, 2*groups, height, width), seed=1)
+ normalization.group_norm(images, groups=4,
+ channels_axis=-3, reduction_axes=(-2, -1),
+ center=True, scale=True)
+ beta = contrib_variables.get_variables_by_name('beta')[0]
+ gamma = contrib_variables.get_variables_by_name('gamma')[0]
+ self.assertEqual('GroupNorm/beta', beta.op.name)
+ self.assertEqual('GroupNorm/gamma', gamma.op.name)
+
+ def testReuseVariables(self):
+ height, width = 3, 3
+ images = random_ops.random_uniform((5, height, width, 4), seed=1)
+ normalization.group_norm(images, groups=2, scale=True, scope='IN')
+ normalization.group_norm(images, groups=2, scale=True, scope='IN',
+ reuse=True)
+ beta = contrib_variables.get_variables_by_name('beta')
+ gamma = contrib_variables.get_variables_by_name('gamma')
+ self.assertEqual(1, len(beta))
+ self.assertEqual(1, len(gamma))
+
+ def testValueCorrectWithReuseVars(self):
+ height, width = 3, 3
+ image_shape = (10, height, width, 4)
+ images = random_ops.random_uniform(image_shape, seed=1)
+ output_train = normalization.group_norm(images, groups=2, scope='IN')
+ output_eval = normalization.group_norm(images, groups=2, scope='IN',
+ reuse=True)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ # output_train and output_eval should be the same.
+ train_np, eval_np = sess.run([output_train, output_eval])
+ self.assertAllClose(train_np, eval_np)
+
+ def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
+ groups=2, tol=1e-2):
+ # Select the axis for the channel and the dimensions along which statistics
+ # are accumulated.
+ if channels_axis < 0:
+ channels_axis += len(input_shape)
+ reduced_axes = [channels_axis + 1]
+ for a in reduction_axes:
+ if a < 0:
+ a += len(input_shape)
+ if a < channels_axis:
+ reduced_axes.append(a)
+ else:
+ reduced_axes.append(a+1)
+ reduced_axes = tuple(reduced_axes)
+
+ # Calculate the final shape for the output Tensor.
+ axes_before_channels = input_shape[:channels_axis]
+ axes_after_channels = input_shape[channels_axis+1:]
+ channels = input_shape[channels_axis]
+ outputs_shape = (axes_before_channels + [groups, channels // groups] +
+ axes_after_channels)
+
+ # Calculate the final shape for the output statistics.
+ reduced_shape = []
+ for i, a in enumerate(outputs_shape):
+ if i not in reduced_axes:
+ reduced_shape.append(a)
+
+ for mu in (0.0, 1e2):
+ for sigma in (1.0, 0.1):
+ # Determine shape of Tensor after normalization.
+ expected_mean = np.zeros(reduced_shape)
+ expected_var = np.ones(reduced_shape)
+
+ inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+ output_op = normalization.group_norm(
+ inputs, groups=groups, center=False, scale=False,
+ channels_axis=channels_axis,
+ reduction_axes=reduction_axes)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ outputs = sess.run(output_op)
+ # Make sure that there are no NaNs
+ self.assertFalse(np.isnan(outputs).any())
+
+ outputs = np.reshape(outputs, outputs_shape)
+ mean = np.mean(outputs, axis=reduced_axes)
+ var = np.var(outputs, axis=reduced_axes)
+ # The mean and variance of each example should be close to 0 and 1
+ # respectively.
+ self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
+ self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
+
+ def testOutputSmallInput4D_NHWC(self):
+ input_shape = [10, 10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+
+ def testOutputSmallInput3D_NHWC(self):
+ input_shape = [10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+
+ def testOutputSmallInput4D_NCHW(self):
+ input_shape = [10, 10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+
+ def testOutputSmallInput3D_NCHW(self):
+ input_shape = [10, 10, 30]
+ # Specify axes with positive values.
+ self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+
+ def testOutputBigInput4D_NHWC(self):
+ self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
+ groups=1)
+
+ def testOutputBigInput4D_NCHW(self):
+ self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
+ groups=4)
+
+ def testOutputSmallInput2D_NC(self):
+ self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+
+ def testOutputSmallInput5D_NCXXX(self):
+ self.doOutputTest([10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index ba55365c14..d665fc9335 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -117,6 +117,7 @@ py_test(
size = "small",
srcs = ["python/learn/learn_io/data_feeder_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/python:client_testlib",
@@ -172,6 +173,7 @@ tf_py_test(
"//tensorflow/python:variables",
"//tensorflow/python/estimator",
],
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
py_test(
@@ -190,6 +192,7 @@ py_test(
size = "small",
srcs = ["python/learn/graph_actions_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/contrib/framework:framework_py",
@@ -591,6 +594,7 @@ py_test(
size = "small",
srcs = ["python/learn/learn_io/io_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/contrib/learn/python/learn/datasets",
@@ -820,6 +824,7 @@ py_test(
size = "small",
srcs = ["python/learn/utils/saved_model_export_utils_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index f3500bf56f..8c85c431be 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -298,7 +298,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
# core_run_config.RunConfig.__init__(self)
# so instead of breaking compatibility with that assumption, we
# just manually initialize this field:
- self._distribute = None
+ self._train_distribute = None
gpu_options = config_pb2.GPUOptions(
per_process_gpu_memory_fraction=gpu_memory_fraction)
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index ac269d540a..9c4533079c 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -89,6 +89,7 @@ cc_library(
hdrs = [
"builtin_op_data.h",
],
+ deps = [":context"],
)
cc_library(
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index f84b3dad95..e9d0fbc5a9 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -25,7 +25,7 @@ limitations under the License.
namespace tflite {
-class AllocationInfo;
+struct AllocationInfo;
// A memory planner that makes all the allocations using arenas.
//
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 5fc8954743..2b6c24768c 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -17,6 +17,8 @@ limitations under the License.
#include <stdint.h>
+#include "tensorflow/contrib/lite/context.h"
+
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
@@ -174,6 +176,11 @@ typedef struct {
int block_size;
} TfLiteSpaceToDepthParams;
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
typedef enum {
kTfLiteCombinerTypeSum = 0,
kTfLiteCombinerTypeMean = 1,
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 14f461f5f9..a33959dca4 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -68,6 +68,19 @@ public final class Interpreter implements AutoCloseable {
}
/**
+ * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
+ *
+ * @param modelFile: a file of a pre-trained TF Lite model
+ * @param numThreads: number of threads to use for inference
+ */
+ public Interpreter(@NonNull File modelFile, int numThreads) {
+ if (modelFile == null) {
+ return;
+ }
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads);
+ }
+
+ /**
* Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
*
* <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index dbf8f8f7cc..fc8187acfe 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -32,9 +32,13 @@ import java.util.Map;
final class NativeInterpreterWrapper implements AutoCloseable {
NativeInterpreterWrapper(String modelPath) {
+ this(modelPath, /* numThreads= */ -1);
+ }
+
+ NativeInterpreterWrapper(String modelPath, int numThreads) {
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModel(modelPath, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
}
@@ -44,11 +48,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
* NativeInterpreterWrapper}.
*/
NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
- modelByteBuffer = mappedByteBuffer;
- errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
- modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, /* numThreads= */ -1);
- isMemoryAllocated = true;
+ this(mappedByteBuffer, /* numThreads= */ -1);
}
/**
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 19942de7bc..17ef2c572e 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -34,6 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(ahentz): these two checks would make the new implementation
+ // incompatible with some existing models, where params is not specified. It
+ // is OK not to have them because toco would have set input and output types
+ // to match the parameters.
+ // auto* params = reinterpret_cast<TfLiteCastParams*>(node->builtin_data);
+ // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type);
+ // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type);
+
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index ee8bfe56d9..e67f4e06f3 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -45,10 +45,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
- // TODO(ahentz): Our current implementations only support float32.
- TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+ TF_LITE_ENSURE(
+ context, output->type == kTfLiteFloat32 || output->type == kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ if (output->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
+ }
+
// TODO(ahentz): For some reason our implementations don't support
// activations.
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
@@ -75,6 +80,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_L2NORM(optimized_ops);
}
#undef TF_LITE_L2NORM
+ } else if (output->type == kTfLiteUInt8) {
+#define TF_LITE_L2NORM(type) \
+ type::L2Normalization(GetTensorData<uint8>(input), GetTensorDims(input), \
+ input->params.zero_point, \
+ GetTensorData<uint8>(output), GetTensorDims(output))
+
+ if (kernel_type == kReference) {
+ TF_LITE_L2NORM(reference_ops);
+ }
+ if (kernel_type == kGenericOptimized) {
+ TF_LITE_L2NORM(optimized_ops);
+ }
+#undef TF_LITE_L2NORM
} else {
context->ReportError(context, "Inputs and outputs not all float types.");
return kTfLiteError;
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
index 30e103f330..042314ccf5 100644
--- a/tensorflow/contrib/lite/kernels/l2norm_test.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -25,10 +25,22 @@ using ::testing::ElementsAreArray;
class L2NormOpModel : public SingleOpModel {
public:
- L2NormOpModel(std::initializer_list<int> input_shape,
- ActivationFunctionType activation_type) {
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ L2NormOpModel(const std::initializer_list<int> input_shape,
+ const TensorType tensor_type,
+ const ActivationFunctionType activation_type) {
+ TensorData data = TensorData{tensor_type};
+ if (tensor_type != TensorType_FLOAT32) {
+ data.min = -2.0;
+ data.max = 2.0;
+ data.scale = 2.0;
+ data.zero_point = 128;
+ }
+ input_ = AddInput(data);
+ if (tensor_type != TensorType_FLOAT32) {
+ data.min = -1.0;
+ data.max = 127.0 / 128.0;
+ }
+ output_ = AddOutput(data);
SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
CreateL2NormOptions(builder_, activation_type).Union());
BuildInterpreter({input_shape});
@@ -38,7 +50,17 @@ class L2NormOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+
+ int input() const { return input_; }
private:
int input_;
@@ -46,13 +68,26 @@ class L2NormOpModel : public SingleOpModel {
};
TEST(L2NormOpTest, SimpleTest) {
- L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE);
+ L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
+ ActivationFunctionType_NONE);
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
m.Invoke();
- EXPECT_THAT(m.GetOutput(),
+ EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
+TEST(L2NormOpTest, SimpleUint8Test) {
+ L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
+
+ m.QuantizeAndPopulate<uint8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({58, 166, 173, 205, 83, 134}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/maximum.cc b/tensorflow/contrib/lite/kernels/maximum.cc
index 9fdf2b47ea..13c40603ce 100644
--- a/tensorflow/contrib/lite/kernels/maximum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum.cc
@@ -52,9 +52,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
MaximumContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type);
- TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input2->dims);
- op_context.output->type = op_context.input2->type;
- return context->ResizeTensor(context, op_context.output, output_dims);
+ op_context.output->type = op_context.input1->type;
+
+ bool requires_broadcast =
+ !HaveSameShapes(op_context.input1, op_context.input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (requires_broadcast) {
+ TF_LITE_ENSURE_OK(
+ context, CalculateShapeForBroadcast(context, op_context.input1,
+ op_context.input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(op_context.input1->dims);
+ }
+
+ return context->ResizeTensor(context, op_context.output, output_size);
}
template <KernelType kernel_type>
diff --git a/tensorflow/contrib/lite/kernels/maximum_test.cc b/tensorflow/contrib/lite/kernels/maximum_test.cc
index b3fd7d4e6f..df2bf29c20 100644
--- a/tensorflow/contrib/lite/kernels/maximum_test.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_test.cc
@@ -71,6 +71,20 @@ TEST(MaximumOpTest, FloatTest) {
ElementsAreArray(ArrayFloatNear({1.0, 0.0, 1.0, 12.0, -2.0, -1.43})));
}
+TEST(MaximumOpTest, FloatWithBroadcastTest) {
+ std::initializer_list<float> data1 = {1.0, 0.0, -1.0, -2.0, -1.44, 11.0};
+ std::initializer_list<float> data2 = {0.5, 2.0};
+ MaximumOpModel m({TensorType_FLOAT32, {3, 1, 2}}, {TensorType_FLOAT32, {2}},
+ TensorType_FLOAT32);
+ m.SetInput1<float>(data1);
+ m.SetInput2<float>(data2);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1.0, 2.0, 0.5, 2.0, 0.5, 11.0})));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index eb374d9031..e6d5c300dc 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -228,6 +228,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
}
break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
+ }
+ break;
default:
context->ReportError(context,
"Type is currently not supported "
diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
index 5c98c5f431..22d7b097cb 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
@@ -24,6 +24,8 @@ namespace {
using ::int32;
using ::testing::ElementsAreArray;
+template <typename input_type = float,
+ TensorType tensor_input_type = TensorType_FLOAT32>
class StridedSliceOpModel : public SingleOpModel {
public:
StridedSliceOpModel(std::initializer_list<int> input_shape,
@@ -32,11 +34,11 @@ class StridedSliceOpModel : public SingleOpModel {
std::initializer_list<int> strides_shape, int begin_mask,
int end_mask, int ellipsis_mask, int new_axis_mask,
int shrink_axis_mask) {
- input_ = AddInput(TensorType_FLOAT32);
+ input_ = AddInput(tensor_input_type);
begin_ = AddInput(TensorType_INT32);
end_ = AddInput(TensorType_INT32);
strides_ = AddInput(TensorType_INT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(tensor_input_type);
SetBuiltinOp(
BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
@@ -45,8 +47,8 @@ class StridedSliceOpModel : public SingleOpModel {
BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
}
- void SetInput(std::initializer_list<float> data) {
- PopulateTensor<float>(input_, data);
+ void SetInput(std::initializer_list<input_type> data) {
+ PopulateTensor<input_type>(input_, data);
}
void SetBegin(std::initializer_list<int32> data) {
PopulateTensor<int32>(begin_, data);
@@ -58,7 +60,9 @@ class StridedSliceOpModel : public SingleOpModel {
PopulateTensor<int32>(strides_, data);
}
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<input_type> GetOutput() {
+ return ExtractVector<input_type>(output_);
+ }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
@@ -71,19 +75,19 @@ class StridedSliceOpModel : public SingleOpModel {
TEST(StridedSliceOpTest, UnsupportedInputSize) {
EXPECT_DEATH(
- StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
+ StridedSliceOpModel<>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
"StridedSlice op only supports 1D-4D input arrays.");
}
TEST(StridedSliceOpTest, UnssupportedArgs) {
- EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
+ EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
"ellipsis_mask is not implemented yet.");
- EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
+ EXPECT_DEATH(StridedSliceOpModel<>({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
"new_axis_mask is not implemented yet.");
}
TEST(StridedSliceOpTest, In1D) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
@@ -94,7 +98,7 @@ TEST(StridedSliceOpTest, In1D) {
}
TEST(StridedSliceOpTest, In1D_EmptyOutput) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({10});
m.SetEnd({3});
@@ -104,7 +108,7 @@ TEST(StridedSliceOpTest, In1D_EmptyOutput) {
}
TEST(StridedSliceOpTest, In1D_NegativeBegin) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({3});
@@ -115,7 +119,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBegin) {
}
TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-5});
m.SetEnd({3});
@@ -126,7 +130,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
}
TEST(StridedSliceOpTest, In1D_NegativeEnd) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({-2});
@@ -137,7 +141,7 @@ TEST(StridedSliceOpTest, In1D_NegativeEnd) {
}
TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({5});
@@ -148,7 +152,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
}
TEST(StridedSliceOpTest, In1D_BeginMask) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
@@ -159,7 +163,7 @@ TEST(StridedSliceOpTest, In1D_BeginMask) {
}
TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-2});
m.SetEnd({-3});
@@ -170,7 +174,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
}
TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({5});
m.SetEnd({2});
@@ -181,7 +185,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
}
TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({2});
m.SetEnd({-4});
@@ -192,7 +196,7 @@ TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
}
TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({-5});
@@ -203,7 +207,7 @@ TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
}
TEST(StridedSliceOpTest, In1D_EndMask) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
@@ -214,7 +218,7 @@ TEST(StridedSliceOpTest, In1D_EndMask) {
}
TEST(StridedSliceOpTest, In1D_NegStride) {
- StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3});
m.SetBegin({-1});
m.SetEnd({-4});
@@ -225,7 +229,7 @@ TEST(StridedSliceOpTest, In1D_NegStride) {
}
TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
- StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({0});
m.SetEnd({2});
@@ -236,7 +240,7 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
}
TEST(StridedSliceOpTest, In1D_OddLenStride2) {
- StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3});
m.SetBegin({0});
m.SetEnd({3});
@@ -247,7 +251,7 @@ TEST(StridedSliceOpTest, In1D_OddLenStride2) {
}
TEST(StridedSliceOpTest, In2D_Identity) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
@@ -258,7 +262,7 @@ TEST(StridedSliceOpTest, In2D_Identity) {
}
TEST(StridedSliceOpTest, In2D) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetEnd({2, 2});
@@ -269,7 +273,7 @@ TEST(StridedSliceOpTest, In2D) {
}
TEST(StridedSliceOpTest, In2D_Stride2) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
@@ -280,7 +284,7 @@ TEST(StridedSliceOpTest, In2D_Stride2) {
}
TEST(StridedSliceOpTest, In2D_NegStride) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, -1});
m.SetEnd({2, -4});
@@ -291,7 +295,7 @@ TEST(StridedSliceOpTest, In2D_NegStride) {
}
TEST(StridedSliceOpTest, In2D_BeginMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetEnd({2, 2});
@@ -302,7 +306,7 @@ TEST(StridedSliceOpTest, In2D_BeginMask) {
}
TEST(StridedSliceOpTest, In2D_EndMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetEnd({2, 2});
@@ -313,7 +317,7 @@ TEST(StridedSliceOpTest, In2D_EndMask) {
}
TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, -2});
m.SetEnd({2, -4});
@@ -324,7 +328,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
}
TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, -2});
m.SetEnd({2, -3});
@@ -335,7 +339,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
}
TEST(StridedSliceOpTest, In3D_Identity) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -347,7 +351,7 @@ TEST(StridedSliceOpTest, In3D_Identity) {
}
TEST(StridedSliceOpTest, In3D_NegStride) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({-1, -1, -1});
m.SetEnd({-3, -4, -3});
@@ -359,7 +363,7 @@ TEST(StridedSliceOpTest, In3D_NegStride) {
}
TEST(StridedSliceOpTest, In3D_Strided2) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -370,7 +374,7 @@ TEST(StridedSliceOpTest, In3D_Strided2) {
}
TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
@@ -381,7 +385,7 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
}
TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({2});
m.SetEnd({1});
@@ -392,7 +396,7 @@ TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) {
}
TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetEnd({3});
@@ -403,7 +407,7 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
}
TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
- StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-2});
m.SetEnd({-3});
@@ -414,7 +418,7 @@ TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
}
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
@@ -425,7 +429,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
}
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
@@ -436,7 +440,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) {
}
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, 0});
m.SetEnd({2, 3});
@@ -447,7 +451,7 @@ TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -458,7 +462,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -469,7 +473,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -480,7 +484,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -491,7 +495,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -502,7 +506,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -513,7 +517,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
- StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
+ StridedSliceOpModel<> m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({2, 3, 2});
@@ -525,7 +529,7 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) {
// This tests catches a very subtle bug that was fixed by cl/188403234.
TEST(StridedSliceOpTest, RunTwice) {
- StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
+ StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
auto setup_inputs = [&m]() {
m.SetInput({1, 2, 3, 4, 5, 6});
@@ -544,6 +548,17 @@ TEST(StridedSliceOpTest, RunTwice) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5}));
}
+TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
+ StridedSliceOpModel<uint8, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
+ 0, 0, 1);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 791d1378f3..606f4a5635 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -32,6 +32,32 @@ namespace tflite {
const char* kEmptyTensorName = "";
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter) {
+ switch (tensor_type) {
+ case TensorType_FLOAT32:
+ *type = kTfLiteFloat32;
+ break;
+ case TensorType_INT32:
+ *type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ *type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ *type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ *type = kTfLiteString;
+ break;
+ default:
+ error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor_type), tensor_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
@@ -307,10 +333,25 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_EXP:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_CAST:
case BuiltinOperator_DEQUANTIZE:
case BuiltinOperator_PRELU:
break;
+ case BuiltinOperator_CAST: {
+ TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+ auto in_status =
+ ConvertTensorType(schema_params->in_data_type(),
+ &params->in_data_type, error_reporter);
+ auto out_status =
+ ConvertTensorType(schema_params->out_data_type(),
+ &params->out_data_type, error_reporter);
+ if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+ break;
+ }
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
MallocPOD<TfLiteLSHProjectionParams>();
@@ -707,29 +748,10 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
}
TfLiteType type;
- switch (tensor->type()) {
- case TensorType_FLOAT32:
- type = kTfLiteFloat32;
- break;
- case TensorType_INT32:
- type = kTfLiteInt32;
- break;
- case TensorType_UINT8:
- type = kTfLiteUInt8;
- break;
- case TensorType_INT64:
- type = kTfLiteInt64;
- break;
- case TensorType_STRING:
- type = kTfLiteString;
- break;
- default:
- // tensorType = ArrayType::NONE;
- error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
- EnumNameTensorType(tensor->type()),
- tensor->type());
- status = kTfLiteError;
- continue;
+ if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
+ kTfLiteOk) {
+ status = kTfLiteError;
+ continue;
}
auto get_readonly_data = [&](const char** buffer_data,
size_t* buffer_size) {
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index decaf9f160..bc13444dc7 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -162,7 +162,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
};
auto duplicate_state_tensor_float32 =
- [interpreter, &nn_model, &augmented_inputs, &next_id](int tensor_id) {
+ [interpreter, &nn_model, &augmented_inputs](int tensor_id) {
const TfLiteTensor* tensor = interpreter->tensor(tensor_id);
CHECK_NN(ANeuralNetworksModel_setOperandValue(
nn_model, tensor_id, tensor->data.raw, tensor->bytes));
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index e70aa51298..e735062a7f 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -101,6 +101,7 @@ py_test(
name = "convert_saved_model_test",
srcs = ["convert_saved_model_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
visibility = ["//visibility:public"],
deps = [
":convert_saved_model",
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7d2e00fe32..c63bfb28cc 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -381,6 +381,8 @@ table LogSoftmaxOptions {
}
table CastOptions {
+ in_data_type: TensorType;
+ out_data_type: TensorType;
}
table DequantizeOptions {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 66a97a1460..0735be5c8f 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -3702,14 +3702,30 @@ flatbuffers::Offset<LogSoftmaxOptions> CreateLogSoftmaxOptions(flatbuffers::Flat
struct CastOptionsT : public flatbuffers::NativeTable {
typedef CastOptions TableType;
- CastOptionsT() {
+ TensorType in_data_type;
+ TensorType out_data_type;
+ CastOptionsT()
+ : in_data_type(TensorType_FLOAT32),
+ out_data_type(TensorType_FLOAT32) {
}
};
struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef CastOptionsT NativeTableType;
+ enum {
+ VT_IN_DATA_TYPE = 4,
+ VT_OUT_DATA_TYPE = 6
+ };
+ TensorType in_data_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_IN_DATA_TYPE, 0));
+ }
+ TensorType out_data_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUT_DATA_TYPE, 0));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_IN_DATA_TYPE) &&
+ VerifyField<int8_t>(verifier, VT_OUT_DATA_TYPE) &&
verifier.EndTable();
}
CastOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3720,6 +3736,12 @@ struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
struct CastOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
+ void add_in_data_type(TensorType in_data_type) {
+ fbb_.AddElement<int8_t>(CastOptions::VT_IN_DATA_TYPE, static_cast<int8_t>(in_data_type), 0);
+ }
+ void add_out_data_type(TensorType out_data_type) {
+ fbb_.AddElement<int8_t>(CastOptions::VT_OUT_DATA_TYPE, static_cast<int8_t>(out_data_type), 0);
+ }
explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -3733,8 +3755,12 @@ struct CastOptionsBuilder {
};
inline flatbuffers::Offset<CastOptions> CreateCastOptions(
- flatbuffers::FlatBufferBuilder &_fbb) {
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType in_data_type = TensorType_FLOAT32,
+ TensorType out_data_type = TensorType_FLOAT32) {
CastOptionsBuilder builder_(_fbb);
+ builder_.add_out_data_type(out_data_type);
+ builder_.add_in_data_type(in_data_type);
return builder_.Finish();
}
@@ -5727,6 +5753,8 @@ inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t
inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
+ { auto _e = in_data_type(); _o->in_data_type = _e; };
+ { auto _e = out_data_type(); _o->out_data_type = _e; };
}
inline flatbuffers::Offset<CastOptions> CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -5737,8 +5765,12 @@ inline flatbuffers::Offset<CastOptions> CreateCastOptions(flatbuffers::FlatBuffe
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _in_data_type = _o->in_data_type;
+ auto _out_data_type = _o->out_data_type;
return tflite::CreateCastOptions(
- _fbb);
+ _fbb,
+ _in_data_type,
+ _out_data_type);
}
inline DequantizeOptionsT *DequantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index d552de313c..8a35fb9034 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -234,6 +234,7 @@ cc_library(
"graph_transformations/identify_relu1.cc",
"graph_transformations/lstm_utils.cc",
"graph_transformations/make_initial_dequantize_operator.cc",
+ "graph_transformations/merge_reshape_into_preceding_transpose.cc",
"graph_transformations/propagate_activation_function_into_constants.cc",
"graph_transformations/propagate_array_data_types.cc",
"graph_transformations/propagate_fixed_sizes.cc",
@@ -251,7 +252,8 @@ cc_library(
"graph_transformations/remove_trivial_reshape.cc",
"graph_transformations/remove_trivial_slice.cc",
"graph_transformations/remove_unused_op.cc",
- "graph_transformations/reorder_activation_functions.cc",
+ "graph_transformations/reorder_elementwise_unary.cc",
+ "graph_transformations/reorder_reshape_transpose.cc",
"graph_transformations/resolve_batch_normalization.cc",
"graph_transformations/resolve_batch_to_space_nd_attributes.cc",
"graph_transformations/resolve_constant_binary.cc",
@@ -259,6 +261,7 @@ cc_library(
"graph_transformations/resolve_constant_fake_quant.cc",
"graph_transformations/resolve_constant_fill.cc",
"graph_transformations/resolve_constant_gather.cc",
+ "graph_transformations/resolve_constant_random_uniform.cc",
"graph_transformations/resolve_constant_range.cc",
"graph_transformations/resolve_constant_shape_or_rank.cc",
"graph_transformations/resolve_constant_stack.cc",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 52c789293c..39e49bc347 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -211,6 +211,7 @@ struct ParsedModelFlags {
Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
Arg<string> arrays_extra_info_file;
+ Arg<string> model_flags_file;
};
// Flags that describe the operation you would like to do (what conversion
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 22a23357b3..5d51431005 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -357,6 +357,14 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
@@ -391,84 +399,6 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
}
}
-void ConvertDilatedConvOperator(const Model& model, const ConvOperator& src_op,
- GraphDef* tensorflow_graph) {
- CHECK((src_op.dilation_width_factor > 1) ||
- (src_op.dilation_height_factor > 1))
- << "Conv operator must have height or width dilation factor > 1. "
- "Otherwise, use regular conv op.";
- CHECK_EQ(src_op.stride_width, 1)
- << "Dilated AND strided convolution is unsupported";
- CHECK_EQ(src_op.stride_height, 1)
- << "Dilated AND strided convolution is unsupported";
-
- // Emulate dilated convolution with a chain of SpaceToBatchND -> Conv ->
- // BatchToSpaceND ops.
-
- // Compute padding
- const auto& input_array = model.GetArray(src_op.inputs[0]);
- const auto& input_shape = input_array.shape();
- CHECK_EQ(input_shape.dimensions_count(), 4);
- int height_mod_dilation = input_shape.dims(1) % src_op.dilation_height_factor;
- int pad_height;
- if (height_mod_dilation) {
- pad_height = src_op.dilation_height_factor - height_mod_dilation;
- } else {
- pad_height = 0;
- }
- int pad_width;
- int width_mod_dilation = input_shape.dims(2) % src_op.dilation_width_factor;
- if (width_mod_dilation) {
- pad_width = src_op.dilation_width_factor - width_mod_dilation;
- } else {
- pad_width = 0;
- }
-
- // SpaceToBatchND op "collapses" the spatially separated elements together
- string stb_output = src_op.outputs[0] + "/dilated_conv_SpaceToBatch";
- auto* stb_op = tensorflow_graph->add_node();
- stb_op->set_op("SpaceToBatchND");
- stb_op->set_name(stb_output);
- *stb_op->add_input() = src_op.inputs[0];
- (*stb_op->mutable_attr())["T"].set_type(DT_FLOAT);
- string block_shape = src_op.outputs[0] + "/dilated_conv_block_shape";
- CreateIntTensorConst(
- block_shape,
- {src_op.dilation_height_factor, src_op.dilation_width_factor}, {2},
- tensorflow_graph);
- *stb_op->add_input() = block_shape;
- (*stb_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
- string stb_paddings = src_op.outputs[0] + "/dilated_conv_paddings";
- CreateIntTensorConst(stb_paddings, {0, pad_height, pad_width, 0}, {2, 2},
- tensorflow_graph);
- *stb_op->add_input() = stb_paddings;
- (*stb_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
-
- // Perform a regular conv on the "collapsed" elements
- ConvOperator conv_op;
- string conv_output = src_op.outputs[0] + "/dilated_conv_Conv2D";
- conv_op.inputs = src_op.inputs;
- conv_op.inputs[0] = stb_output;
- conv_op.outputs = {conv_output};
- conv_op.padding.type = src_op.padding.type;
- conv_op.stride_width = src_op.stride_width;
- conv_op.stride_height = src_op.stride_height;
- conv_op.dilation_width_factor = 1;
- conv_op.dilation_height_factor = 1;
- ConvertConvOperator(model, conv_op, tensorflow_graph);
-
- // BatchToSpaceND op restores elements to their original layout
- auto* bts_op = tensorflow_graph->add_node();
- bts_op->set_op("BatchToSpaceND");
- bts_op->set_name(src_op.outputs[0]);
- *bts_op->add_input() = conv_output;
- (*bts_op->mutable_attr())["T"].set_type(DT_FLOAT);
- *bts_op->add_input() = block_shape;
- (*bts_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
- *bts_op->add_input() = stb_paddings;
- (*bts_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
-}
-
void ConvertDepthwiseConvOperator(const Model& model,
const DepthwiseConvOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1711,6 +1641,23 @@ void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
(*topk_op->mutable_attr())["sorted"].set_b(true);
}
+void ConvertRandomUniformOperator(const Model& model,
+ const RandomUniformOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ CHECK(tensorflow_graph != nullptr);
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("RandomUniform");
+ CHECK_EQ(src_op.inputs.size(), 1);
+ new_op->set_name(src_op.outputs[0]);
+ *new_op->add_input() = src_op.inputs[0];
+ const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(shape_type);
+ (*new_op->mutable_attr())["dtype"].set_type(
+ GetTensorFlowDataType(src_op.dtype));
+ (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
+ (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1719,13 +1666,8 @@ void ConvertOperator(const Model& model, const Operator& src_op,
}
if (src_op.type == OperatorType::kConv) {
- const ConvOperator& conv_op = static_cast<const ConvOperator&>(src_op);
- if ((conv_op.dilation_width_factor != 1) ||
- (conv_op.dilation_height_factor != 1)) {
- return ConvertDilatedConvOperator(model, conv_op, tensorflow_graph);
- } else {
- ConvertConvOperator(model, conv_op, tensorflow_graph);
- }
+ ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kDepthwiseConv) {
ConvertDepthwiseConvOperator(
model, static_cast<const DepthwiseConvOperator&>(src_op),
@@ -1897,6 +1839,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertTransposeConvOperator(
model, static_cast<const TransposeConvOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kRandomUniform) {
+ ConvertRandomUniformOperator(
+ model, static_cast<const RandomUniformOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index d38db85280..0fffab574d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -33,6 +33,11 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
if (conv_op->stride_width != conv_op->stride_height) {
return false;
}
+ if ((conv_op->dilation_width_factor != 1) ||
+ (conv_op->dilation_height_factor != 1)) {
+ // Depthwise conv does not support dilation
+ return false;
+ }
auto& weights_array = model->GetArray(conv_op->inputs[1]);
if (!weights_array.buffer) {
// Yield until the weights are resolved as a constant array.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 640afc7c74..27c5044bb3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -128,6 +128,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
+DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
@@ -152,7 +153,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
-DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary)
+DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose)
DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
@@ -173,6 +175,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 7c97ef0d31..23c9e3246b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -223,8 +223,11 @@ bool PropagateMinMaxAmongArrays(Model* model,
if (array.minmax) {
CHECK(*array.minmax == *reference_minmax)
<< "Both the following arrays have minmax, and they disagree: "
- << reference_array_name << " and " << array_name
- << ". Expected that either only one of them would have minmax, or at "
+ << reference_array_name << " (" << reference_minmax->min << ","
+ << reference_minmax->max << ") and " << array_name << " ("
+ << array.minmax->min << "," << array.minmax->max
+ << "). Expected that either only one of them would have minmax, or "
+ "at "
"least that they would agree.";
} else {
array.GetOrCreateMinMax() = *reference_minmax;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
new file mode 100644
index 0000000000..5065004093
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -0,0 +1,190 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+ if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+ !model.HasArray(op->outputs[0])) {
+ // Arrays are missing.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[0]).has_shape() ||
+ !model.GetArray(op->outputs[0]).has_shape()) {
+ // Input and output needs the shape.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[1]).buffer) {
+ // Buffer needs to be a constant.
+ return false;
+ }
+
+ return true;
+}
+
+// Returns whether the reshape could be a transpose.
+std::vector<int32> ReshapeToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op) {
+ CHECK(!op->shape.empty());
+ CHECK(model.HasArray(op->inputs[0]));
+ CHECK(model.HasArray(op->outputs[0]));
+
+ const auto& input_array = model.GetArray(op->inputs[0]);
+ const auto& output_array = model.GetArray(op->outputs[0]);
+
+ CHECK(input_array.has_shape());
+ CHECK(output_array.has_shape());
+
+ std::vector<int> in_shape = input_array.shape().dims();
+ std::vector<int> out_shape = output_array.shape().dims();
+
+ std::vector<int> one_indices;
+ std::vector<int> not_one_indices;
+
+ // Separate into one indices and not one indices.
+ for (int i = 0; i < in_shape.size(); i++) {
+ if (in_shape[i] == 1) {
+ one_indices.push_back(i);
+ } else {
+ not_one_indices.push_back(i);
+ }
+ }
+
+ // Reorder the vertices.
+ std::vector<int> perm;
+ perm.reserve(in_shape.size());
+ int one_index = 0;
+ int not_one_index = 0;
+ for (const auto val : out_shape) {
+ if (val == 1) {
+ perm.push_back(one_indices[one_index]);
+ one_index++;
+ } else {
+ perm.push_back(not_one_indices[not_one_index]);
+ not_one_index++;
+ }
+ }
+
+ return perm;
+}
+
+} // namespace
+
+// When a transpose is fed into a reshape, it is possible for the two operators
+// to be merged if the reshape does not affect memory ordering and does not
+// affects the number of dimensions. This only occurs when only unary dimensions
+// are shifting position.
+bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
+ std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
+ it->get(), OperatorType::kTensorFlowReshape);
+
+ if (reshape_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+ return false;
+ }
+
+ const string intermediate_name = reshape_op->inputs[0];
+ const string output_name = reshape_op->outputs[0];
+
+ // Guarantee the input is only consume by the reshape.
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ return false;
+ }
+
+ // Check for the parent operator.
+ const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
+ if (transpose_it == model->operators.end()) {
+ return false;
+ }
+
+ // Find the parent operator and guarantee it is a transpose.
+ TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+ transpose_it->get(), OperatorType::kTranspose);
+
+ if (transpose_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+ return false;
+ }
+
+ if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+ false /*allow_extra_unary_dimensions*/)) {
+ return false;
+ }
+
+ // Check that the intermediate is not an output array.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot fuse %s and %s as it would invalidate the transpose "
+ "output array.",
+ LogName(*transpose_op), LogName(*reshape_op));
+ return false;
+ }
+
+ AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
+ LogName(*reshape_op));
+
+ // const auto& intermediate_array = model->GetArray(intermediate_name);
+ // const auto& output_array = model->GetArray(output_name);
+
+ auto merged_perm = ReshapeToTranspose(*model, reshape_op);
+
+ // Combine the permutations.
+ const auto& transpose_perm = transpose_op->perm;
+ for (int i = 0; i < merged_perm.size(); i++) {
+ merged_perm[i] = transpose_perm[merged_perm[i]];
+ }
+
+ // Remove the reshape as passthrough operation.
+ if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
+ return false;
+ }
+
+ // Update transpose_op's constant buffer to contain the new permutation.
+ model->GetArray(transpose_op->inputs[1])
+ .GetMutableBuffer<ArrayDataType::kInt32>()
+ .data = merged_perm;
+ transpose_op->perm = merged_perm;
+
+ // transpose_ops's shape will likely has changed.
+ model->GetArray(transpose_op->outputs[0]).clear_shape();
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 778da39bf1..89ad58f887 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -50,78 +50,108 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
old_output_data_types[output] = model->GetArray(output).data_type;
}
// Do the actual output data types propagation.
- if (op->type == OperatorType::kDequantize ||
- op->type == OperatorType::kResizeBilinear) {
- // These operators unconditionally produce float outputs
- SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
- } else if (op->type == OperatorType::kTensorFlowLess ||
- op->type == OperatorType::kTensorFlowLessEqual ||
- op->type == OperatorType::kTensorFlowGreater ||
- op->type == OperatorType::kTensorFlowGreaterEqual) {
- // These operators unconditionally produce bool outputs
- SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
- } else if (op->type == OperatorType::kRank ||
- op->type == OperatorType::kTensorFlowShape) {
- // These operators only produce int32 outputs.
- SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
- } else if (op->type == OperatorType::kTensorFlowSplit ||
- op->type == OperatorType::kTensorFlowConcat ||
- op->type == OperatorType::kFill) {
- // These operators produce an output with the same type as their 2nd input
- CHECK_GE(op->inputs.size(), 2);
- const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
- } else if (op->type == OperatorType::kTransposeConv) {
- // These operators produce an output with the same type as their 3rd input
- CHECK_GE(op->inputs.size(), 3);
- const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
- } else if (op->type == OperatorType::kCast) {
- // Data type of the Cast op is specified.
- CHECK_EQ(op->outputs.size(), 1);
- auto* cast_op = static_cast<CastOperator*>(op);
- model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
- } else if (op->type == OperatorType::kArgMax) {
- // Data type of the ArgMax op is specified.
- CHECK_EQ(op->outputs.size(), 1);
- auto* argmax_op = static_cast<ArgMaxOperator*>(op);
- model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
- } else if (op->type == OperatorType::kRange) {
- auto* range_op = static_cast<RangeOperator*>(op);
- // Output type of the Range op can be set via an attribute
- ArrayDataType data_type;
- if (range_op->dtype != ArrayDataType::kNone) {
- // Use the type if specified
- data_type = range_op->dtype;
- } else {
- // Otherwise use the first input
- CHECK_GE(op->inputs.size(), 1);
- data_type = model->GetArray(op->inputs[0]).data_type;
+ switch (op->type) {
+ case OperatorType::kDequantize:
+ case OperatorType::kResizeBilinear:
+ // These operators unconditionally produce float outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+ break;
+ case OperatorType::kTensorFlowLess:
+ case OperatorType::kTensorFlowLessEqual:
+ case OperatorType::kTensorFlowGreater:
+ case OperatorType::kTensorFlowGreaterEqual:
+ // These operators unconditionally produce bool outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+ break;
+ case OperatorType::kRank:
+ case OperatorType::kTensorFlowShape:
+ // These operators only produce int32 outputs.
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+ break;
+ case OperatorType::kTensorFlowSplit:
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kFill: {
+ // These operators produce an output with the same type as their 2nd input
+ CHECK_GE(op->inputs.size(), 2);
+ const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
}
- CHECK_EQ(op->outputs.size(), 1);
- SetDataTypeForAllOutputs(model, op, data_type);
- } else if (op->type == OperatorType::kTensorFlowUnsupported) {
- auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
- // Some output tensors from the op could be eliminated by optimization.
- // This can make unsupported_op->output_data_types have more elements than
- // op->outputs.
- if (unsupported_op->output_data_types.size() < op->outputs.size()) {
+ case OperatorType::kTransposeConv: {
+ // These operators produce an output with the same type as their 3rd input
+ CHECK_GE(op->inputs.size(), 3);
+ const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
+ case OperatorType::kCast: {
+ // Data type of the Cast op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* cast_op = static_cast<CastOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
+ break;
+ }
+ case OperatorType::kArgMax: {
+ // Data type of the ArgMax op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmax_op = static_cast<ArgMaxOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
+ break;
+ }
+ case OperatorType::kRange: {
+ auto* range_op = static_cast<RangeOperator*>(op);
+ // Output type of the Range op can be set via an attribute
+ ArrayDataType data_type;
+ if (range_op->dtype != ArrayDataType::kNone) {
+ // Use the type if specified
+ data_type = range_op->dtype;
+ } else {
+ // Otherwise use the first input
+ CHECK_GE(op->inputs.size(), 1);
+ data_type = model->GetArray(op->inputs[0]).data_type;
+ }
+ CHECK_EQ(op->outputs.size(), 1);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
+ case OperatorType::kRandomUniform: {
+ auto* rand_op = static_cast<RandomUniformOperator*>(op);
+ // The output type of RandomUniform is specified with an attribute
+ if (rand_op->dtype == ArrayDataType::kNone) {
+ return false;
+ }
+ CHECK_EQ(op->outputs.size(), 1);
+ SetDataTypeForAllOutputs(model, op, rand_op->dtype);
+ break;
+ }
+ case OperatorType::kTensorFlowUnsupported: {
+ auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+ // Some output tensors from the op could be eliminated by optimization.
+ // This can make unsupported_op->output_data_types have more elements than
+ // op->outputs.
+ if (unsupported_op->output_data_types.size() < op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < op->outputs.size(); ++i) {
+ auto output = op->outputs[i];
+ auto data_type = unsupported_op->output_data_types[i];
+ model->GetArray(output).data_type = data_type;
+ }
+ break;
+ }
+ case OperatorType::kExpandDims: {
+ // Yield on ExpandDim until it is converted to Reshape
return false;
}
- for (int i = 0; i < op->outputs.size(); ++i) {
- auto output = op->outputs[i];
- auto data_type = unsupported_op->output_data_types[i];
- model->GetArray(output).data_type = data_type;
+ default: {
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
}
- } else if (op->type == OperatorType::kExpandDims) {
- // Yield on ExpandDim until it is converted to Reshape
- return false;
- } else {
- // These operators produce outputs with the same type as their 1st input
- CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
- SetDataTypeForAllOutputs(model, op, data_type);
}
+
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
if (old_output_data_types[output] != model->GetArray(output).data_type) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 676736cfc5..68d6f21cf8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -38,6 +38,16 @@ void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
const int input_height = input_shape.dims(1);
const int batch = input_shape.dims(0);
+ CHECK_GE(input_width, 1);
+ CHECK_GE(input_height, 1);
+ CHECK_GE(batch, 1);
+ CHECK_GE(kwidth, 1);
+ CHECK_GE(kheight, 1);
+ CHECK_GE(stride_width, 1);
+ CHECK_GE(stride_height, 1);
+ CHECK_GE(dilation_width_factor, 1);
+ CHECK_GE(dilation_height_factor, 1);
+
int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
@@ -392,8 +402,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
depth * block_size * block_size}));
}
-void ProcessFillOperator(Model* model, FillOperator* op) {
- CHECK_EQ(op->inputs.size(), 2);
+void ProcessOpWithShapeInput(Model* model, Operator* op) {
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
@@ -1529,7 +1538,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<SpaceToDepthOperator*>(op));
break;
case OperatorType::kFill:
- ProcessFillOperator(model, static_cast<FillOperator*>(op));
+ CHECK_EQ(op->inputs.size(), 2);
+ ProcessOpWithShapeInput(model, op);
break;
case OperatorType::kFullyConnected:
ProcessFullyConnectedOperator(model,
@@ -1659,6 +1669,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
// transforms that remove them, so we avoid propagating shapes through
// them and let things settle once they've been removed.
break;
+ case OperatorType::kRandomUniform:
+ CHECK_EQ(op->inputs.size(), 1);
+ ProcessOpWithShapeInput(model, op);
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 9fcc95e1fe..7784558b22 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -472,6 +472,44 @@ bool ChooseQuantizationForOperatorOutput(
return true;
}
+
+// Fixes array minmax info to match the quantization parameters.
+// This is required for when quantization parameters change for an array during
+// quantization (such as ChooseQuantizationForOperatorOutput).
+void FixMinMaxPostQuantization(ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params,
+ MinMax* minmax) {
+ double qmin, qmax;
+ switch (quantized_data_type) {
+ case ArrayDataType::kUint8:
+ qmin = 0;
+ qmax = 255;
+ break;
+ case ArrayDataType::kInt16:
+ qmin = -32768;
+ qmax = 32767;
+ break;
+ default:
+ // No update required.
+ return;
+ }
+
+ // Compute new minmax values.
+ double min =
+ (qmin - quantization_params.zero_point) * quantization_params.scale;
+ double max =
+ (qmax - quantization_params.zero_point) * quantization_params.scale;
+
+ // If we are close to the existing minmax values don't bother changing them.
+ // This prevents propagating small floating point precision errors.
+ constexpr double kMinMaxThreshold = 1e-5;
+ const double width = max - min;
+ if (std::abs(min - minmax->min) > kMinMaxThreshold * width ||
+ std::abs(max - minmax->max) > kMinMaxThreshold * width) {
+ minmax->min = min;
+ minmax->max = max;
+ }
+}
} // namespace
bool Quantize::Run(Model* model, std::size_t op_index) {
@@ -618,12 +656,19 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
&quantization_params)) {
changed = true;
const auto& output = op.outputs[output_index];
+ auto& output_array = model->GetArray(output);
+
+ // Fix up the min/max information on the output array to match the chosen
+ // quantization parameters.
+ auto& output_minmax = output_array.GetMinMax();
+ FixMinMaxPostQuantization(quantized_data_type, quantization_params,
+ &output_minmax);
+
QuantizeArray(this, model, output, quantized_data_type,
quantization_params);
+
const auto& dequantized_output =
AvailableArrayName(*model, output + "_dequantized");
- const auto& output_array = model->GetArray(output);
- const auto& output_minmax = output_array.GetMinMax();
auto& dequantized_output_array =
model->GetOrCreateArray(dequantized_output);
dequantized_output_array.data_type = ArrayDataType::kFloat;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
index 11f8d4b6ee..bdcca5b7ca 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
@@ -72,6 +72,13 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
// We always want [min, max] to contain 0.
+ if (minmax.min > 0 || minmax.max < 0) {
+ LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range "
+ << "[" << minmax.min << ", " << minmax.max
+ << "] does not contain 0. "
+ << "Proceeding by tweaking it to contain 0, which will result "
+ "in poor accuracy.";
+ }
minmax.min = std::min(minmax.min, 0.);
minmax.max = std::max(minmax.max, 0.);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc
deleted file mode 100644
index 9852c86c21..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/runtime/types.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) {
- const auto ac_it = model->operators.begin() + op_index;
- std::unique_ptr<Operator>& ac_op = *ac_it;
- DCHECK(ac_op);
-
- if (ac_op->type != OperatorType::kRelu6 &&
- ac_op->type != OperatorType::kRelu1 &&
- ac_op->type != OperatorType::kRelu) {
- return false;
- }
-
- auto exchange_it = FindOpWithOutput(*model, ac_op->inputs[0]);
- if (exchange_it == model->operators.end()) return false;
- // Find the op producing the array passed to this activation function
- std::unique_ptr<Operator>& exchange_op = *exchange_it;
- DCHECK(exchange_op);
-
- // Allow activation functions to move up over any operator that does not
- // change the values.
- switch (exchange_op->type) {
- case OperatorType::kExpandDims:
- case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
- case OperatorType::kTranspose:
- break;
- default:
- return false;
- }
-
- DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]);
- const auto exchange_op_input = exchange_op->inputs[0];
- const auto intermediate_array = exchange_op->outputs[0];
- const auto ac_op_output = ac_op->outputs[0];
-
- int count_ops_consuming_output =
- CountOpsWithInput(*model, intermediate_array);
- DCHECK_GE(count_ops_consuming_output, 1);
- if (count_ops_consuming_output > 1) {
- AddMessageF(
- "Not exchanging activation function with %s because it is consumed by "
- "more than 1 other operator",
- LogName(*exchange_op));
- return false;
- }
-
- // If the ac_op was originally producing an output_array we can't trivially
- // reorder as otherwise the output array name would change and break
- // downstream assumptions. To work around that we perform some renaming below
- // in that case at the cost of a bit more confusing array names in this rare
- // case.
- bool is_ac_op_output =
- std::find(model->flags.output_arrays().begin(),
- model->flags.output_arrays().end(),
- ac_op_output) != model->flags.output_arrays().end();
- if (is_ac_op_output) {
- // To preserve the output array name of the activation function we need to
- // create a temporary to use to pass between ac->ex.
- //
- // Original:
- // (a) -> EX -> (b) -> AC -> (c)
- // Now:
- // (a) -> AC -> (c') -> EX -> (c)
- AddMessageF(
- "Exchanging activation function %s with %s but renaming to preserve "
- "output array %s",
- LogName(*ac_op), LogName(*exchange_op), ac_op->outputs[0]);
-
- auto renamed_ac_op_output =
- AvailableArrayName(*model, ac_op_output + "_exchange");
- ac_op->inputs[0] = exchange_op_input;
- ac_op->outputs[0] = renamed_ac_op_output;
- model->EraseArray(exchange_op->outputs[0]);
- exchange_op->inputs[0] = renamed_ac_op_output;
- exchange_op->outputs[0] = ac_op_output;
- } else {
- // Simply swap the order and update consumers to use the exchange_op output
- // array (b).
- //
- // Original:
- // (a) -> EX -> (b) -> AC -> (c)
- // Now:
- // (a) -> AC -> (c) -> EX -> (b)
- AddMessageF("Exchanging activation function %s with %s", LogName(*ac_op),
- LogName(*exchange_op));
-
- Operator* consumer = GetFirstOpWithInput(*model, ac_op_output);
- while (consumer) {
- for (int i = 0; i < consumer->inputs.size(); ++i) {
- if (consumer->inputs[i] == ac_op_output) {
- consumer->inputs[i] = intermediate_array;
- }
- }
- consumer = GetFirstOpWithInput(*model, ac_op_output);
- }
- ac_op->inputs[0] = exchange_op_input;
- exchange_op->inputs[0] = ac_op_output;
- }
-
- // Clear shapes; this will allow shape propagation to fix the sizes for us.
- model->GetOrCreateArray(ac_op->outputs[0]).clear_shape();
- model->GetOrCreateArray(exchange_op->outputs[0]).clear_shape();
-
- // Finally, reorder operators. Note that this only works when there are no
- // other direct descendents of the exchange_op.
- ac_op.swap(exchange_op);
-
- return true;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
new file mode 100644
index 0000000000..9f5b7920cb
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsElementwiseOperator(OperatorType optype) {
+ switch (optype) {
+ case OperatorType::kCast:
+ case OperatorType::kExp:
+ case OperatorType::kFloor:
+ case OperatorType::kNeg:
+ case OperatorType::kRelu:
+ case OperatorType::kRelu1:
+ case OperatorType::kRelu6:
+ case OperatorType::kTanh:
+ case OperatorType::kTensorFlowSqrt:
+ case OperatorType::kTensorFlowSquare:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsMoveOperator(OperatorType optype) {
+ switch (optype) {
+ case OperatorType::kDepthToSpace:
+ case OperatorType::kExpandDims:
+ case OperatorType::kSpaceToDepth:
+ case OperatorType::kSqueeze:
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace
+
+// Swap elementwise operators such that all value operators occur before all
+// element move operators, e.g. negation then transpose.
+bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
+ const auto element_op_it = model->operators.begin() + op_index;
+ std::unique_ptr<Operator>& element_op = *element_op_it;
+ if (!IsElementwiseOperator(element_op->type)) {
+ return false;
+ }
+
+ const string intermediate_name = element_op->inputs[0];
+ auto it = FindOpWithOutput(*model, intermediate_name);
+ if (it == model->operators.end()) {
+ AddMessageF("No preceding operator");
+ return false;
+ }
+
+ std::unique_ptr<Operator>& move_op = *it;
+ if (!IsMoveOperator(move_op->type)) {
+ AddMessageF("Preceding operator is not a move operator");
+ return false;
+ }
+
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ AddMessageF("Input %s used elsewhere", intermediate_name);
+ return false;
+ }
+
+ // Check that the intermediate is discardable.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot swap elementwise as it would invalidate %s which is "
+ "an output array.",
+ intermediate_name);
+ return false;
+ }
+
+ // op->inputs may change so we need to keep a value by copy.
+ const string input_name = move_op->inputs[0];
+ const string output_name = element_op->outputs[0];
+
+ AddMessageF("Swapping around operators with %s and %s", LogName(*element_op),
+ LogName(*move_op));
+
+ // If the output array is an exit node for the graph then we need to retain
+ // the name as an output node. This makes the naming scheme a little confusing
+ // but is required in this rare case.
+ if (!IsDiscardableArray(*model, output_name)) {
+ // The output name of the sequence needs to stay static, so create a new
+ // array new use for the intermediate.
+ const auto new_intermediate_name =
+ AvailableArrayName(*model, element_op->outputs[0] + "_reorder");
+ AddMessageF("Adding new array %s to preserve output array name %s",
+ new_intermediate_name, output_name);
+
+ element_op->inputs[0] = input_name;
+ element_op->outputs[0] = new_intermediate_name;
+ model->EraseArray(intermediate_name);
+ move_op->inputs[0] = new_intermediate_name;
+ move_op->outputs[0] = output_name;
+ } else {
+ // The intermediate array is now the output array.
+ for (int i = 0; i < model->operators.size(); i++) {
+ Operator* consumer = model->operators[i].get();
+ for (int j = 0; j < consumer->inputs.size(); j++) {
+ if (consumer->inputs[j] == output_name) {
+ consumer->inputs[j] = intermediate_name;
+ }
+ }
+ }
+
+ element_op->inputs[0] = input_name;
+ move_op->inputs[0] = output_name;
+ }
+
+ // Reset both arrays as shape, type, min/max, etc can all change because of
+ // the position swap.
+ model->EraseArray(element_op->outputs[0]);
+ model->EraseArray(move_op->outputs[0]);
+
+ // Reconstruct.
+ model->GetOrCreateArray(element_op->outputs[0]);
+ model->GetOrCreateArray(move_op->outputs[0]);
+
+ // Swap the order of the operators.
+ element_op.swap(move_op);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
new file mode 100644
index 0000000000..9e7fe1b1cc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
@@ -0,0 +1,248 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+ if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+ !model.HasArray(op->outputs[0])) {
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[0]).has_shape() ||
+ !model.GetArray(op->outputs[0]).has_shape()) {
+ // Input and output needs the shape.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[1]).buffer) {
+ // Buffer needs to be a constant.
+ return false;
+ }
+
+ return true;
+}
+
+// Utility function to filter out a value.
+void Filter(std::vector<int>* vec, int value) {
+ vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
+}
+
+// Computes a new permutation used to swap a reshape-transpose to a
+// transpose-reshape. In this case the permutation operates on the intermediate
+// shape.
+std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
+ std::vector<int> intermediate_dims,
+ std::vector<int> perm) {
+ // These are the major axis of the input.
+ std::vector<int> input_indices;
+ for (int i = 0; i < input_dims.size(); i++) {
+ if (input_dims[i] != 1) {
+ input_indices.push_back(i);
+ }
+ }
+
+ // This maps which indices of the input produced the intermediate indices for
+ // non-unary dimensions.
+ std::unordered_map<int, int> intermediate_to_input_indices_map;
+ for (int i = 0; i < intermediate_dims.size(); i++) {
+ if (intermediate_dims[i] != 1) {
+ intermediate_to_input_indices_map[i] =
+ input_indices[intermediate_to_input_indices_map.size()];
+ }
+ }
+
+ // Translate the transpose permutation to a new permutation starting with the
+ // major indices.
+ std::vector<int> new_perm;
+ new_perm.reserve(input_dims.size());
+ for (int i = 0; i < perm.size(); i++) {
+ if (intermediate_dims[perm[i]] == 1) continue;
+
+ new_perm.push_back(intermediate_to_input_indices_map[perm[i]]);
+ }
+
+ // Fill the rest of the transpose in with the ones.
+ for (int index = 0; index < input_dims.size(); index++) {
+ if (input_dims[index] == 1) {
+ new_perm.push_back(index);
+ }
+ }
+
+ CHECK_EQ(new_perm.size(), input_dims.size());
+ return new_perm;
+}
+
+} // namespace
+
+// Swaps reshape-transpose to transpose-reshape whenever possible. This is
+// possible when the reshape does not affect memory ordering.
+bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
+ auto transpose_it = model->operators.begin() + op_index;
+
+ TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+ transpose_it->get(), OperatorType::kTranspose);
+
+ if (transpose_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+ // Wait for values to propagate.
+ return false;
+ }
+
+ // Find the operator that produces the transpose op.
+ auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]);
+ if (reshape_it == model->operators.end()) {
+ return false;
+ }
+
+ TensorFlowReshapeOperator* reshape_op =
+ ConvertOperator<TensorFlowReshapeOperator*>(
+ reshape_it->get(), OperatorType::kTensorFlowReshape);
+ if (reshape_op == nullptr) {
+ return false;
+ }
+
+ // Ignore if the reshape is uninitialized.
+ if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+ return false;
+ }
+
+ // Need to copy to keep static if permutated.
+ const string input_name = reshape_op->inputs[0];
+ const string intermediate_name = reshape_op->outputs[0];
+ const string output_name = transpose_op->outputs[0];
+
+ // Intermediate should not be consumed by any other operators.
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ AddMessageF("Input %s used elsewhere", intermediate_name);
+ return false;
+ }
+
+ // Check that the intermediate is not an output array.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot reorder reshape-transpose as it would invalidate %s which is "
+ "an output array.",
+ intermediate_name);
+ return false;
+ }
+
+ // Get the arrays.
+ const auto& input_array = model->GetArray(input_name);
+ const auto& intermediate_array = model->GetArray(intermediate_name);
+ const auto& output_array = model->GetArray(output_name);
+
+ // Get the shapes of each array.
+ Shape input_shape = input_array.shape();
+ Shape intermediate_shape = intermediate_array.shape();
+ Shape output_shape = output_array.shape();
+
+ // Assign ids to non-unary indices.
+ std::vector<int> input_dims = input_shape.dims();
+ std::vector<int> intermediate_dims = intermediate_shape.dims();
+ std::vector<int> output_dims = output_shape.dims();
+
+ // If the reshape is equivalent to a transpose with fewer/more unary
+ // dimensions then it can be moved between the transpose.
+ if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+ true /*allow_extra_unary_dims*/)) {
+ return false;
+ }
+
+ if (!IsDiscardableArray(*model, output_name)) {
+ // The output name of the sequence needs to stay static, so create a new
+ // array new use for the intermediate.
+ const auto new_intermediate_name =
+ AvailableArrayName(*model, transpose_op->outputs[0] + "_exchange");
+ AddMessageF("Adding new array %s to preserve output array name %s",
+ new_intermediate_name, transpose_op->outputs[0]);
+ transpose_op->inputs[0] = input_name;
+ transpose_op->outputs[0] = new_intermediate_name;
+ reshape_op->inputs[0] = new_intermediate_name;
+ reshape_op->outputs[0] = output_name;
+ model->EraseArray(intermediate_name);
+ } else {
+ // The intermediate array is now the output array.
+ for (int i = 0; i < model->operators.size(); i++) {
+ Operator* consumer = model->operators[i].get();
+ for (int j = 0; j < consumer->inputs.size(); j++) {
+ if (consumer->inputs[j] == output_name) {
+ consumer->inputs[j] = intermediate_name;
+ }
+ }
+ }
+
+ transpose_op->inputs[0] = input_name;
+ reshape_op->inputs[0] = output_name;
+ }
+
+ // If transposes constant buffer is used elsewhere, make a new copy.
+ if (CountOpsWithInput(*model, transpose_op->inputs[1]) != 1) {
+ transpose_op->inputs[1] =
+ AvailableArrayName(*model, transpose_op->inputs[1] + "_copy");
+ }
+
+ // Make the new transpose permutation.
+ const std::vector<int> new_perm =
+ ComputeNewPerm(input_dims, intermediate_dims, transpose_op->perm);
+ CHECK_EQ(input_dims.size(), new_perm.size());
+
+ auto& transpose_array = model->GetOrCreateArray(transpose_op->inputs[1]);
+ transpose_array.GetMutableBuffer<ArrayDataType::kInt32>().data = new_perm;
+ *(transpose_array.mutable_shape()->mutable_dims()) = {
+ static_cast<int>(new_perm.size())};
+ transpose_op->perm = new_perm;
+
+ // If the reshape's constant buffer is reused, create a new one.
+ if (CountOpsWithInput(*model, reshape_op->inputs[1]) != 1) {
+ reshape_op->inputs[1] =
+ AvailableArrayName(*model, reshape_op->inputs[1] + "_copy");
+ }
+
+ // We need to modify the reshape input array to target the new output size.
+ auto& reshape_array = model->GetOrCreateArray(reshape_op->inputs[1]);
+ reshape_array.GetMutableBuffer<ArrayDataType::kInt32>().data = output_dims;
+ *(reshape_array.mutable_shape()->mutable_dims()) = {
+ static_cast<int>(output_shape.dimensions_count())};
+ reshape_op->shape.clear();
+
+ AddMessageF("Swapping around operators between %s and %s", input_name,
+ output_name);
+
+ model->GetOrCreateArray(transpose_op->outputs[0]).clear_shape();
+ model->GetOrCreateArray(reshape_op->outputs[0]).clear_shape();
+
+ // Swap the order of the operators.
+ transpose_it->swap(*reshape_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
new file mode 100644
index 0000000000..88d06d7dc7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace toco {
+
+template <ArrayDataType Type>
+bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
+ typedef tensorflow::random::UniformDistribution<
+ tensorflow::random::PhiloxRandom, DataType<Type>>
+ Distribution;
+
+ // Allocate output
+ auto& output_array = model->GetArray(op->outputs[0]);
+ CHECK(output_array.data_type == Type);
+ std::vector<DataType<Type>>& data =
+ output_array.GetMutableBuffer<Type>().data;
+ data.resize(RequiredBufferSizeForShape(output_array.shape()));
+
+ // We use the same random number generator and distribution as TensorFlow to
+ // produce the exact same values given the same seeds. See
+ // tensorflow::functor::FillPhiloxRandomTask<Distribution, false> in
+ // //third_party/tensorflow/core/kernels/random_op.cc for the implementation.
+ tensorflow::random::PhiloxRandom generator(op->seed, op->seed2);
+ Distribution dist;
+
+ // The generator creates Distribution::kResultElementCount samples at a time.
+ size_t offset = 0;
+ size_t num_samples = Distribution::kResultElementCount;
+ while (offset < data.size()) {
+ const typename Distribution::ResultType samples = dist(&generator);
+ std::copy(&samples[0],
+ &samples[0] + std::min(num_samples, data.size() - offset),
+ &data[0] + offset);
+ offset += num_samples;
+ }
+
+ return true;
+}
+
+bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* base_op = it->get();
+ if (base_op->type != OperatorType::kRandomUniform) {
+ return false;
+ }
+ auto* op = static_cast<RandomUniformOperator*>(base_op);
+
+ CHECK_EQ(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return false;
+ }
+
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes
+ return false;
+ }
+
+ if ((op->seed == 0) && (op->seed2 == 0)) {
+ LOG(WARNING) << "RandomUniform op outputting \"" << op->outputs[0]
+ << "\" is truly random (using /dev/random system entropy). "
+ "Therefore, cannot resolve as constant. Set \"seed\" or "
+ "\"seed2\" attr non-zero to fix this";
+ return false;
+ }
+
+ switch (output_array.data_type) {
+ case ArrayDataType::kFloat:
+ if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) {
+ return false;
+ }
+ break;
+ // For future support of double or half.
+ // case ArrayDataType::kDouble...
+ default:
+ LOG(FATAL)
+ << "Unsupported data type given to RandomUniform op with output \""
+ << op->outputs[0] << "\"";
+ break;
+ }
+
+ // Erase input arrays if no longer used
+ toco::DeleteArrayIfUsedOnce(op->inputs[0], model);
+
+ // Erase the operator
+ model->operators.erase(it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index f38203c80f..2a236d3f98 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -60,6 +60,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
string input_lhs = matmul_op->inputs[0];
string input_rhs = transpose_op->outputs[0];
+ // Construct the new FullyConnectedOperator.
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->outputs = matmul_op->outputs;
+
+ // Insert the newly constructed FullyConnectedOperator.
+ model->operators.emplace(matmul_it, fc_op) + 1;
+
// Find the op producing the array passed to this MatMul
auto previous_op_it = model->operators.begin();
bool found = false;
@@ -76,13 +83,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
}
Operator* previous_op = (found) ? previous_op_it->get() : nullptr;
- // Construct the new FullyConnectedOperator.
- auto* fc_op = new FullyConnectedOperator;
- fc_op->outputs = matmul_op->outputs;
-
- // Insert the newly constructed FullyConnectedOperator.
- model->operators.emplace(matmul_it, fc_op) + 1;
-
// Refresh iterator.
matmul_it = model->operators.begin();
for (; matmul_it != model->operators.end(); ++matmul_it) {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index c26e4bddff..876479079b 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -74,7 +74,7 @@ const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
return attr.s();
}
-int GetIntAttr(const NodeDef& node, const string& attr_name) {
+int64 GetIntAttr(const NodeDef& node, const string& attr_name) {
CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
<< node.DebugString();
const auto& attr = node.attr().at(attr_name);
@@ -569,6 +569,23 @@ void ConvertBiasAddOperator(const NodeDef& node,
model->operators.emplace_back(biasadd);
}
+void ConvertRandomUniform(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "RandomUniform");
+ CheckInputsCount(node, tf_import_flags, 1);
+
+ CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
+ auto op = absl::make_unique<RandomUniformOperator>();
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
+ op->seed = GetIntAttr(node, "seed");
+ op->seed2 = GetIntAttr(node, "seed2");
+ CHECK(model != nullptr);
+ model->operators.emplace_back(std::move(op));
+}
+
void ConvertReluOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1931,7 +1948,7 @@ void ConvertTopKV2Operator(const NodeDef& node,
// K can be encoded as attr (TopK) convert it to a const.
if (HasAttr(node, "k")) {
string k_array = CreateConstArray<ArrayDataType::kInt32>(
- model, node.name() + "k", {GetIntAttr(node, "k")});
+ model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
op->inputs.push_back(k_array);
} else {
CheckInputsCount(node, tf_import_flags, 2);
@@ -2168,6 +2185,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
} else if (node.op() == "DynamicStitch" ||
node.op() == "ParallelDynamicStitch") {
ConvertDynamicStitchOperator(node, tf_import_flags, model);
+ } else if (node.op() == "RandomUniform") {
+ ConvertRandomUniform(node, tf_import_flags, model);
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 5199e292e1..9bd72e7de1 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -60,6 +60,7 @@ enum class OperatorType {
kMaxPool,
kFakeQuant,
kMul,
+ kRandomUniform,
kRange,
kRank,
kRelu,
@@ -946,6 +947,13 @@ struct FloorModOperator : Operator {
FloorModOperator() : Operator(OperatorType::kFloorMod) {}
};
+struct RandomUniformOperator : Operator {
+ RandomUniformOperator() : Operator(OperatorType::kRandomUniform) {}
+ ArrayDataType dtype = ArrayDataType::kNone;
+ int64 seed;
+ int64 seed2;
+};
+
// Creates a sequence of numbers that begins at start and extends by increments
// of delta up to but not including limit.
//
@@ -1499,7 +1507,14 @@ class Shape {
// We still have that one convenience accessor to avoid
// the awkward double bracket issue: shape.dims()[i].
- int dims(int i) const { return dims_[i]; }
+ int dims(int i) const {
+ // Always check for out-of-bounds accesses, even in optimized builds where
+ // standard assertions are disabled. Out-of-bounds access here is a common
+ // occurence.
+ CHECK_GE(i, 0);
+ CHECK_GT(dims_.size(), i);
+ return dims_[i];
+ }
bool operator==(const Shape& comp) const {
return (this->dims_ == comp.dims());
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 4264f21c76..245eb52444 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -160,6 +160,11 @@ bool ParseModelFlagsFromCommandLineFlags(
"Path to an optional file containing a serialized ArraysExtraInfo "
"proto allowing to pass extra information about arrays not specified "
"in the input model file, such as extra MinMax information."),
+ Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
+ parsed_flags.model_flags_file.default_value(),
+ "Path to an optional file containing a serialized ModelFlags proto. "
+ "Options specified on the command line will override the values in "
+ "the proto."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -182,7 +187,24 @@ void ReadModelFlagsFromCommandLineFlags(
const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
-// "batch" flag only exists internally
+ // Load proto containing the initial model flags.
+ // Additional flags specified on the command line will overwrite the values.
+ if (parsed_model_flags.model_flags_file.specified()) {
+ string model_flags_file_contents;
+ QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
+ &model_flags_file_contents,
+ port::file::Defaults())
+ .ok())
+ << "Specified --model_flags_file="
+ << parsed_model_flags.model_flags_file.value()
+ << " was not found or could not be read";
+ QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
+ model_flags))
+ << "Specified --model_flags_file="
+ << parsed_model_flags.model_flags_file.value()
+ << " could not be parsed";
+ }
+
#ifdef PLATFORM_GOOGLE
CHECK(!((base::SpecifiedOnCommandLine("batch") &&
parsed_model_flags.variable_batch.specified())))
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
index 42e0f54826..835dea49eb 100644
--- a/tensorflow/contrib/lite/toco/model_flags.proto
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -98,8 +98,8 @@ message ArraysExtraInfo {
message Entry {
// Next ID to use: 7.
optional string name = 1;
- optional float min = 2;
- optional float max = 3;
+ optional double min = 2;
+ optional double max = 3;
optional IODataType data_type = 4;
optional InputArrayShape shape = 5;
optional float constant_float_value = 6;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 0cb348bda5..f991529569 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -204,17 +204,22 @@ class BatchToSpaceND
TocoOperator* op) const override {}
};
-class Cast : public CustomOperator<CastOperator> {
+class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
+ ::tflite::BuiltinOptions_CastOptions> {
public:
- using CustomOperator::CustomOperator;
- void WriteOptions(const TocoOperator& op,
- flexbuffers::Builder* fbb) const override {
- fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
- fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateCastOptions(*builder,
+ DataType::Serialize(op.src_data_type),
+ DataType::Serialize(op.dst_data_type));
}
- void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
- op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
- op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->src_data_type = DataType::Deserialize(options.in_data_type());
+ op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
};
@@ -827,9 +832,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2));
ops.emplace_back(
new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
+ ops.emplace_back(
+ new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
// Custom Operators.
- ops.emplace_back(new Cast("CAST", OperatorType::kCast));
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index f7a213ecfc..4783843b7f 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -131,7 +131,7 @@ TEST_F(OperatorTest, BuiltinMean) {
EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
}
-TEST_F(OperatorTest, CustomCast) {
+TEST_F(OperatorTest, BuiltinCast) {
CastOperator op;
op.src_data_type = ArrayDataType::kFloat;
op.dst_data_type = ArrayDataType::kUint8;
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 30dd6fab9e..76e9a27aef 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -74,11 +74,14 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowMatMul);
transformations->Add(new FuseBinaryIntoPrecedingAffine);
transformations->Add(new FuseBinaryIntoFollowingAffine);
- transformations->Add(new ReorderActivationFunctions);
+ transformations->Add(new MergeReshapeIntoPrecedingTranspose);
+ transformations->Add(new ReorderElementwiseUnary);
+ transformations->Add(new ReorderReshapeTranspose);
transformations->Add(new ResolveBatchNormalization);
transformations->Add(new ResolveConstantBinaryOperator);
transformations->Add(new ResolveConstantFill);
transformations->Add(new ResolveConstantGather);
+ transformations->Add(new ResolveConstantRandomUniform);
transformations->Add(new ResolveConstantRange);
transformations->Add(new ResolveConstantStack);
transformations->Add(new ResolveConstantStridedSlice);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index f3f50487ff..56fa8f4b69 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -297,6 +297,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(L2Pool)
HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
HANDLE_OPERATORTYPENAME_CASE(Mul)
+ HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
HANDLE_OPERATORTYPENAME_CASE(Relu)
HANDLE_OPERATORTYPENAME_CASE(Relu1)
HANDLE_OPERATORTYPENAME_CASE(Relu6)
@@ -1920,6 +1921,35 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
return true;
}
+bool ReshapeIsEquivalentToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op,
+ bool allow_extra_unary_dims) {
+ CHECK(!op->shape.empty());
+ CHECK(model.HasArray(op->inputs[0]));
+ CHECK(model.HasArray(op->outputs[0]));
+
+ const auto& input_array = model.GetArray(op->inputs[0]);
+ const auto& output_array = model.GetArray(op->outputs[0]);
+
+ CHECK(input_array.has_shape());
+ CHECK(output_array.has_shape());
+
+ std::vector<int> in_shape = input_array.shape().dims();
+ std::vector<int> out_shape = output_array.shape().dims();
+
+ // If the reshape changes the number of dimensions so it cannot be interpreted
+ // as a transpose.
+ if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
+ return false;
+ }
+
+ in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
+ in_shape.end());
+ out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
+ out_shape.end());
+ return in_shape == out_shape;
+}
+
void CheckFinalDataTypesSatisfied(const Model& model) {
for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = *array_entry.second;
@@ -1976,9 +2006,9 @@ void UseArraysExtraInfo(Model* model) {
continue;
}
auto& array = model->GetArray(entry.name());
- auto& minmax = array.GetOrCreateMinMax();
if (entry.has_min() || entry.has_max()) {
CHECK_EQ(entry.has_min(), entry.has_max());
+ auto& minmax = array.GetOrCreateMinMax();
minmax.min = entry.min();
minmax.max = entry.max();
}
@@ -1997,11 +2027,12 @@ void UseArraysExtraInfo(Model* model) {
}
if (entry.has_constant_float_value()) {
CHECK(array.has_shape());
- CHECK(array.data_type == ArrayDataType::kFloat);
- auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
- data.resize(RequiredBufferSizeForShape(array.shape()));
- for (float& f : data) {
- f = entry.constant_float_value();
+ if (array.data_type == ArrayDataType::kFloat) {
+ auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ data.resize(RequiredBufferSizeForShape(array.shape()));
+ for (float& f : data) {
+ f = entry.constant_float_value();
+ }
}
}
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index d3b7224fe3..259ee7fbd0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -169,10 +169,23 @@ void GetQuantizationParamsFromMinMax(const MinMax& minmax,
::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
}
+template <typename T>
+T ConvertOperator(Operator* o, OperatorType type) {
+ if (o != nullptr && o->type == type) {
+ return static_cast<T>(o);
+ }
+
+ return nullptr;
+}
+
void CheckIsReadyForQuantization(const Model& model);
void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
double default_ranges_max);
+bool ReshapeIsEquivalentToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op,
+ bool allow_extra_unary_dims);
+
inline int Offset(const Shape& shape, const std::vector<int>& indices) {
DCHECK_EQ(shape.dimensions_count(), indices.size());
const int dims_count = shape.dimensions_count();
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index 02b4f80252..f616207d46 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -46,4 +46,5 @@ tf_py_test(
"//tensorflow/python:variables",
],
grpc_enabled = True,
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 77c936d8c5..76428bc1d4 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -12,6 +12,7 @@ tensorflow/core/platform/posix/env.cc
tensorflow/core/platform/posix/load_library.cc
tensorflow/core/platform/posix/env_time.cc
tensorflow/core/platform/file_system.cc
+tensorflow/core/platform/file_system_helper.cc
tensorflow/core/platform/env.cc
tensorflow/core/platform/env_time.cc
tensorflow/core/platform/setround.cc
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 6cbfd03881..334e70318d 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -31,7 +31,7 @@ tf_custom_op_library(
"kernels/nccl_ops.cc",
],
deps = if_cuda([
- "@nccl_archive//:nccl",
+ "@local_config_nccl//:nccl",
"//tensorflow/core:gpu_headers_lib",
]),
)
@@ -61,7 +61,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- "@nccl_archive//:nccl",
+ "@local_config_nccl//:nccl",
],
)
@@ -80,7 +80,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:stream_executor",
- "@nccl_archive//:nccl",
+ "@local_config_nccl//:nccl",
],
alwayslink = 1,
)
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
index bb219e0edc..6ff8cea84e 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.h
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "src/nccl.h"
+#include "third_party/nccl/nccl.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
index 266d4f6f0d..c2b76caef3 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
-#include "src/nccl.h"
+#include "third_party/nccl/nccl.h"
#include "tensorflow/contrib/nccl/kernels/nccl_manager.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
index a4de46a93f..4676e937e5 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/lib/strings/str_util.h"
#if GOOGLE_CUDA
#include <forward_list>
@@ -254,7 +255,7 @@ class NcclReplacePass : public GraphOptimizationPass {
// Find reduction and broadcast ops and replace them with Send/Recv ops.
for (Node* node : graph->op_nodes()) {
StringPiece type = node->type_string();
- if (!type.starts_with("Nccl")) {
+ if (!str_util::StartsWith(type, "Nccl")) {
continue;
}
if (type == "NcclReduce") {
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 471992fdac..25d19578ea 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -866,7 +866,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in grads_and_vars],))
return distribute_lib.get_tower_context().merge_call(
- self.distributed_apply, filtered, global_step=global_step, name=name)
+ self._distributed_apply, filtered, global_step=global_step, name=name)
def _get_or_create_state(self, var_list=None):
"""Either looks up or creates `_OptimizerV2State`.
@@ -899,7 +899,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
self._per_graph_state[graph_key] = per_graph_state
return per_graph_state
- def distributed_apply(self, distribution, grads_and_vars, global_step, name):
+ def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
var_list = [v for _, v in grads_and_vars]
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2889016a84..d53d4d7b10 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -416,7 +416,9 @@ def _InsertQuantOp(context,
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
# variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
# breaks things later.
- name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/')
+ name_scope = ops.get_name_scope()
+ if name_scope:
+ name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
inputs = producer.outputs[0]
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 98f05c8bfc..8d057d3710 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -247,6 +247,27 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self.assertTrue(not op.name.startswith('name_scope/name_scope/'),
'Broken op: %s' % op.name)
+ def testWithNullNameScope(self):
+ self._RunTestOverParameters(self._TestWithNullNameScope)
+
+ def _TestWithNullNameScope(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ with graph.name_scope(None):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ _ = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test')
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+ # Passes if Quantize() does not crash.
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD
index 996b55f9b8..3aa8a14f44 100644
--- a/tensorflow/contrib/remote_fused_graph/pylib/BUILD
+++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD
@@ -38,7 +38,6 @@ py_test(
size = "small",
srcs = ["python/ops/remote_fused_graph_ops_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":remote_fused_graph_ops_py",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index faad40d335..e431c464ef 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -53,6 +53,7 @@ py_test(
size = "small",
srcs = ["python/saved_model/reader_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
visibility = ["//visibility:private"],
deps = [
":saved_model_py",
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
index 31717305e7..9c08859180 100644
--- a/tensorflow/contrib/session_bundle/BUILD
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -151,6 +151,7 @@ py_test(
name = "gc_test",
srcs = ["gc_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
visibility = ["//visibility:private"],
deps = [
":gc",
diff --git a/tensorflow/contrib/slim/python/slim/data/BUILD b/tensorflow/contrib/slim/python/slim/data/BUILD
index dc12e67fc6..eef043e832 100644
--- a/tensorflow/contrib/slim/python/slim/data/BUILD
+++ b/tensorflow/contrib/slim/python/slim/data/BUILD
@@ -61,6 +61,7 @@ py_test(
name = "dataset_data_provider_test",
srcs = ["dataset_data_provider_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":dataset",
":dataset_data_provider",
diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD
index d4096751c4..30be14c10c 100644
--- a/tensorflow/contrib/stat_summarizer/BUILD
+++ b/tensorflow/contrib/stat_summarizer/BUILD
@@ -31,4 +31,5 @@ tf_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
],
+ tags = ["no_windows"],
)
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 11a59ec22b..136856c015 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -539,7 +539,6 @@ py_test(
srcs = ["client/random_forest_test.py"],
srcs_version = "PY2AND3",
tags = [
- "no_windows",
"nomac", # b/63258195
"notsan",
],
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index f4efd9717d..2b6a2b2f3c 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -9,6 +9,7 @@ exports_files(["LICENSE"])
# For platform specific build config
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow:tensorflow.bzl", "py_test")
tf_proto_library(
name = "protos_all",
@@ -81,6 +82,7 @@ py_test(
size = "small",
srcs = ["plugins/trace/trace_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":trace",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc
index c61b465596..cd3f712256 100644
--- a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/event.pb.h"
@@ -58,7 +59,7 @@ class SummaryFileWriterTest : public ::testing::Test {
TF_CHECK_OK(env_.GetChildren(testing::TmpDir(), &files));
bool found = false;
for (const string& f : files) {
- if (StringPiece(f).contains(test_name)) {
+ if (str_util::StrContains(f, test_name)) {
if (found) {
return errors::Unknown("Found more than one file for ", test_name);
}
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 40cf9147b3..32e948a009 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -25,7 +25,10 @@ py_test(
srcs = ["predict_test.py"],
data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
- tags = ["notsan"], # b/67513579
+ tags = [
+ "no_windows", # TODO: needs investigation on Windows
+ "notsan", # b/67513579
+ ],
deps = [
":predict",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 9b6c08150c..d2746032a0 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -88,10 +88,14 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/estimator:export",
+ "//tensorflow/python/feature_column",
],
)
@@ -132,7 +136,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_keys",
- "//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@@ -141,6 +144,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:estimator_py",
@@ -156,23 +160,30 @@ py_test(
"head_test.py",
],
srcs_version = "PY2AND3",
- tags = [
- "no_pip_gpu", # b/63391119
- ],
+ tags = ["no_pip_gpu"], # b/63391119
deps = [
+ ":estimators",
":feature_keys",
":head",
+ ":input_pipeline",
":model",
":state_management",
+ "//tensorflow/contrib/timeseries/examples:lstm",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
],
)
@@ -428,6 +439,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_pip_gpu", # b/63391119
+ "no_windows", # TODO: needs investigation on Windows
],
deps = [
":feature_keys",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 469cea4fd2..886e1846e2 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -44,7 +44,7 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
"""An Estimator to fit and evaluate a time series model."""
def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
- config=None):
+ config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
"""Initialize the Estimator.
Args:
@@ -55,6 +55,8 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
from tf.train.Optimizer. Defaults to Adam with step size 0.02.
model_dir: See `Estimator`.
config: See `Estimator`.
+ head_type: The kind of head to use for the model (inheriting from
+ `TimeSeriesRegressionHead`).
"""
input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
dtype=model.dtype, num_features=model.num_features)
@@ -63,8 +65,8 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
if optimizer is None:
optimizer = train.AdamOptimizer(0.02)
self._model = model
- ts_regression_head = ts_head_lib.time_series_regression_head(
- model, state_manager, optimizer,
+ ts_regression_head = head_type(
+ model=model, state_manager=state_manager, optimizer=optimizer,
input_statistics_generator=input_statistics_generator)
model_fn = ts_regression_head.create_estimator_spec
super(TimeSeriesRegressor, self).__init__(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 51d0c0ca3f..9f161c1695 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import tempfile
import numpy
+import six
from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import estimators
@@ -127,6 +128,12 @@ class TimeSeriesRegressorTest(test.TestCase):
session=sess)
# Test cold starting
+ six.assertCountEqual(
+ self,
+ [feature_keys.FilteringFeatures.TIMES,
+ feature_keys.FilteringFeatures.VALUES],
+ signatures.signature_def[
+ feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())
batch_numpy_times = numpy.tile(
numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
batch_numpy_values = numpy.ones([10, 30, 1])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 4cf6bbcfd4..a28a5872b8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -39,27 +39,18 @@ from tensorflow.python.util import nest
from tensorflow.python.summary import summary
-def time_series_regression_head(model,
- state_manager,
- optimizer,
- input_statistics_generator=None):
- """Creates a `_Head` for time series regression.
+class _NoStatePredictOutput(export_lib.PredictOutput):
- Args:
- model: A model for time series regression.
- state_manager: A state manager.
- optimizer: An optimizer.
- input_statistics_generator: A input statistics generator.
-
- Returns:
- An instance of `_Head` for time series regression.
- """
- return _TimeSeriesRegressionHead(model, state_manager, optimizer,
- input_statistics_generator)
+ def as_signature_def(self, receiver_tensors):
+ no_state_receiver_tensors = {
+ key: value for key, value in receiver_tensors.items()
+ if not key.startswith(feature_keys.State.STATE_PREFIX)}
+ return super(_NoStatePredictOutput, self).as_signature_def(
+ receiver_tensors=no_state_receiver_tensors)
-class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
- """See `time_series_regression_head`."""
+class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
+ """Determines input and output signatures for a time series model."""
def __init__(self,
model,
@@ -67,6 +58,15 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
optimizer,
input_statistics_generator=None,
name=None):
+ """Creates a `_Head` for time series regression.
+
+ Args:
+ model: A model for time series regression.
+ state_manager: A state manager.
+ optimizer: An optimizer.
+ input_statistics_generator: A input statistics generator.
+ name: An optional name for the model.
+ """
self.model = model
self.state_manager = state_manager
self.optimizer = optimizer
@@ -167,7 +167,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
export_lib.PredictOutput(
state_to_dictionary(filtering_outputs.end_state)),
feature_keys.SavedModelLabels.COLD_START_FILTER:
- export_lib.PredictOutput(
+ _NoStatePredictOutput(
state_to_dictionary(cold_filtering_outputs.end_state))
},
# Likely unused, but it is necessary to return `predictions` to satisfy
@@ -255,6 +255,58 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
return self._serving_ops(features)
+class OneShotPredictionHead(TimeSeriesRegressionHead):
+ """A time series head which exports a single stateless serving signature.
+
+ The serving default signature exported by this head expects `times`, `values`,
+ and any exogenous features, but no state. `values` has shape `[batch_size,
+ filter_length, num_features]` and `times` has shape `[batch_size,
+ total_length]`, where `total_length > filter_length`. Any exogenous features
+ must have their shapes prefixed by the shape of the `times` feature.
+
+ When serving, first performs filtering on the series up to `filter_length`
+ starting from the default start state for the model, then computes predictions
+ on the remainder of the series, returning them.
+
+ Model state is neither accepted nor returned, so filtering must be performed
+ each time predictions are requested when using this head.
+ """
+
+ def _serving_ops(self, features):
+ """Add ops for serving to the graph."""
+ with variable_scope.variable_scope("model", use_resource=True):
+ filtering_features = {}
+ prediction_features = {}
+ values_length = array_ops.shape(
+ features[feature_keys.FilteringFeatures.VALUES])[1]
+ for key, value in features.items():
+ if key == feature_keys.State.STATE_TUPLE:
+ # Ignore state input. The model's default start state is replicated
+ # across the batch.
+ continue
+ if key == feature_keys.FilteringFeatures.VALUES:
+ filtering_features[key] = value
+ else:
+ filtering_features[key] = value[:, :values_length]
+ prediction_features[key] = value[:, values_length:]
+ cold_filtering_outputs = self.model.define_loss(
+ features=filtering_features, mode=estimator_lib.ModeKeys.EVAL)
+ prediction_features[feature_keys.State.STATE_TUPLE] = (
+ cold_filtering_outputs.end_state)
+ with variable_scope.variable_scope("model", reuse=True):
+ prediction_outputs = self.model.predict(
+ features=prediction_features)
+ return estimator_lib.EstimatorSpec(
+ mode=estimator_lib.ModeKeys.PREDICT,
+ export_outputs={
+ feature_keys.SavedModelLabels.PREDICT:
+ _NoStatePredictOutput(prediction_outputs),
+ },
+ # Likely unused, but it is necessary to return `predictions` to satisfy
+ # the Estimator's error checking.
+ predictions={})
+
+
def _check_feature_shapes_compatible_with(features,
compatible_with_name,
compatible_with_value,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 3415061cfd..c606db76a6 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,12 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy
+import six
+
+from tensorflow.contrib.timeseries.examples import lstm as lstm_example
+from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
+from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import state_management
+from tensorflow.python.client import session as session_lib
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -31,6 +39,9 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import adam
from tensorflow.python.training import coordinator as coordinator_lib
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import training as train
@@ -90,7 +101,7 @@ class EvaluationMetricsTests(test.TestCase):
.count_up_to(10),
dtype=dtypes.float32), (1, 1, 1))
}
- model_fn = ts_head_lib.time_series_regression_head(
+ model_fn = ts_head_lib.TimeSeriesRegressionHead(
model=_TickerModel(),
state_manager=state_management.PassthroughStateManager(),
optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec
@@ -127,7 +138,7 @@ class _StubModel(object):
def _stub_model_fn():
- return ts_head_lib.time_series_regression_head(
+ return ts_head_lib.TimeSeriesRegressionHead(
model=_StubModel(),
state_manager=state_management.PassthroughStateManager(),
optimizer=train.AdamOptimizer(0.001)).create_estimator_spec
@@ -263,5 +274,76 @@ class PredictFeatureCheckingTests(test.TestCase):
mode=estimator_lib.ModeKeys.PREDICT)
+class OneShotTests(test.TestCase):
+
+ def test_one_shot_prediction_head_export(self):
+ model_dir = self.get_temp_dir()
+ categorical_column = feature_column.categorical_column_with_hash_bucket(
+ key="categorical_exogenous_feature", hash_bucket_size=16)
+ exogenous_feature_columns = [
+ feature_column.numeric_column(
+ "2d_exogenous_feature", shape=(2,)),
+ feature_column.embedding_column(
+ categorical_column=categorical_column, dimension=10)]
+ estimator = ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(
+ num_features=5, num_units=128,
+ exogenous_feature_columns=exogenous_feature_columns),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ state_manager=state_management.ChainingStateManager(),
+ head_type=ts_head_lib.OneShotPredictionHead,
+ model_dir=model_dir)
+ train_features = {
+ feature_keys.TrainEvalFeatures.TIMES: numpy.arange(
+ 20, dtype=numpy.int64),
+ feature_keys.TrainEvalFeatures.VALUES: numpy.tile(numpy.arange(
+ 20, dtype=numpy.float32)[:, None], [1, 5]),
+ "2d_exogenous_feature": numpy.ones([20, 2]),
+ "categorical_exogenous_feature": numpy.array(
+ ["strkey"] * 20)[:, None]
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(train_features), shuffle_seed=2,
+ num_threads=1, batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=5)
+ input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
+ export_location = estimator.export_savedmodel(self.get_temp_dir(),
+ input_receiver_fn)
+ graph = ops.Graph()
+ with graph.as_default():
+ with session_lib.Session() as session:
+ signatures = loader.load(
+ session, [tag_constants.SERVING], export_location)
+ self.assertEqual([feature_keys.SavedModelLabels.PREDICT],
+ list(signatures.signature_def.keys()))
+ predict_signature = signatures.signature_def[
+ feature_keys.SavedModelLabels.PREDICT]
+ six.assertCountEqual(
+ self,
+ [feature_keys.FilteringFeatures.TIMES,
+ feature_keys.FilteringFeatures.VALUES,
+ "2d_exogenous_feature",
+ "categorical_exogenous_feature"],
+ predict_signature.inputs.keys())
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: numpy.tile(
+ numpy.arange(35, dtype=numpy.int64)[None, :], [2, 1]),
+ feature_keys.TrainEvalFeatures.VALUES: numpy.tile(numpy.arange(
+ 20, dtype=numpy.float32)[None, :, None], [2, 1, 5]),
+ "2d_exogenous_feature": numpy.ones([2, 35, 2]),
+ "categorical_exogenous_feature": numpy.tile(numpy.array(
+ ["strkey"] * 35)[None, :, None], [2, 1, 1])
+ }
+ feeds = {
+ graph.as_graph_element(input_value.name): features[input_key]
+ for input_key, input_value in predict_signature.inputs.items()}
+ fetches = {output_key: graph.as_graph_element(output_value.name)
+ for output_key, output_value
+ in predict_signature.outputs.items()}
+ output = session.run(fetches, feed_dict=feeds)
+ self.assertAllEqual((2, 15, 5), output["mean"].shape)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
index ca25ccd2b8..5d33e23a42 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
@@ -40,6 +40,7 @@ py_test(
timeout = "long", # Moderate but for asan
srcs = ["state_space_model_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":state_space_model",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 3e32a7a85c..4de09dd988 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -159,6 +159,7 @@ py_library(
name = "tpu_lib",
srcs = [
"python/tpu/__init__.py",
+ "python/tpu/bfloat16.py",
"python/tpu/device_assignment.py",
"python/tpu/topology.py",
"python/tpu/tpu.py",
@@ -214,6 +215,7 @@ tf_py_test(
":datasets",
],
grpc_enabled = True,
+ tags = ["no_windows"],
)
tf_py_test(
@@ -227,6 +229,7 @@ tf_py_test(
"//tensorflow/python:framework",
"//tensorflow/python:layers",
],
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
tf_py_test(
@@ -241,6 +244,17 @@ tf_py_test(
)
tf_py_test(
+ name = "bfloat16_test",
+ size = "small",
+ srcs = ["python/tpu/bfloat16_test.py"],
+ additional_deps = [
+ ":tpu",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ ],
+)
+
+tf_py_test(
name = "tpu_infeed_test",
size = "small",
srcs = ["python/tpu/tpu_infeed_test.py"],
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index ea6e874f2d..bb60f3e2d7 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -53,6 +53,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
+from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.topology import *
from tensorflow.contrib.tpu.python.tpu.tpu import *
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 590db2c376..2a15875627 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -79,6 +79,10 @@ message StepInfoResult {
optional uint64 infeed_duration_ps = 3;
// The start time of this step in picoseconds.
optional uint64 begin_ps = 4;
+ // The waiting time within this step in picoseconds.
+ optional uint64 wait_duration_ps = 5;
+ // The time spent on cross-replica-sum in picoseconds.
+ optional uint64 crs_duration_ps = 6;
}
// Result proto for a sequence of steps.
diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16.py b/tensorflow/contrib/tpu/python/tpu/bfloat16.py
new file mode 100644
index 0000000000..5e49af6408
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/bfloat16.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Helper context for running models with bfloat16."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import tf_contextlib
+
+
+def _get_custom_getter():
+ """Returns a custom getter that this class's methods must be called under.
+
+ All methods of this class must be called under a variable scope that was
+ passed this custom getter. Example:
+
+ ```python
+ network = ConvNetBuilder(...)
+ with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
+ network.conv(...)
+ # Call more methods of network here
+ ```
+
+ Currently, this custom getter only does anything if self.use_tf_layers is
+ True. In that case, it causes variables to be stored as dtype
+ self.variable_type, then casted to the requested dtype, instead of directly
+ storing the variable as the requested dtype.
+ """
+
+ def inner_custom_getter(getter, *args, **kwargs):
+ """Custom getter that forces variables to have type self.variable_type."""
+ cast_to_bfloat16 = False
+ requested_dtype = kwargs['dtype']
+ if requested_dtype == dtypes.bfloat16:
+ # Only change the variable dtype if doing so does not decrease variable
+ # precision.
+ kwargs['dtype'] = dtypes.float32
+ cast_to_bfloat16 = True
+ var = getter(*args, **kwargs)
+ # This if statement is needed to guard the cast, because batch norm
+ # assigns directly to the return value of this custom getter. The cast
+ # makes the return value not a variable so it cannot be assigned. Batch
+ # norm variables are always in fp32 so this if statement is never
+ # triggered for them.
+ if cast_to_bfloat16:
+ var = math_ops.cast(var, dtypes.bfloat16)
+ return var
+
+ return inner_custom_getter
+
+
+@tf_contextlib.contextmanager
+def bfloat16_scope():
+ """Scope class for bfloat16 variables so that the model uses custom getter.
+
+ This enables variables to be read as bfloat16 type when using get_variable.
+ """
+ with variable_scope.variable_scope(
+ 'bfloat16', custom_getter=_get_custom_getter()) as varscope:
+ yield varscope
diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16_test.py b/tensorflow/contrib/tpu/python/tpu/bfloat16_test.py
new file mode 100644
index 0000000000..48a01c7308
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/bfloat16_test.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for bfloat16 helper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tpu.python.tpu import bfloat16
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import variable_scope
+
+from tensorflow.python.platform import test
+
+
+class BFloat16ScopeTest(test.TestCase):
+
+ def testScopeName(self):
+ """Test if name for the variable scope is propogated correctly.
+ """
+ with bfloat16.bfloat16_scope() as bf:
+ self.assertEqual(bf.name, "bfloat16")
+
+ def testRequestedDType(self):
+ """Test if requested dtype is honored in the getter.
+ """
+ with bfloat16.bfloat16_scope() as scope:
+ v1 = variable_scope.get_variable("v1", [])
+ self.assertEqual(v1.dtype.base_dtype, dtypes.float32)
+ v2 = variable_scope.get_variable("v2", [], dtype=dtypes.bfloat16)
+ self.assertEqual(v2.dtype.base_dtype, dtypes.bfloat16)
+ self.assertEqual([dtypes.float32, dtypes.float32],
+ [v.dtype.base_dtype for v in scope.global_variables()])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index fa56708f44..6834600b79 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -2019,7 +2019,8 @@ class TPUEstimator(estimator_lib.Estimator):
host_ops,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator)),
- ExamplesPerSecondHook(ctx.global_batch_size),
+ ExamplesPerSecondHook(ctx.global_batch_size,
+ output_dir=self.model_dir),
InstallSignalHandlerHook(),
training.LoggingTensorHook(
{
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py
index eea57ed336..3ae350c7bb 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py
@@ -120,7 +120,8 @@ def _query_tpu_system_metadata(master_address, run_config,
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
logging.info('*** Num TPU Cores Per Worker: %d',
metadata.num_of_cores_per_host)
- logging.info('*** Available Devices: %s', metadata.devices)
+ for device in metadata.devices:
+ logging.info('*** Available Device: %s', device)
else:
logging.info('Failed to find TPU: %s', metadata)
return metadata
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index dbdbb08a82..f305197c19 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -517,6 +518,7 @@ class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest):
ops._USE_C_API = self._prev_value
+@test_util.with_c_api
class PaddingTest(test.TestCase):
def testPaddingInvalidLengths(self):
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
index 7223194885..99d486b183 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
@@ -1574,8 +1574,9 @@ def _padding(sequences, num_unroll):
if not sequences:
return 0, {}
- sequences_dict = {}
- for key, value in sequences.items():
+ # Sort 'sequences_dict' so 'length' will have a predictable value below.
+ sequences_dict = collections.OrderedDict()
+ for key, value in sorted(sequences.items()):
if not (isinstance(value, sparse_tensor.SparseTensor) or
isinstance(value, sparse_tensor.SparseTensorValue)):
sequences_dict[key] = ops.convert_to_tensor(value)
diff --git a/tensorflow/contrib/util/loader.py b/tensorflow/contrib/util/loader.py
index f4283cd9ed..dca01d26f4 100644
--- a/tensorflow/contrib/util/loader.py
+++ b/tensorflow/contrib/util/loader.py
@@ -42,9 +42,10 @@ def load_op_library(path):
plugin.
"""
if os.name == 'nt':
- # To avoid makeing every user_ops aware of windows, re-write
- # the file extension from .so to .dll.
- path = re.sub(r'\.so$', '.dll', path)
+ # To avoid making every user_ops aware of windows, re-write
+ # the file extension from .so to .dll if .so file doesn't exist.
+ if not os.path.exists(path):
+ path = re.sub(r'\.so$', '.dll', path)
# Currently we have only some user_ops as dlls on windows - don't try
# to load them if the dll is not found.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 21f7866abd..7d5ae1c5b5 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -349,6 +349,7 @@ cc_library(
"platform/env.h",
"platform/env_time.h",
"platform/file_system.h",
+ "platform/file_system_helper.h",
"platform/fingerprint.h",
"platform/init_main.h",
"platform/logging.h",
diff --git a/tensorflow/core/api_def/base_api/api_def_For.pbtxt b/tensorflow/core/api_def/base_api/api_def_For.pbtxt
new file mode 100644
index 0000000000..a7cd8e1a26
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_For.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "For"
+ in_arg { name: "start" description: "The lower bound. An int32" }
+ in_arg { name: "limit" description: "The upper bound. An int32" }
+ in_arg { name: "delta" description: "The increment. An int32" }
+ in_arg {
+ name: "input"
+ description: "A list of input tensors whose types are T."
+ }
+ out_arg {
+ name: "output"
+ description: "A list of output tensors whose types are T."
+ }
+ attr { name: "T" description: "A list of dtypes." }
+ attr {
+ name: "body"
+ description: <<END
+ A function that takes a list of tensors (int32, T) and returns another
+ list of tensors (T).
+END
+ }
+ summary: <<END
+ ```python
+ output = input;
+ for i in range(start, limit, delta)
+ output = body(i, output);
+ ```
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_If.pbtxt b/tensorflow/core/api_def/base_api/api_def_If.pbtxt
new file mode 100644
index 0000000000..7ba5a3f37e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_If.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "If"
+ in_arg { name: "cond" description: "The predicate." }
+ in_arg {
+ name: "cond"
+ description: <<END
+ A Tensor. If the tensor is a scalar of non-boolean type, the
+ scalar is converted to a boolean according to the
+ following rule: if the scalar is a numerical value, non-zero means
+ `True` and zero means False; if the scalar is a string, non-empty
+ means `True` and empty means `False`. If the tensor is not a scalar,
+ being empty means False and being non-empty means True.
+END
+ }
+ in_arg {
+ name: "input"
+ description: "A list of input tensors."
+ }
+ out_arg {
+ name: "output"
+ description: "A list of return values."
+ }
+ attr { name: "Tin" description: "A list of input types." }
+ attr { name: "Tout" description: "A list of output types." }
+ attr {
+ name: "then_branch"
+ description: <<END
+ A function that takes 'inputs' and returns a list of tensors, whose
+ types are the same as what else_branch returns.
+END
+ }
+ attr {
+ name: "else_branch"
+ description: <<END
+ A function that takes 'inputs' and returns a list of tensors, whose
+ types are the same as what then_branch returns.
+END
+ }
+ summary: "output = cond ? then_branch(input) : else_branch(input)"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_While.pbtxt b/tensorflow/core/api_def/base_api/api_def_While.pbtxt
new file mode 100644
index 0000000000..95a19c6dff
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_While.pbtxt
@@ -0,0 +1,33 @@
+op {
+ graph_op_name: "While"
+ in_arg {
+ name: "input"
+ description: "A list of input tensors whose types are T."
+ }
+ out_arg {
+ name: "output"
+ description: "A list of output tensors whose types are T."
+ }
+ attr { name: "T" description: "dtype in use." }
+ attr {
+ name: "cond"
+ description: <<END
+ A function takes 'input' and returns a tensor. If the tensor is
+ a scalar of non-boolean, the scalar is converted to a boolean
+ according to the following rule: if the scalar is a numerical
+ value, non-zero means True and zero means False; if the scalar is
+ a string, non-empty means True and empty means False. If the
+ tensor is not a scalar, non-emptiness means True and False
+ otherwise.
+END
+ }
+ attr {
+ name: "body"
+ description: <<END
+ A function that takes a list of tensors and returns another
+ list of tensors. Both lists have the same types as specified
+ by T.
+END
+ }
+ summary: "output = input; While (Cond(output)) { output = Body(output) }"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_For.pbtxt b/tensorflow/core/api_def/python_api/api_def_For.pbtxt
new file mode 100644
index 0000000000..a58ddf56fe
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_For.pbtxt
@@ -0,0 +1 @@
+op { graph_op_name: "For" visibility: HIDDEN }
diff --git a/tensorflow/core/api_def/python_api/api_def_If.pbtxt b/tensorflow/core/api_def/python_api/api_def_If.pbtxt
new file mode 100644
index 0000000000..a44db5da08
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_If.pbtxt
@@ -0,0 +1 @@
+op { graph_op_name: "If" visibility: HIDDEN }
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt
new file mode 100644
index 0000000000..4f5b6decf6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterAdd"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_While.pbtxt b/tensorflow/core/api_def/python_api/api_def_While.pbtxt
new file mode 100644
index 0000000000..f47a9b0fce
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_While.pbtxt
@@ -0,0 +1 @@
+op { graph_op_name: "While" visibility: HIDDEN }
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index ee38960618..f95cecfc66 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
@@ -155,22 +156,22 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
Status s = session->RunCallable(handle, {}, nullptr, nullptr);
EXPECT_TRUE(errors::IsInvalidArgument(s));
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("`fetch_tensors` must be provided"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "`fetch_tensors` must be provided"));
TF_ASSERT_OK(session->ReleaseCallable(handle));
std::vector<Tensor> outputs;
s = session->RunCallable(handle, {}, &outputs, nullptr);
EXPECT_TRUE(errors::IsInvalidArgument(s));
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains("Attempted to run callable after handle was released"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Attempted to run callable after handle was released"));
s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
EXPECT_TRUE(errors::IsInvalidArgument(s));
EXPECT_TRUE(
- StringPiece(s.error_message()).contains("No such callable handle"));
+ str_util::StrContains(s.error_message(), "No such callable handle"));
}
}
@@ -567,7 +568,7 @@ TEST(DirectSessionTest, MultipleFeedTest) {
{first_identity->name() + ":0", second_identity->name() + ":0"}, {},
&outputs);
EXPECT_TRUE(errors::IsInvalidArgument(s));
- EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
}
TEST(DirectSessionTest, MultipleFeedTest_Callable) {
@@ -650,7 +651,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) {
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
&handle);
EXPECT_TRUE(errors::IsInvalidArgument(s));
- EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
}
TEST(DirectSessionTest, FetchMultipleTimes) {
@@ -845,8 +846,8 @@ TEST(DirectSessionTest, PartialRunMissingFeed) {
s = session->PRun(handle, {{first_const->name(), value_11}},
{third_identity->name() + ":0"}, &outputs);
ASSERT_TRUE(errors::IsInvalidArgument(s));
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("can't be computed from the feeds"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "can't be computed from the feeds"));
}
TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
@@ -875,8 +876,8 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
// Fetch fourth_identity without feeds.
s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
ASSERT_TRUE(errors::IsInvalidArgument(s));
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("can't be computed from the feeds"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "can't be computed from the feeds"));
// Feed switch_node:1 and fetch fourth_identity.
s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index b06b75d658..0c461a9ee9 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -258,6 +258,13 @@ struct NodeItem {
// Return array of per-output allocator attributes.
const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
+ // Return array of expected input index from which each output should
+ // be forwarded:
+ // kNeverForward (-2) for DO NOT FORWARD (must allocate).
+ // kNoReservation (-1) for no expected forwarding.
+ // 0... for forward from that input.
+ const int* forward_from() const { return forward_from_base(); }
+
private:
friend class GraphView;
@@ -267,6 +274,7 @@ struct NodeItem {
// AllocatorAttributes output_attr[num_outputs];
// uint8 input_type[num_inputs];
// uint8 output_type[num_outputs];
+ // int forward_from[num_outputs];
// Return pointer to variable length section.
char* var() const {
@@ -292,6 +300,13 @@ struct NodeItem {
sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs);
}
+ int* forward_from_base() const {
+ return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
+ sizeof(AllocatorAttributes) * num_outputs +
+ sizeof(uint8) * num_inputs +
+ sizeof(uint8) * num_outputs);
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
};
@@ -466,7 +481,8 @@ size_t GraphView::NodeItemBytes(const Node* n) {
+ num_output_edges * sizeof(EdgeInfo) // output_edges[...]
+ num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
+ num_inputs * sizeof(uint8) // input_type[num_inputs]
- + num_outputs * sizeof(uint8); // output_type[num_outputs]
+ + num_outputs * sizeof(uint8) // output_type[num_outputs]
+ + num_outputs * sizeof(int); // forward_from[num_outputs]
static constexpr size_t kItemAlignment = sizeof(NodeItem*);
static_assert(kItemAlignment % alignof(NodeItem) == 0,
"NodeItem must be aligned with kItemAlignment");
@@ -737,8 +753,8 @@ Status InferAllocAttr(const Node* n, const Node* dst,
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
} else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
parsed_src_name.type != "CPU") {
- // Value is going to be the sink of a local DMA from GPU to CPU (or other
- // types of accelerators).
+ // Value is going to be the sink of a local DMA from GPU to CPU (or
+ // other types of accelerators).
attr->set_gpu_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
} else {
@@ -1022,7 +1038,8 @@ class ExecutorState {
int total_input_tensors = 0;
std::vector<const Node*>* nodes = nullptr;
- // Lock ordering: ExecutorState.mu_ < mu.
+ // Lock ordering: ExecutorState.mu_ < mu;
+ // during structured traversal: parent_frame->mu < mu.
mutex mu;
void InitializeFrameInfo(const string& enter_name) {
@@ -1090,7 +1107,8 @@ class ExecutorState {
void ActivateLoopInvs(const GraphView* gview, int64 iter,
TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
- // Add a new loop invariant and make it available to all active iterations.
+ // Add a new loop invariant and make it available to all active
+ // iterations.
void AddLoopInv(const NodeItem* item, const Entry& value,
TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
@@ -1147,8 +1165,8 @@ class ExecutorState {
if (front_index_ == ready_.size()) {
ready_.clear();
} else {
- // Lots of unused entries at beginning of vector: move everything down
- // to start of vector.
+ // Lots of unused entries at beginning of vector: move everything
+ // down to start of vector.
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
}
front_index_ = 0;
@@ -1596,6 +1614,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
params.is_input_dead = is_input_dead;
params.output_attr_array = item.output_attrs();
+ params.forward_from_array = nullptr; // later: item.forward_from();
if (item.kernel_is_async) {
// Asynchronous computes.
@@ -2333,8 +2352,9 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
FrameState* parent_frame = frame->parent_frame;
const int64 parent_iter = frame->parent_iter;
if (parent_frame != nullptr) {
- mutex_lock paranet_frame_lock(parent_frame->mu);
+ mutex_lock parent_frame_lock(parent_frame->mu);
// Propagate all the dead exits to the parent frame.
+ mutex_lock this_frame_lock(frame->mu);
for (const Node* node : frame->dead_exits) {
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
for (const Edge* e : node->out_edges()) {
@@ -2603,7 +2623,7 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
(new ExecutorState(args, this))->RunAsync(std::move(done));
}
-} // end namespace
+} // namespace
Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
@@ -2629,4 +2649,4 @@ Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
-} // end namespace tensorflow
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index d17ef4d459..61b2f0e60f 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -53,8 +54,8 @@ Status GetOpSig(const string& op, const OpDef** sig) {
return OpRegistry::Global()->LookUpOpDef(op, sig);
}
-void HasError(const Status& s, const string& substr) {
- EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+void HasError(const Status& s, StringPiece substr) {
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
<< s << ", expected substring " << substr;
}
@@ -240,7 +241,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status status2 = Run(flr, handle, opts, args, std::move(rets));
EXPECT_TRUE(errors::IsInvalidArgument(status2));
EXPECT_TRUE(
- StringPiece(status2.error_message()).contains("remote execution."));
+ str_util::StrContains(status2.error_message(), "remote execution."));
return status;
}
@@ -310,7 +311,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status status2 = Run(flr, handle, opts, args, std::move(rets));
EXPECT_TRUE(errors::IsInvalidArgument(status2));
EXPECT_TRUE(
- StringPiece(status2.error_message()).contains("remote execution."));
+ str_util::StrContains(status2.error_message(), "remote execution."));
return status;
}
diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc
index 6223a4e648..2d09e83d01 100644
--- a/tensorflow/core/common_runtime/function_threadpool_test.cc
+++ b/tensorflow/core/common_runtime/function_threadpool_test.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -153,7 +154,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status status2 = Run(flr, handle, opts, args, std::move(rets));
EXPECT_TRUE(errors::IsInvalidArgument(status2));
EXPECT_TRUE(
- StringPiece(status2.error_message()).contains("remote execution."));
+ str_util::StrContains(status2.error_message(), "remote execution."));
return status;
}
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index e128b9257f..86851c2c07 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -151,7 +152,8 @@ class ColocationGraph {
if (attr_value != nullptr && attr_value->has_list()) {
for (const string& class_spec : attr_value->list().s()) {
StringPiece spec(class_spec);
- if (spec.Consume(kColocationGroupPrefixStringPiece)) {
+ if (str_util::ConsumePrefix(&spec,
+ kColocationGroupPrefixStringPiece)) {
found_spec = true;
TF_RETURN_IF_ERROR(
ColocateNodeToGroup(&colocation_group_root, node, spec));
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 098024d219..5ad251c892 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@@ -262,9 +263,9 @@ class PlacerTest : public ::testing::Test {
->attributes() \
.device_type())
-#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \
- EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \
- .contains(device_substr))
+#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \
+ EXPECT_TRUE(::tensorflow::str_util::StrContains( \
+ GetNodeByName((g), (name))->assigned_device_name(), device_substr))
// Test that a graph with no constraints will successfully assign nodes to the
// "best available" device (i.e. prefer GPU over CPU).
@@ -488,11 +489,10 @@ TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) {
Status s = Place(&g);
EXPECT_EQ(error::INTERNAL, s.code());
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains(
- "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' "
- "does not have registered OpKernel support for TestInput"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' "
+ "does not have registered OpKernel support for TestInput"));
}
// Test that graphs with reference connections are correctly placed.
@@ -541,15 +541,15 @@ TEST_F(PlacerTest, TestReferenceConnection) {
{
Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", "FakeCPU");
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("no device type supports both of those nodes"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "no device type supports both of those nodes"));
}
TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", "FakeGPU"));
{
Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", "FakeCPU");
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("no device type supports both of those nodes"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "no device type supports both of those nodes"));
}
TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", "FakeGPU"));
}
@@ -760,8 +760,9 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) {
}
Status s = Place(&g);
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Cannot colocate nodes 'foo' and 'in' because no "
+ EXPECT_TRUE(
+ str_util::StrContains(s.error_message(),
+ "Cannot colocate nodes 'foo' and 'in' because no "
"device type supports both of those nodes and the "
"other nodes colocated with them"));
}
@@ -824,11 +825,11 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) {
}
Status s = Place(&g);
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains("Cannot colocate nodes 'var3' and 'assign3' because no "
- "device type supports both of those nodes and the other "
- "nodes colocated with them."));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot colocate nodes 'var3' and 'assign3' because no "
+ "device type supports both of those nodes and the other "
+ "nodes colocated with them."));
}
TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
@@ -888,7 +889,7 @@ TEST_F(PlacerTest, TestEmptyDeviceSet) {
Status s = Place(&g, &empty);
EXPECT_TRUE(
- StringPiece(s.error_message()).contains("No devices are registered"));
+ str_util::StrContains(s.error_message(), "No devices are registered"));
}
// Test that placement fails when the requested device forces an
@@ -913,16 +914,17 @@ TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) {
heterogeneous.AddDevice(cpu.get());
Status s = Place(&g, &heterogeneous);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("colocated with a group of nodes that required "
+ EXPECT_TRUE(
+ str_util::StrContains(s.error_message(),
+ "colocated with a group of nodes that required "
"incompatible device"));
// The error message should contain information that indicates which
// op types have which registered device types.
- EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: FakeGPU"))
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "VariableGPU: FakeGPU"))
<< s;
EXPECT_TRUE(
- StringPiece(s.error_message()).contains("TestAssign: FakeGPU FakeCPU"))
+ str_util::StrContains(s.error_message(), "TestAssign: FakeGPU FakeCPU"))
<< s;
}
@@ -937,7 +939,7 @@ TEST_F(PlacerTest, TestUnknownDevice) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "/job:foo"));
}
// Test that placement fails when the combination of partial
@@ -952,7 +954,7 @@ TEST_F(PlacerTest, TestUnknownMergedDevice) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "/job:foo"));
}
// Test that placement fails when the previously-assigned device for a
@@ -969,9 +971,9 @@ TEST_F(PlacerTest, TestUnknownAssignedDevice) {
Status s = Place(&g);
EXPECT_EQ(error::INTERNAL, s.code());
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains("Assigned device '/job:foo' does not match any device"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Assigned device '/job:foo' does not match any device"));
}
// Test that placement fails when an op with no registered kernels is
@@ -986,12 +988,11 @@ TEST_F(PlacerTest, TestNoKernelsRegistered) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "No OpKernel was registered to support Op 'VariableNoKernels'"));
EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains(
- "No OpKernel was registered to support Op 'VariableNoKernels'"));
- EXPECT_TRUE(
- StringPiece(s.error_message()).contains("<no registered kernels>"));
+ str_util::StrContains(s.error_message(), "<no registered kernels>"));
}
// Test that placement fails when a kernel is registered but no known
@@ -1011,10 +1012,10 @@ TEST_F(PlacerTest, TestNoDevicesRegistered) {
Status s = Place(&g, &cpu_only);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("No OpKernel was registered to support "
- "Op 'VariableGPU'"));
- EXPECT_TRUE(StringPiece(s.error_message()).contains("device='FakeGPU'"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "No OpKernel was registered to support Op 'VariableGPU'"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "device='FakeGPU'"));
}
// Test that placement fails when a requested device is malformed.
@@ -1028,8 +1029,8 @@ TEST_F(PlacerTest, TestMalformedDeviceSpecification) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Malformed device specification '/foo:bar'"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Malformed device specification '/foo:bar'"));
}
// Test that placement fails when a previously-assigned device is malformed.
@@ -1045,8 +1046,8 @@ TEST_F(PlacerTest, TestMalformedAssignedDevice) {
Status s = Place(&g);
EXPECT_EQ(error::INTERNAL, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Malformed assigned device '/foo:bar'"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Malformed assigned device '/foo:bar'"));
}
// Test that placement fails when a device was previously assigned to
@@ -1063,9 +1064,8 @@ TEST_F(PlacerTest, TestNonUniqueAssignedDevice) {
Status s = Place(&g);
EXPECT_EQ(error::INTERNAL, s.code());
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains("Assigned device '/job:a' does not match any device"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Assigned device '/job:a' does not match any device"));
}
// Test that ops request to be placed on non-existent devices will be relocated
@@ -1099,7 +1099,7 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
SessionOptions options;
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakegpu:11"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakegpu:11"));
}
// Test that placement fails when a node requests an explicit device that is not
@@ -1116,10 +1116,10 @@ TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
SessionOptions options;
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakecpu:0"));
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains("no supported kernel for fakecpu devices is available"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "/device:fakecpu:0"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "no supported kernel for fakecpu devices is available"));
}
// Test that placement fails when a node requests an explicit device that is not
@@ -1137,9 +1137,9 @@ TEST_F(PlacerTest, TestNonExistentDevice) {
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
LOG(WARNING) << s.error_message();
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("was explicitly assigned to /job:foo/replica:17 "
- "but available devices"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "was explicitly assigned to /job:foo/replica:17 but available devices"));
}
TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
@@ -1205,8 +1205,8 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Cannot colocate nodes 'var' and 'assign'"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Cannot colocate nodes 'var' and 'assign'"));
}
// Test that a generator node follows its consumers (where there are several
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index d69e8bc2a0..c7b8259f78 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -155,7 +155,10 @@ class ProcessFunctionLibraryRuntime {
string target_device() { return target_device_; }
- FunctionLibraryRuntime::LocalHandle local_handle() { return local_handle_; }
+ FunctionLibraryRuntime::LocalHandle local_handle() {
+ mutex_lock l(mu_);
+ return local_handle_;
+ }
// Initializes the FunctionData object by potentially making an Initialize
// call to the DistributedFunctionLibraryRuntime.
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 2da67b084a..4fbf2abc67 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -132,7 +133,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
});
done2.WaitForNotification();
EXPECT_TRUE(errors::IsNotFound(status));
- EXPECT_TRUE(StringPiece(status.error_message()).contains("not found."));
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found."));
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc
index a074154450..feaf29c7bb 100644
--- a/tensorflow/core/common_runtime/session_test.cc
+++ b/tensorflow/core/common_runtime/session_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
@@ -31,10 +32,9 @@ TEST(SessionTest, InvalidTargetReturnsNull) {
Session* session;
Status s = tensorflow::NewSession(options, &session);
EXPECT_EQ(s.code(), error::NOT_FOUND);
- EXPECT_TRUE(
- StringPiece(s.error_message())
- .contains(
- "No session factory registered for the given session options"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "No session factory registered for the given session options"));
}
// Register a fake session factory to test error handling paths in
@@ -44,7 +44,7 @@ class FakeSessionFactory : public SessionFactory {
FakeSessionFactory() {}
bool AcceptsOptions(const SessionOptions& options) override {
- return StringPiece(options.target).starts_with("fake");
+ return str_util::StartsWith(options.target, "fake");
}
Session* NewSession(const SessionOptions& options) override {
@@ -68,9 +68,9 @@ TEST(SessionTest, MultipleFactoriesForTarget) {
Status s = tensorflow::NewSession(options, &session);
EXPECT_EQ(s.code(), error::INTERNAL);
EXPECT_TRUE(
- StringPiece(s.error_message()).contains("Multiple session factories"));
- EXPECT_TRUE(StringPiece(s.error_message()).contains("FAKE_SESSION_1"));
- EXPECT_TRUE(StringPiece(s.error_message()).contains("FAKE_SESSION_2"));
+ str_util::StrContains(s.error_message(), "Multiple session factories"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "FAKE_SESSION_1"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "FAKE_SESSION_2"));
}
} // namespace
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index cef50be3b1..1b7e3138ee 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -351,6 +351,11 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
}
}
}
+ if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
+ // The input value may have changed. Since we have no way to know if
+ // that's indeed the case, err on the safe side.
+ *refined = true;
+ }
// Also propagate handle shape and dtype of edges which are carrying
// resource handles.
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc
index adf5a9afff..f48638afc0 100644
--- a/tensorflow/core/common_runtime/shape_refiner_test.cc
+++ b/tensorflow/core/common_runtime/shape_refiner_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
@@ -143,8 +144,8 @@ TEST_F(ShapeRefinerTest, BadShapes) {
// an error.
Status s = m.AddNode(mm.node());
ASSERT_FALSE(s.ok());
- ASSERT_TRUE(StringPiece(s.error_message())
- .contains("Dimensions must be equal, but are 1 and 2"));
+ ASSERT_TRUE(str_util::StrContains(
+ s.error_message(), "Dimensions must be equal, but are 1 and 2"));
}
TEST_F(ShapeRefinerTest, SetShape) {
@@ -1032,8 +1033,8 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
TF_ASSERT_OK(m.AddNode(pack.node()));
- EXPECT_TRUE(
- StringPiece(m.AddNode(result).error_message()).contains("but is rank 2"));
+ EXPECT_TRUE(str_util::StrContains(m.AddNode(result).error_message(),
+ "but is rank 2"));
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index 049eec347c..bafd9bfc68 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -144,9 +144,9 @@ BaseRemoteRendezvous::~BaseRemoteRendezvous() {
// Returns true if "device_name" is a valid full name of local device
// of the "worker". This helper is purely based on the worker name
// and device name and does no lookups in the worker->device_mgr.
-static bool IsLocalDevice(const string& worker_name,
+static bool IsLocalDevice(const StringPiece worker_name,
const StringPiece device_name) {
- return device_name.starts_with(worker_name);
+ return str_util::StartsWith(device_name, worker_name);
}
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 120a33f17b..3e79a40683 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/master.pb.h"
@@ -402,7 +403,7 @@ Status GrpcSession::Reset(const SessionOptions& options,
class GrpcSessionFactory : public SessionFactory {
public:
bool AcceptsOptions(const SessionOptions& options) override {
- return StringPiece(options.target).starts_with(kSchemePrefix);
+ return str_util::StartsWith(options.target, kSchemePrefix);
}
Session* NewSession(const SessionOptions& options) override {
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index a382b8be95..6182f95f28 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -61,6 +61,26 @@ static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.
static bool cpu_allocator_collect_full_stats = false;
+// Individual allocations large than this amount will trigger a warning.
+static const double kLargeAllocationWarningThreshold = 0.1;
+
+// If cpu_allocator_collect_stats is true, warn when the total allocated memory
+// exceeds this threshold.
+static const double kTotalAllocationWarningThreshold = 0.5;
+
+// Cache first invocation to port::AvailableRam, as it can be expensive.
+static int64_t LargeAllocationWarningBytes() {
+ static int64_t value = static_cast<int64>(port::AvailableRam() *
+ kLargeAllocationWarningThreshold);
+ return value;
+}
+
+static int64_t TotalAllocationWarningBytes() {
+ static int64_t value = static_cast<int64>(port::AvailableRam() *
+ kTotalAllocationWarningThreshold);
+ return value;
+}
+
void EnableCPUAllocatorStats(bool enable) {
cpu_allocator_collect_stats = enable;
}
@@ -70,7 +90,8 @@ void EnableCPUAllocatorFullStats(bool enable) {
class CPUAllocator : public VisitableAllocator {
public:
- CPUAllocator() : allocation_begun_(false) {}
+ CPUAllocator()
+ : total_allocation_warning_triggered_(false), allocation_begun_(false) {}
~CPUAllocator() override {}
@@ -81,6 +102,12 @@ class CPUAllocator : public VisitableAllocator {
allocation_begun_ = true;
}
+ if (num_bytes > LargeAllocationWarningBytes()) {
+ LOG(WARNING) << "Allocation of " << num_bytes << " exceeds "
+ << 100 * kLargeAllocationWarningThreshold
+ << "% of system memory.";
+ }
+
void* p = port::AlignedMalloc(num_bytes, alignment);
if (cpu_allocator_collect_stats) {
const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p);
@@ -91,6 +118,14 @@ class CPUAllocator : public VisitableAllocator {
std::max<int64>(stats_.max_bytes_in_use, stats_.bytes_in_use);
stats_.max_alloc_size =
std::max<int64>(stats_.max_alloc_size, alloc_size);
+
+ if (stats_.bytes_in_use > TotalAllocationWarningBytes() &&
+ !total_allocation_warning_triggered_) {
+ LOG(WARNING) << "Total allocated memory " << stats_.bytes_in_use
+ << "exceeds " << 100 * kTotalAllocationWarningThreshold
+ << "% of system memory";
+ total_allocation_warning_triggered_ = true;
+ }
}
// visit each Visitor in alloc_visitors_
@@ -162,6 +197,7 @@ class CPUAllocator : public VisitableAllocator {
private:
mutex mu_;
AllocatorStats stats_ GUARDED_BY(mu_);
+ bool total_allocation_warning_triggered_ GUARDED_BY(mu_);
// visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_.
// While write access is mutually exclusive, reads may happen concurrently.
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index ebb56d525e..87c1ddd15d 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -186,7 +186,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
// check if has_list is false and some other field in attr_value is
// set to flag the error. This test can be made more strict once
// support for GraphDef versions <= 4 is dropped.
- if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
+ if (str_util::StartsWith(type, "list(") && !attr_value.has_list()) {
if (num_set) {
return errors::InvalidArgument(
"AttrValue missing value with expected type '", type, "'");
@@ -197,7 +197,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
}
// Okay to have an empty list, but not to be missing a non-list value.
- if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
+ if (num_set == 0 && !str_util::StartsWith(type, "list(")) {
return errors::InvalidArgument(
"AttrValue missing value with expected type '", type, "'");
}
@@ -241,29 +241,29 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
// Parse type.
string field_name;
- bool is_list = type.Consume("list(");
- if (type.Consume("string")) {
+ bool is_list = str_util::ConsumePrefix(&type, "list(");
+ if (str_util::ConsumePrefix(&type, "string")) {
field_name = "s";
- } else if (type.Consume("int")) {
+ } else if (str_util::ConsumePrefix(&type, "int")) {
field_name = "i";
- } else if (type.Consume("float")) {
+ } else if (str_util::ConsumePrefix(&type, "float")) {
field_name = "f";
- } else if (type.Consume("bool")) {
+ } else if (str_util::ConsumePrefix(&type, "bool")) {
field_name = "b";
- } else if (type.Consume("type")) {
+ } else if (str_util::ConsumePrefix(&type, "type")) {
field_name = "type";
- } else if (type.Consume("shape")) {
+ } else if (str_util::ConsumePrefix(&type, "shape")) {
field_name = "shape";
- } else if (type.Consume("tensor")) {
+ } else if (str_util::ConsumePrefix(&type, "tensor")) {
field_name = "tensor";
- } else if (type.Consume("func")) {
+ } else if (str_util::ConsumePrefix(&type, "func")) {
field_name = "func";
- } else if (type.Consume("placeholder")) {
+ } else if (str_util::ConsumePrefix(&type, "placeholder")) {
field_name = "placeholder";
} else {
return false;
}
- if (is_list && !type.Consume(")")) {
+ if (is_list && !str_util::ConsumePrefix(&type, ")")) {
return false;
}
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 2fb17c2b02..72eeda7a43 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -504,8 +504,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
input_shape =
c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
stride_planes = strides[2];
- stride_cols = strides[3];
- stride_rows = strides[4];
+ stride_rows = strides[3];
+ stride_cols = strides[4];
} else {
stride_planes = strides[1];
stride_rows = strides[2];
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 5f3e5ad457..13d429b895 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -140,9 +141,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains("Invalid argument: Shape must be rank 2 but is rank 1"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 1"));
}
{
@@ -161,10 +161,9 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{S({2, 5}), S({3, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains(
- "Invalid argument: Dimensions must be equal, but are 5 and 3"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "Invalid argument: Dimensions must be equal, but are 5 and 3"));
}
{
@@ -173,9 +172,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains("Invalid argument: Shape must be rank 2 but is rank 3"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: Shape must be rank 2 but is rank 3"));
}
{
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index beaf0adbc5..9e7ffe6c0b 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -201,7 +201,7 @@ class GraphDefBuilderWrapper {
// Also looks up the `op_def->name` in the global
// `WhitelistedStatefulOpRegistry`.
bool IsOpWhitelisted(const OpDef* op_def) const {
- return (StringPiece(op_def->name()).ends_with("Dataset") &&
+ return (str_util::EndsWith(op_def->name(), "Dataset") &&
op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT) ||
dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
@@ -474,11 +474,11 @@ class GraphDatasetBase : public DatasetBase {
}
// Key for storing the Dataset graph in the serialized format.
- static const char kDatasetGraphKey[];
+ TF_EXPORT static const char kDatasetGraphKey[];
// Key for storing the output node of the Dataset graph in the serialized
// format.
- static const char kDatasetGraphOutputNodeKey[];
+ TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
private:
Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 3e7b89d4eb..bdc1af9fda 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -278,7 +279,7 @@ class FunctionInstantiationHelper {
auto it = index_.lower_bound(node_name);
while (it != index_.end() && it->first <= node_colon_bound) {
if (it->first == node_name ||
- tensorflow::StringPiece(it->first).starts_with(node_colon)) {
+ tensorflow::str_util::StartsWith(it->first, node_colon)) {
nid = it->second.nid;
break;
}
@@ -502,7 +503,7 @@ string Print(const NodeDef& n) {
std::vector<StringPiece> dat;
std::vector<string> dep;
for (StringPiece s : n.input()) {
- if (s.Consume("^")) {
+ if (str_util::ConsumePrefix(&s, "^")) {
dep.push_back(s.ToString());
} else {
dat.push_back(s);
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index 23685e9c53..44e1383719 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -496,7 +496,7 @@ MySelect(x:float) -> (z:float) {
}
static void HasError(const Status& s, const string& substr) {
- EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
<< ">>" << s << "<<, expected substring >>" << substr << "<<";
}
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index 896cb3cd7f..f7539d37be 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb_text.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -94,7 +95,7 @@ static Status RemoveNewDefaultAttrsFromNodeDef(
std::vector<string> to_remove;
for (const auto& attr : node_def->attr()) {
// If the attr is not in consumer_op_def and doesn't start with '_'...
- if (!StringPiece(attr.first).starts_with("_") &&
+ if (!str_util::StartsWith(attr.first, "_") &&
FindAttr(attr.first, *consumer_op_def) == nullptr) {
const OpDef::AttrDef* producer_attr_def =
FindAttr(attr.first, *producer_op_def);
diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc
index e836873f66..cc583df348 100644
--- a/tensorflow/core/framework/node_def_builder_test.cc
+++ b/tensorflow/core/framework/node_def_builder_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -82,7 +83,7 @@ class NodeDefBuilderTest : public ::testing::Test {
EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
if (status.ok()) return;
for (const string& message : messages) {
- EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), message))
<< status << ", " << message;
}
}
@@ -103,7 +104,7 @@ class NodeDefBuilderTest : public ::testing::Test {
}
EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
if (status.ok()) return;
- EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), message))
<< "Actual error: " << status.error_message()
<< "\nDoes not contain: " << message;
}
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 95fb386314..bad92ca9b3 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -131,7 +132,7 @@ Status AttrSlice::Find(StringPiece attr_name,
// Skip AttachDef for internal attrs since it is a little bit
// expensive and it is common for them to correctly not be included
// in a NodeDef.
- if (!attr_name.starts_with("_") && ndef_ != nullptr) {
+ if (!str_util::StartsWith(attr_name, "_") && ndef_ != nullptr) {
s = AttachDef(s, *ndef_);
}
return s;
@@ -399,7 +400,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
size_t num_inputs = 0;
// TODO(josh11b): Unify the input field validation.
for (const string& input : node_def.input()) {
- if (StringPiece(input).starts_with("^")) {
+ if (str_util::StartsWith(input, "^")) {
seen_control = true;
if (input.find(':') != string::npos) {
return errors::InvalidArgument(
@@ -425,7 +426,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
}
for (const auto& attr : node_def.attr()) {
// Allow internal optional attributes with names starting with "_".
- if (StringPiece(attr.first).starts_with("_")) {
+ if (str_util::StartsWith(attr.first, "_")) {
continue;
}
auto iter = op_attrs.find(attr.first);
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index ae3a93eafe..2a49425dba 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -65,7 +66,7 @@ void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
<< "; OpDef: " << SummarizeOpDef(op_def);
LOG(INFO) << "Message: " << status.error_message();
- EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ EXPECT_TRUE(str_util::StrContains(status.ToString(), message))
<< "NodeDef: " << SummarizeNodeDef(bad)
<< "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
<< "\nDoes not contain: " << message;
@@ -265,7 +266,7 @@ void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
EXPECT_TRUE(errors::IsInvalidArgument(status))
<< status << "; NodeDef: " << SummarizeNodeDef(bad);
- EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ EXPECT_TRUE(str_util::StrContains(StringPiece(status.ToString()), message))
<< "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
<< message;
}
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index fc5467b3c8..5f68c59fe9 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -142,7 +143,7 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const {
out->Reserve(sorted.size());
for (const auto& item : sorted) {
- if (include_internal || !StringPiece(item.first).starts_with("_")) {
+ if (include_internal || !str_util::StartsWith(item.first, "_")) {
*out->Add() = item.second->op_def;
}
}
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc
index b57bdcb841..c782480f1f 100644
--- a/tensorflow/core/framework/op_compatibility_test.cc
+++ b/tensorflow/core/framework/op_compatibility_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -96,7 +97,7 @@ class OpCompatibilityTest : public OpsTestBase {
ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. "
<< SummarizeOpDef(new_op_def);
} else {
- EXPECT_TRUE(StringPiece(status.error_message()).contains(error))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), error))
<< status << " does not contain " << error;
}
}
@@ -118,7 +119,7 @@ class OpCompatibilityTest : public OpsTestBase {
ADD_FAILURE() << SummarizeNodeDef(*node_def());
} else {
EXPECT_TRUE(
- StringPiece(status.error_message()).contains(validation_error))
+ str_util::StrContains(status.error_message(), validation_error))
<< status << " does not contain " << validation_error;
}
@@ -179,7 +180,7 @@ class OpCompatibilityTest : public OpsTestBase {
<< SummarizeOpDef(*new_op_def);
} else {
EXPECT_TRUE(
- StringPiece(status.error_message()).contains(compatibility_error))
+ str_util::StrContains(status.error_message(), compatibility_error))
<< status << " does not contain " << compatibility_error;
}
}
diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto
index ba545a1994..ca0e5e7133 100644
--- a/tensorflow/core/framework/op_def.proto
+++ b/tensorflow/core/framework/op_def.proto
@@ -126,6 +126,12 @@ message OpDef {
// -------------------------------------------------------------------------
// Optimization constraints.
+ // Ops are marked as stateful if their behavior depends on some state beyond
+ // their input tensors (e.g. variable reading op) or if they have
+ // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
+ // must always produce the same output for the same input and have
+ // no side-effects.
+ //
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 962bc11ccb..403bd0b5e2 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -112,9 +112,11 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
auto capture_begin = sp->begin();
- if (sp->Consume("numbertype") || sp->Consume("numerictype") ||
- sp->Consume("quantizedtype") || sp->Consume("realnumbertype") ||
- sp->Consume("realnumberictype")) {
+ if (str_util::ConsumePrefix(sp, "numbertype") ||
+ str_util::ConsumePrefix(sp, "numerictype") ||
+ str_util::ConsumePrefix(sp, "quantizedtype") ||
+ str_util::ConsumePrefix(sp, "realnumbertype") ||
+ str_util::ConsumePrefix(sp, "realnumberictype")) {
*out = StringPiece(capture_begin, sp->begin() - capture_begin);
return true;
}
@@ -155,32 +157,32 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
bool is_list = ConsumeListPrefix(&spec);
string type;
StringPiece type_string; // Used if type == "type"
- if (spec.Consume("string")) {
+ if (str_util::ConsumePrefix(&spec, "string")) {
type = "string";
- } else if (spec.Consume("int")) {
+ } else if (str_util::ConsumePrefix(&spec, "int")) {
type = "int";
- } else if (spec.Consume("float")) {
+ } else if (str_util::ConsumePrefix(&spec, "float")) {
type = "float";
- } else if (spec.Consume("bool")) {
+ } else if (str_util::ConsumePrefix(&spec, "bool")) {
type = "bool";
- } else if (spec.Consume("type")) {
+ } else if (str_util::ConsumePrefix(&spec, "type")) {
type = "type";
- } else if (spec.Consume("shape")) {
+ } else if (str_util::ConsumePrefix(&spec, "shape")) {
type = "shape";
- } else if (spec.Consume("tensor")) {
+ } else if (str_util::ConsumePrefix(&spec, "tensor")) {
type = "tensor";
- } else if (spec.Consume("func")) {
+ } else if (str_util::ConsumePrefix(&spec, "func")) {
type = "func";
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
VERIFY(ProcessCompoundType(type_string, allowed),
"Expected to see a compound type, saw: ", type_string);
- } else if (spec.Consume("{")) {
+ } else if (str_util::ConsumePrefix(&spec, "{")) {
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
AttrValue* allowed = attr->mutable_allowed_values();
str_util::RemoveLeadingWhitespace(&spec);
- if (spec.starts_with("\"") || spec.starts_with("'")) {
+ if (str_util::StartsWith(spec, "\"") || str_util::StartsWith(spec, "'")) {
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
while (true) {
StringPiece escaped_string;
@@ -193,11 +195,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
"Trouble unescaping \"", escaped_string,
"\", got error: ", error);
allowed->mutable_list()->add_s(unescaped);
- if (spec.Consume(",")) {
+ if (str_util::ConsumePrefix(&spec, ",")) {
str_util::RemoveLeadingWhitespace(&spec);
- if (spec.Consume("}")) break; // Allow ending with ", }".
+ if (str_util::ConsumePrefix(&spec, "}"))
+ break; // Allow ending with ", }".
} else {
- VERIFY(spec.Consume("}"),
+ VERIFY(str_util::ConsumePrefix(&spec, "}"),
"Expected , or } after strings in list, not: '", spec, "'");
break;
}
@@ -215,11 +218,12 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
}
- if (spec.Consume(",")) {
+ if (str_util::ConsumePrefix(&spec, ",")) {
str_util::RemoveLeadingWhitespace(&spec);
- if (spec.Consume("}")) break; // Allow ending with ", }".
+ if (str_util::ConsumePrefix(&spec, "}"))
+ break; // Allow ending with ", }".
} else {
- VERIFY(spec.Consume("}"),
+ VERIFY(str_util::ConsumePrefix(&spec, "}"),
"Expected , or } after types in list, not: '", spec, "'");
break;
}
@@ -232,7 +236,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Write the type into *attr.
if (is_list) {
- VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'");
+ VERIFY(str_util::ConsumePrefix(&spec, ")"),
+ "Expected ) to close 'list(', not: '", spec, "'");
str_util::RemoveLeadingWhitespace(&spec);
attr->set_type(strings::StrCat("list(", type, ")"));
} else {
@@ -240,7 +245,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
}
// Read optional minimum constraint at the end.
- if ((is_list || type == "int") && spec.Consume(">=")) {
+ if ((is_list || type == "int") && str_util::ConsumePrefix(&spec, ">=")) {
int64 min_limit = -999;
VERIFY(ConsumeAttrNumber(&spec, &min_limit),
"Could not parse integer lower limit after '>=', found '", spec,
@@ -250,7 +255,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
}
// Parse default value, if present.
- if (spec.Consume("=")) {
+ if (str_util::ConsumePrefix(&spec, "=")) {
str_util::RemoveLeadingWhitespace(&spec);
VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
"Could not parse default value '", spec, "'");
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index c80802aad3..9be0dc69d2 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -239,7 +240,7 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
Status ValidateOpDef(const OpDef& op_def) {
using ::tensorflow::strings::Scanner;
- if (!StringPiece(op_def.name()).starts_with("_")) {
+ if (!str_util::StartsWith(op_def.name(), "_")) {
VALIDATE(Scanner(op_def.name())
.One(Scanner::UPPERLETTER)
.Any(Scanner::LETTER_DIGIT)
@@ -259,11 +260,11 @@ Status ValidateOpDef(const OpDef& op_def) {
// Validate type
StringPiece type(attr.type());
- bool is_list = type.Consume("list(");
+ bool is_list = str_util::ConsumePrefix(&type, "list(");
bool found = false;
for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
"tensor", "func"}) {
- if (type.Consume(valid)) {
+ if (str_util::ConsumePrefix(&type, valid)) {
found = true;
break;
}
@@ -271,8 +272,9 @@ Status ValidateOpDef(const OpDef& op_def) {
VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
"'");
if (is_list) {
- VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ",
- attr.name(), "'s type ", attr.type());
+ VALIDATE(str_util::ConsumePrefix(&type, ")"),
+ "'list(' is missing ')' in attr ", attr.name(), "'s type ",
+ attr.type());
}
VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
attr.name(), "'s type ", attr.type());
diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc
index 2b9812d4fc..4514d92e38 100644
--- a/tensorflow/core/framework/op_def_util_test.cc
+++ b/tensorflow/core/framework/op_def_util_test.cc
@@ -57,7 +57,7 @@ class ValidateOpDefTest : public ::testing::Test {
EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
if (!status.ok()) {
LOG(INFO) << "message: " << status;
- EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ EXPECT_TRUE(str_util::StrContains(status.ToString(), message))
<< "Actual: " << status << "\nExpected to contain: " << message;
}
}
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index 5f2eb9d99a..7f23272871 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -50,10 +50,10 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) {
StringPiece to_append = str.substr(0, space);
str.remove_prefix(space + 1);
// Remove spaces at break.
- while (to_append.ends_with(" ")) {
+ while (str_util::EndsWith(to_append, " ")) {
to_append.remove_suffix(1);
}
- while (str.Consume(" ")) {
+ while (str_util::ConsumePrefix(&str, " ")) {
}
// Go on to the next line.
@@ -65,8 +65,9 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) {
}
bool ConsumeEquals(StringPiece* description) {
- if (description->Consume("=")) {
- while (description->Consume(" ")) { // Also remove spaces after "=".
+ if (str_util::ConsumePrefix(description, "=")) {
+ while (str_util::ConsumePrefix(description,
+ " ")) { // Also remove spaces after "=".
}
return true;
}
@@ -98,7 +99,7 @@ static bool StartsWithFieldName(StringPiece line,
const std::vector<string>& multi_line_fields) {
StringPiece up_to_colon;
if (!SplitAt(':', &line, &up_to_colon)) return false;
- while (up_to_colon.Consume(" "))
+ while (str_util::ConsumePrefix(&up_to_colon, " "))
; // Remove leading spaces.
for (const auto& field : multi_line_fields) {
if (up_to_colon == field) {
@@ -119,9 +120,9 @@ static bool ConvertLine(StringPiece line,
StringPiece up_to_colon;
StringPiece after_colon = line;
SplitAt(':', &after_colon, &up_to_colon);
- while (after_colon.Consume(" "))
+ while (str_util::ConsumePrefix(&after_colon, " "))
; // Remove leading spaces.
- if (!after_colon.Consume("\"")) {
+ if (!str_util::ConsumePrefix(&after_colon, "\"")) {
// We only convert string fields, so don't convert this line.
return false;
}
@@ -181,9 +182,9 @@ string PBTxtToMultiline(StringPiece pbtxt,
static bool FindMultiline(StringPiece line, size_t colon, string* end) {
if (colon == StringPiece::npos) return false;
line.remove_prefix(colon + 1);
- while (line.Consume(" ")) {
+ while (str_util::ConsumePrefix(&line, " ")) {
}
- if (line.Consume("<<")) {
+ if (str_util::ConsumePrefix(&line, "<<")) {
*end = line.ToString();
return true;
}
@@ -228,7 +229,7 @@ string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
string suffix;
while (!multiline_pbtxt.empty()) {
SplitAt('\n', &multiline_pbtxt, &line);
- if (line.Consume(end)) break;
+ if (str_util::ConsumePrefix(&line, end)) break;
if (first) {
first = false;
} else {
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 9ec1c213c3..cfde1e8ea3 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -365,7 +365,7 @@ Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
const Tensor& OpKernelContext::input(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, num_inputs());
+ DCHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
DCHECK(!input_is_ref(index));
const Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
@@ -420,8 +420,8 @@ bool OpKernelContext::forward_input_to_output_with_shape(
? AllocatorAttributes()
: output_alloc_attr(output_index);
std::unique_ptr<Tensor> new_tensor = forward_input(
- input_index, expected_output_dtype(output_index), output_shape,
- output_memory_type(output_index), output_attr);
+ input_index, output_index, expected_output_dtype(output_index),
+ output_shape, output_memory_type(output_index), output_attr);
if (new_tensor != nullptr) {
// Transfer ownership to the output slot in OpKernelContext.
outputs_[output_index] = TensorValue(new_tensor.release());
@@ -461,35 +461,66 @@ Status OpKernelContext::forward_input_to_output_with_shape(
}
std::unique_ptr<Tensor> OpKernelContext::forward_input(
- int input_index, DataType output_dtype, const TensorShape& output_shape,
- MemoryType output_memory_type, const AllocatorAttributes& output_attr) {
+ int input_index, int output_index, DataType output_dtype,
+ const TensorShape& output_shape, MemoryType output_memory_type,
+ const AllocatorAttributes& output_attr) {
DCHECK_GE(input_index, 0);
DCHECK_LT(input_index, num_inputs());
const TensorValue& input = (*params_->inputs)[input_index];
- // Check that input tensor exists, is not a ref, and has no other consumers.
- if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) {
+ // Check whether at graph construction time this output was marked
+ // either for no forwarding or with a reservation for this input.
+ // If it's reserved for this input we'll skip the refcount and
+ // AllocatorAttribute checks.
+ // TODO(tucker): Maybe we should skip all of the checks?
+ bool never_forward =
+ (params_->forward_from_array != nullptr && output_index >= 0 &&
+ params_->forward_from_array[output_index] == Params::kNeverForward);
+ if (never_forward) return nullptr;
+ bool forward_expected =
+ (params_->forward_from_array != nullptr && output_index >= 0 &&
+ params_->forward_from_array[output_index] == input_index);
+ if (!forward_expected && params_->forward_from_array != nullptr) {
+ // Check for possibly conflicting forward.
+ for (int i = 0; i < num_outputs(); ++i) {
+ if (params_->forward_from_array[i] == input_index) {
+ // This input is reserved for output i.
+ return nullptr;
+ }
+ }
+ }
+ // Check that input tensor exists and is not a ref.
+ if (input.tensor == nullptr || input.is_ref()) {
+ CHECK(!forward_expected);
return nullptr;
}
// Check that input type matches.
if (input_dtype(input_index) != output_dtype) {
+ CHECK(!forward_expected);
return nullptr;
}
// Check that the input and output sizes are compatible.
if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
+ CHECK(!forward_expected);
return nullptr;
}
// Check that input and output memory types match, i.e.
// that they either both live in host or both live in device memory.
if (input_memory_type(input_index) != output_memory_type) {
+ CHECK(!forward_expected);
return nullptr;
}
- // Check that output allocator attributes are not more restrictive than
- // input allocator attributes.
- const auto input_attr = params_->input_alloc_attrs == nullptr
- ? AllocatorAttributes()
- : input_alloc_attr(input_index);
- if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
- return nullptr;
+ if (!forward_expected) {
+ if (!input->RefCountIsOne()) {
+ return nullptr;
+ }
+ // Check that output allocator attributes are not more restrictive than
+ // input allocator attributes.
+ const auto input_attr = params_->input_alloc_attrs == nullptr
+ ? AllocatorAttributes()
+ : input_alloc_attr(input_index);
+ if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
+ return nullptr;
+ }
}
// TODO(rmlarsen): Use MakeUnique here. There is already a copy in
// tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
@@ -505,7 +536,8 @@ Status OpKernelContext::forward_input_or_allocate_temp(
Tensor* out_temp) {
for (int input_index : candidate_input_indices) {
std::unique_ptr<Tensor> new_tensor =
- forward_input(input_index, type, shape, DEVICE_MEMORY, allocator_attr);
+ forward_input(input_index, Params::kNoReservation /*output_index*/,
+ type, shape, DEVICE_MEMORY, allocator_attr);
if (new_tensor != nullptr) {
*out_temp = std::move(*new_tensor);
return Status::OK();
@@ -595,6 +627,14 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
Tensor** output) {
DCHECK_GE(index, 0);
DCHECK_LT(index, num_outputs());
+ bool forward_expected =
+ (params_->forward_from_array != nullptr && index >= 0 &&
+ params_->forward_from_array[index] >= 0);
+ if (forward_expected) {
+ return errors::Internal(
+ "Explicit allocate_output call where input forwarding required. Try "
+ "turning off the ScopedAllocator optimizer.");
+ }
AllocatorAttributes attr = output_alloc_attr(index);
return allocate_output(index, shape, output, attr);
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 2d97160830..67943377b9 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -64,10 +64,11 @@ class AsyncOpKernel;
class CallFrameInterface;
class FunctionLibraryRuntime;
class OpKernelConstruction; // declared below
-class OpKernelContext; // declared below
+class OpKernelContext; // declared below,
class OpRegistryInterface;
class ResourceMgr;
class ScopedStepContainer;
+class CollectiveExecutor;
class StepStatsCollector;
class OpKernel {
@@ -532,6 +533,10 @@ class OpKernelContext {
// computations running on other devices.
Rendezvous* rendezvous = nullptr;
+ // Mechanism for executing a collective op that needs to coordinate
+ // with parallel instances runing on other devices.
+ CollectiveExecutor* collective_executor = nullptr;
+
// The session state for this op.
SessionState* session_state = nullptr;
@@ -565,6 +570,12 @@ class OpKernelContext {
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
+
+ // Support for forwarding reservations (used by ScopedAllocator).
+ static const int kNeverForward = -2;
+ static const int kNoReservation = -1;
+ // Values in [0,...) represent reservations for the indexed output.
+ const int* forward_from_array = nullptr;
};
// params must outlive the OpKernelContext.
@@ -707,14 +718,31 @@ class OpKernelContext {
// input[input_index] are compatible with those given in dtype, shape,
// memory_type, and attr,
// * refcount on the underlying buffer is one.
+ // * Either there is no forwarding reservation for either input_index
+ // or output_index or the specified input is reserved for the specified
+ // output. More precisely:
+ //
+ // These cases mean neither input nor output has a reservation:
+ // forward_from_array = nullptr
+ // OR (input_index is not in forward_from_array AND
+ // (output_index == kNoReservation OR
+ // forward_from_array[output_index] == kNoReservation))
+ //
+ // This case means that input_index is reserved for output_index:
+ // forward_from_array[output_index] == input_index
+ //
+ // This case means the output is reserved to always be allocated,
+ // never assigned a forwarded input:
+ // forward_from_array[output_index] == kNeverForward
+ //
// Otherwise returns nullptr.
// NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
// forwarding is only safe if there are no reads via __ldg() after writes
// to the same address.
std::unique_ptr<Tensor> forward_input(
- int input_index, DataType dtype, const TensorShape& shape,
- MemoryType memory_type,
- const AllocatorAttributes& attr) TF_MUST_USE_RESULT;
+ int input_index, int output_index, DataType output_dtype,
+ const TensorShape& output_shape, MemoryType output_memory_type,
+ const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
@@ -934,6 +962,10 @@ class OpKernelContext {
// Rendezvous Send() and Recv().
Rendezvous* rendezvous() const { return params_->rendezvous; }
+ CollectiveExecutor* collective_executor() const {
+ return params_->collective_executor;
+ }
+
// An op kernel can access the session state it belongs to.
SessionState* session_state() const { return params_->session_state; }
@@ -1102,7 +1134,7 @@ class OpKernelContext {
Status status_;
friend class CollectiveExecutor; // for access to params_
- Params* params_; // not owned
+ Params* params_; // not owned
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
gtl::InlinedVector<TensorValue, 4> outputs_;
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index b53b877f28..bcd409e5c5 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -546,9 +546,9 @@ TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
{"T|list(type)|[DT_FLOAT]"}));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
- EXPECT_TRUE(
- StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
- .contains("Invalid argument: "));
+ EXPECT_TRUE(str_util::StrContains(
+ GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
+ "Invalid argument: "));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
error::INVALID_ARGUMENT);
@@ -565,8 +565,8 @@ TEST_F(OpKernelBuilderTest, DuplicateKernel) {
DeviceTypeVector devs;
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Multiple OpKernel registrations match NodeDef"));
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(), "Multiple OpKernel registrations match NodeDef"));
ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
}
@@ -585,8 +585,8 @@ TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
DeviceTypeVector devs;
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Multiple OpKernel registrations match NodeDef"));
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(), "Multiple OpKernel registrations match NodeDef"));
ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
error::INVALID_ARGUMENT);
@@ -606,8 +606,9 @@ TEST_F(OpKernelBuilderTest, BadConstraint) {
DeviceTypeVector devs;
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("OpKernel 'BadConstraint' has constraint on attr "
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(),
+ "OpKernel 'BadConstraint' has constraint on attr "
"'T' not in NodeDef"));
ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
index 07272e2374..798220d4c3 100644
--- a/tensorflow/core/framework/resource_mgr_test.cc
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@@ -71,7 +72,7 @@ string LookupOrCreate(ResourceMgr* rm, const string& container,
}
static void HasError(const Status& s, const string& substr) {
- EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
<< s << ", expected substring " << substr;
}
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index f48a7b9c47..da103bfec9 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@@ -152,10 +153,9 @@ TEST_F(ShapeInferenceTest, Run) {
};
Status s = c.Run(fn);
// Extra error message is attached when Run fails.
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Shape must be at most rank 0 but "
- "is rank 1 for 'foo' (op: "
- "'foo_op')"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "Shape must be at most rank 0 but is rank 1 for 'foo' (op: 'foo_op')"))
<< s;
}
}
@@ -367,10 +367,9 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
// WithRankAtMost on shape with known dimensionality.
s1 = in1;
- EXPECT_TRUE(
- StringPiece(c.WithRankAtMost(in1, 2, &s1).ToString())
- .contains(
- "Invalid argument: Shape must be at most rank 2 but is rank 3"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.WithRankAtMost(in1, 2, &s1).ToString(),
+ "Invalid argument: Shape must be at most rank 2 but is rank 3"));
EXPECT_FALSE(IsSet(s1));
EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok());
@@ -406,10 +405,9 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
// WithRankAtLeast on shape with known dimensionality.
s1 = in1;
- EXPECT_TRUE(
- StringPiece(c.WithRankAtLeast(in1, 4, &s1).ToString())
- .contains(
- "Invalid argument: Shape must be at least rank 4 but is rank 3"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.WithRankAtLeast(in1, 4, &s1).ToString(),
+ "Invalid argument: Shape must be at least rank 4 but is rank 3"));
EXPECT_FALSE(IsSet(s1));
EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok());
@@ -449,12 +447,14 @@ TEST_F(ShapeInferenceTest, WithValue) {
// WithValue on dimension with known size.
out1 = d0;
- EXPECT_TRUE(StringPiece(c.WithValue(d0, 0, &out1).ToString())
- .contains("Invalid argument: Dimension must be 0 but is 1"));
+ EXPECT_TRUE(
+ str_util::StrContains(c.WithValue(d0, 0, &out1).ToString(),
+ "Invalid argument: Dimension must be 0 but is 1"));
EXPECT_FALSE(IsSet(out1));
out1 = d0;
- EXPECT_TRUE(StringPiece(c.WithValue(d0, 2, &out1).ToString())
- .contains("Invalid argument: Dimension must be 2 but is 1"));
+ EXPECT_TRUE(
+ str_util::StrContains(c.WithValue(d0, 2, &out1).ToString(),
+ "Invalid argument: Dimension must be 2 but is 1"));
EXPECT_FALSE(IsSet(out1));
EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok());
@@ -513,16 +513,14 @@ TEST_F(ShapeInferenceTest, MergeDim) {
EXPECT_EQ(3, merged_dims.size());
// Merging unequal values is an error.
- EXPECT_TRUE(
- StringPiece(c.Merge(d2, d1, &out).ToString())
- .contains(
- "Invalid argument: Dimensions must be equal, but are 2 and 1"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Merge(d2, d1, &out).ToString(),
+ "Invalid argument: Dimensions must be equal, but are 2 and 1"));
EXPECT_FALSE(IsSet(out));
- EXPECT_TRUE(
- StringPiece(c.Merge(d1, d2, &out).ToString())
- .contains(
- "Invalid argument: Dimensions must be equal, but are 1 and 2"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Merge(d1, d2, &out).ToString(),
+ "Invalid argument: Dimensions must be equal, but are 1 and 2"));
EXPECT_FALSE(IsSet(out));
@@ -729,26 +727,23 @@ TEST_F(ShapeInferenceTest, MergeShape) {
// Incompatible merges give errors and set out to nullptr.
out = s_unknown;
- EXPECT_TRUE(
- StringPiece(c.Merge(s_u_2, s_1_3, &out).ToString())
- .contains(
- "Invalid argument: Dimension 1 in both shapes must be equal, but "
- "are 2 and 3"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Merge(s_u_2, s_1_3, &out).ToString(),
+ "Invalid argument: Dimension 1 in both shapes must be equal, but "
+ "are 2 and 3"));
EXPECT_FALSE(IsSet(out));
out = s_unknown;
- EXPECT_TRUE(
- StringPiece(c.Merge(s_1_3, s_u_2, &out).ToString())
- .contains(
- "Invalid argument: Dimension 1 in both shapes must be equal, but "
- "are 3 and 2"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Merge(s_1_3, s_u_2, &out).ToString(),
+ "Invalid argument: Dimension 1 in both shapes must be equal, but "
+ "are 3 and 2"));
EXPECT_FALSE(IsSet(out));
out = s_unknown;
- EXPECT_TRUE(
- StringPiece(c.Merge(s_1, s_1_2, &out).ToString())
- .contains(
- "Invalid argument: Shapes must be equal rank, but are 1 and 2"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Merge(s_1, s_1_2, &out).ToString(),
+ "Invalid argument: Shapes must be equal rank, but are 1 and 2"));
EXPECT_FALSE(IsSet(out));
@@ -795,22 +790,18 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
// Incompatible merges give errors and set outs to nullptr.
s_out = s_unknown;
s_prefix_out = s_unknown;
- EXPECT_TRUE(
- StringPiece(
- c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString())
- .contains(
- "Invalid argument: Dimensions must be equal, but are 1 and 2"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString(),
+ "Invalid argument: Dimensions must be equal, but are 1 and 2"));
EXPECT_FALSE(IsSet(s_out));
EXPECT_FALSE(IsSet(s_prefix_out));
s_out = s_unknown;
s_prefix_out = s_unknown;
- EXPECT_TRUE(
- StringPiece(
- c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString())
- .contains(
- "Invalid argument: Shape must be at least rank 3 but is rank 2"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString(),
+ "Invalid argument: Shape must be at least rank 3 but is rank 2"));
EXPECT_FALSE(IsSet(s_out));
EXPECT_FALSE(IsSet(s_prefix_out));
}
@@ -868,24 +859,21 @@ TEST_F(ShapeInferenceTest, Subshape) {
// Errors.
out = unknown;
- EXPECT_TRUE(StringPiece(c.Subshape(in0, 6, -3, &out).ToString())
- .contains("Invalid argument: Subshape must have computed "
- "start <= end, but is 5 "
- "and 2 (computed from start 6 and end -3 over "
- "shape with rank 5)"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Subshape(in0, 6, -3, &out).ToString(),
+ "Invalid argument: Subshape must have computed start <= end, but is 5 "
+ "and 2 (computed from start 6 and end -3 over shape with rank 5)"));
EXPECT_FALSE(IsSet(out));
out = unknown;
- EXPECT_TRUE(StringPiece(c.Subshape(in0, -50, 100, &out).ToString())
- .contains("Invalid argument: Subshape start out of "
- "bounds: -50, for shape with "
- "rank 5"));
+ EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, -50, 100, &out).ToString(),
+ "Invalid argument: Subshape start out of "
+ "bounds: -50, for shape with rank 5"));
EXPECT_FALSE(IsSet(out));
out = unknown;
- EXPECT_TRUE(StringPiece(c.Subshape(in0, 0, -50, &out).ToString())
- .contains("Invalid argument: Subshape end out of bounds: "
- "-50, for shape with rank "
- "5"));
+ EXPECT_TRUE(str_util::StrContains(c.Subshape(in0, 0, -50, &out).ToString(),
+ "Invalid argument: Subshape end out of "
+ "bounds: -50, for shape with rank 5"));
EXPECT_FALSE(IsSet(out));
}
@@ -1094,27 +1082,26 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
EXPECT_EQ("[]", create(&t));
t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
- EXPECT_TRUE(
- StringPiece(create(&t))
- .contains("Input tensor must be int32 or int64, but was float"));
+ EXPECT_TRUE(str_util::StrContains(
+ create(&t), "Input tensor must be int32 or int64, but was float"));
t = ::tensorflow::test::AsScalar<int32>(1);
- EXPECT_TRUE(StringPiece(create(&t))
- .contains("Input tensor must be rank 1, but was rank 0"));
+ EXPECT_TRUE(str_util::StrContains(
+ create(&t), "Input tensor must be rank 1, but was rank 0"));
t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
- EXPECT_TRUE(StringPiece(create(&t))
- .contains("Input tensor must be rank 1, but was rank 2"));
+ EXPECT_TRUE(str_util::StrContains(
+ create(&t), "Input tensor must be rank 1, but was rank 2"));
// Test negative values for the dims.
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
- EXPECT_TRUE(StringPiece(create(&t))
- .contains("Invalid value in tensor used for shape: -2"));
+ EXPECT_TRUE(str_util::StrContains(
+ create(&t), "Invalid value in tensor used for shape: -2"));
// Test negative values for the dims.
t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
- EXPECT_TRUE(StringPiece(create(&t))
- .contains("Invalid value in tensor used for shape: -2"));
+ EXPECT_TRUE(str_util::StrContains(
+ create(&t), "Invalid value in tensor used for shape: -2"));
// Test when the input shape is wrong.
{
@@ -1172,9 +1159,9 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok());
EXPECT_EQ("?", c.DebugString(out));
proto.add_dim()->set_size(0);
- EXPECT_TRUE(
- StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message())
- .contains("An unknown shape must not have any dimensions set."));
+ EXPECT_TRUE(str_util::StrContains(
+ c.MakeShapeFromShapeProto(proto, &out).error_message(),
+ "An unknown shape must not have any dimensions set."));
EXPECT_FALSE(IsSet(out));
// With known rank.
@@ -1188,10 +1175,10 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
// With invalid dimension value.
proto.add_dim()->set_size(-2);
- EXPECT_TRUE(
- StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message())
- .contains("Shape [0,?,1000,-2] has dimensions with values below -1 "
- "(where -1 means unknown)"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.MakeShapeFromShapeProto(proto, &out).error_message(),
+ "Shape [0,?,1000,-2] has dimensions with values below -1 "
+ "(where -1 means unknown)"));
EXPECT_FALSE(IsSet(out));
}
@@ -1257,9 +1244,10 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
EXPECT_EQ("20", c.DebugString(d));
- EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message())
- .contains("Dimension size, given by scalar input 1, must "
- "be non-negative but is -1"));
+ EXPECT_TRUE(
+ str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
+ "Dimension size, given by scalar input 1, must be "
+ "non-negative but is -1"));
// Same tests, with int64 values.
t1 = tensorflow::test::AsScalar<int64>(20);
@@ -1267,9 +1255,10 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
EXPECT_EQ("20", c.DebugString(d));
- EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message())
- .contains("Dimension size, given by scalar input 1, must "
- "be non-negative but is -1"));
+ EXPECT_TRUE(
+ str_util::StrContains(c.MakeDimForScalarInput(1, &d).error_message(),
+ "Dimension size, given by scalar input 1, must be "
+ "non-negative but is -1"));
}
TEST_F(ShapeInferenceTest, GetAttr) {
@@ -1322,33 +1311,33 @@ TEST_F(ShapeInferenceTest, Divide) {
EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok());
EXPECT_EQ("3", c.DebugString(out));
- EXPECT_TRUE(
- StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message())
- .contains("Dimension size must be evenly divisible by 5 but is 6"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Divide(d_6, 5, evenly_divisible, &out).error_message(),
+ "Dimension size must be evenly divisible by 5 but is 6"));
- EXPECT_TRUE(
- StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message())
- .contains("Divisor must be positive but is 0"));
- EXPECT_TRUE(
- StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message())
- .contains("Divisor must be positive but is 0"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
+ "Divisor must be positive but is 0"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Divide(d_6, d_0, evenly_divisible, &out).error_message(),
+ "Divisor must be positive but is 0"));
- EXPECT_TRUE(
- StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message())
- .contains("Divisor must be positive but is -1"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
+ "Divisor must be positive but is -1"));
// Repeat error cases above with evenly_divisible=false.
evenly_divisible = false;
EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok());
EXPECT_EQ("1", c.DebugString(out));
- EXPECT_TRUE(
- StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message())
- .contains("Divisor must be positive but is 0"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Divide(d_6, 0, evenly_divisible, &out).error_message(),
+ "Divisor must be positive but is 0"));
- EXPECT_TRUE(
- StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message())
- .contains("Divisor must be positive but is -1"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Divide(d_6, -1, evenly_divisible, &out).error_message(),
+ "Divisor must be positive but is -1"));
}
TEST_F(ShapeInferenceTest, Add) {
@@ -1396,11 +1385,9 @@ TEST_F(ShapeInferenceTest, Add) {
EXPECT_TRUE(c.Add(d_0, d_6, &out).ok());
EXPECT_TRUE(SameHandle(out, d_6));
- EXPECT_TRUE(
- StringPiece(c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out)
- .error_message())
- .contains(
- "Dimension size overflow from adding 6 and 9223372036854775802"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Add(d_6, std::numeric_limits<int64>::max() - 5, &out).error_message(),
+ "Dimension size overflow from adding 6 and 9223372036854775802"));
}
TEST_F(ShapeInferenceTest, Subtract) {
@@ -1448,9 +1435,9 @@ TEST_F(ShapeInferenceTest, Subtract) {
EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok());
EXPECT_TRUE(SameHandle(out, d_6));
- EXPECT_TRUE(
- StringPiece(c.Subtract(d_5, d_6, &out).error_message())
- .contains("Negative dimension size caused by subtracting 6 from 5"));
+ EXPECT_TRUE(str_util::StrContains(
+ c.Subtract(d_5, d_6, &out).error_message(),
+ "Negative dimension size caused by subtracting 6 from 5"));
}
TEST_F(ShapeInferenceTest, Multiply) {
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index b4765ab0b2..b54dd220ab 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -100,7 +100,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
}
- if (expected.starts_with("in")) {
+ if (str_util::StartsWith(expected, "in")) {
if (in_index == -1) {
return Unknown(err_prefix,
" should have matched an input shape by "
@@ -135,7 +135,9 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
// Verify the dimensions.
- CHECK(expected.starts_with("[") && expected.ends_with("]")) << expected;
+ CHECK(str_util::StartsWith(expected, "[") &&
+ str_util::EndsWith(expected, "]"))
+ << expected;
expected.remove_prefix(1);
expected.remove_suffix(1);
@@ -176,7 +178,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
return Unknown(err_prefix, " expected to be unknown but was ",
c.Value(out_dim), err_suffix);
}
- } else if (expected_dim.starts_with("d")) {
+ } else if (str_util::StartsWith(expected_dim, "d")) {
// Compare the dimension values.
auto v = str_util::Split(expected_dim, '|');
if (in_dim_idx.first == -1) {
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index 7977841482..2a99af7659 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/version.h"
@@ -83,17 +84,17 @@ class ShapeInferenceTestutil {
"", ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \
op, i, o) \
.error_message())
-#define INFER_ERROR(error_substring, op, i) \
- { \
- string error_message = \
- ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \
- op, i, "e") \
- .error_message(); \
- const string& substring = error_substring; \
- EXPECT_NE("", error_message); \
- EXPECT_TRUE(StringPiece(error_message).contains(substring)) \
- << "Expected to see '" << substring << "' in '" << error_message \
- << "'"; \
+#define INFER_ERROR(error_substring, op, i) \
+ { \
+ string error_message = \
+ ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \
+ op, i, "e") \
+ .error_message(); \
+ const string& substring = error_substring; \
+ EXPECT_NE("", error_message); \
+ EXPECT_TRUE(::tensorflow::str_util::StrContains(error_message, substring)) \
+ << "Expected to see '" << substring << "' in '" << error_message \
+ << "'"; \
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil_test.cc b/tensorflow/core/framework/shape_inference_testutil_test.cc
index 20a6807064..a4405b502c 100644
--- a/tensorflow/core/framework/shape_inference_testutil_test.cc
+++ b/tensorflow/core/framework/shape_inference_testutil_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -25,10 +26,11 @@ namespace shape_inference {
namespace {
-#define EXPECT_CONTAINS(str, substr) \
- do { \
- string s = (str); \
- EXPECT_TRUE(StringPiece(s).contains(substr)) << "String: " << s; \
+#define EXPECT_CONTAINS(str, substr) \
+ do { \
+ string s = (str); \
+ EXPECT_TRUE(::tensorflow::str_util::StrContains(s, substr)) \
+ << "String: " << s; \
} while (false)
static OpShapeInferenceFn* global_fn_ptr = nullptr;
@@ -97,8 +99,8 @@ TEST(ShapeInferenceTestutilTest, Failures) {
auto error_message = ShapeInferenceTestutil::InferShapes(
ShapeInferenceTestOp("NoSuchOp"), "", "")
.error_message();
- EXPECT_TRUE(StringPiece(error_message)
- .starts_with("Op type not registered 'NoSuchOp'"));
+ EXPECT_TRUE(
+ str_util::StartsWith(error_message, "Op type not registered 'NoSuchOp'"));
// Wrong shape error messages.
EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0),
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index adf4e1bae3..2280114de5 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -114,7 +114,7 @@ string DataTypeString(DataType dtype) {
}
bool DataTypeFromString(StringPiece sp, DataType* dt) {
- if (sp.ends_with("_ref")) {
+ if (str_util::EndsWith(sp, "_ref")) {
sp.remove_suffix(4);
DataType non_ref;
if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc
index 60f2b4135a..16b069c70a 100644
--- a/tensorflow/core/framework/types_test.cc
+++ b/tensorflow/core/framework/types_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -140,9 +141,8 @@ TEST(TypesTest, ComplexTypes) {
TEST(TypesTest, IntegerTypes) {
for (auto dt : AllTypes()) {
const string name = DataTypeString(dt);
- const StringPiece n = name;
- EXPECT_EQ(DataTypeIsInteger(dt),
- n.starts_with("int") || n.starts_with("uint"))
+ EXPECT_EQ(DataTypeIsInteger(dt), str_util::StartsWith(name, "int") ||
+ str_util::StartsWith(name, "uint"))
<< "DataTypeInteger failed for " << name;
}
}
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 85e014f804..60fa7bd559 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/port.h"
@@ -259,8 +260,8 @@ TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) {
ClientSession session(root);
std::vector<Tensor> outputs;
Status s = session.Run({create_const}, &outputs);
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("GPU copy from non-DMA string tensor"))
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "GPU copy from non-DMA string tensor"))
<< s.ToString();
}
@@ -365,8 +366,9 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) {
std::vector<Tensor> outputs;
Status err = session.Run({create_op, identity}, &outputs);
EXPECT_EQ(err.code(), errors::Code::INVALID_ARGUMENT);
- EXPECT_TRUE(StringPiece(err.error_message())
- .contains("During Variant Host->Device Copy: non-DMA-copy "
+ EXPECT_TRUE(
+ str_util::StrContains(err.error_message(),
+ "During Variant Host->Device Copy: non-DMA-copy "
"attempted of tensor type: string"))
<< err.error_message();
}
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 06ca211c76..7055e62c0e 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <memory>
+#include "tensorflow/core/lib/strings/str_util.h"
#define EIGEN_USE_THREADS
@@ -130,7 +131,7 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
Variant v = vv_early_exit;
Status s0 = (*shape_fn)(v, &shape);
EXPECT_FALSE(s0.ok());
- EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit!"));
+ EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit!"));
VariantValue vv_ok{false /* early_exit */};
v = vv_ok;
@@ -229,7 +230,7 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(
- StringPiece(s0.error_message()).contains("early exit zeros_like"));
+ str_util::StrContains(s0.error_message(), "early exit zeros_like"));
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
v = vv_ok;
@@ -254,7 +255,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(
- StringPiece(s0.error_message()).contains("early exit zeros_like"));
+ str_util::StrContains(s0.error_message(), "early exit zeros_like"));
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
v = vv_ok;
@@ -299,7 +300,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
Status s0 = BinaryOpVariants<CPUDevice>(
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
EXPECT_FALSE(s0.ok());
- EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
+ EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit add"));
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
v_a = vv_ok;
@@ -325,7 +326,7 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
Status s0 = BinaryOpVariants<GPUDevice>(
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
EXPECT_FALSE(s0.ok());
- EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
+ EXPECT_TRUE(str_util::StrContains(s0.error_message(), "early exit add"));
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
v_a = vv_ok;
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index a7af5e2312..fb8a6c39e6 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -567,6 +567,11 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
inputs[edge->dst_input()] = edge;
}
}
+ // Sort the control inputs for more predictable serialization.
+ std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
+ [](const Edge* a, const Edge* b) -> bool {
+ return a->src()->name() < b->src()->name();
+ });
node_def->clear_input();
node_def->mutable_input()->Reserve(inputs.size());
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 76ee88e684..f15e2ce9fa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/version.h"
@@ -73,7 +74,7 @@ class GraphConstructor {
Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(false),
expect_device_spec(false),
- prefix(in.prefix.empty() || StringPiece(in.prefix).ends_with("/")
+ prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
? in.prefix
: in.prefix + "/"),
uniquify_names(in.uniquify_names),
@@ -436,7 +437,7 @@ Status GraphConstructor::BuildNodeIndex() {
bool in_control_dependence = false;
for (int i = 0; i < node_def.input_size(); ++i) {
StringPiece input_name = node_def.input(i);
- if (!input_name.empty() && input_name.starts_with("^")) {
+ if (!input_name.empty() && str_util::StartsWith(input_name, "^")) {
in_control_dependence = true;
} else if (in_control_dependence) {
return errors::InvalidArgument(
@@ -484,7 +485,7 @@ Status GraphConstructor::InitFromEdges() {
bool has_loop_back_edge = false;
for (int i = 0; i < node_def.input_size(); ++i) {
StringPiece input_name(node_def.input(i));
- if (input_name.starts_with("^")) {
+ if (str_util::StartsWith(input_name, "^")) {
num_control_edges++;
} else {
TensorId id(ParseTensorName(input_name));
@@ -534,7 +535,7 @@ Status GraphConstructor::ValidateColocationConstraints(
if (iter == node_def.attr().end()) return Status::OK();
for (const string& c : iter->second.list().s()) {
StringPiece s(c);
- if (s.Consume(kColocationGroupPrefix) &&
+ if (str_util::ConsumePrefix(&s, kColocationGroupPrefix) &&
gdef_nodes_.find(s) == gdef_nodes_.end()) {
return errors::InvalidArgument(
"Node '", node_def.name(),
@@ -764,7 +765,7 @@ void GraphConstructor::AddPrefixToNodeDef(
// Skip remapped inputs (which already exist in g_ and are not being
// imported).
if (input_already_exists[i]) continue;
- if (input.Consume("^")) {
+ if (str_util::ConsumePrefix(&input, "^")) {
node_def->set_input(i, strings::StrCat("^", prefix_, input));
} else {
node_def->set_input(i, strings::StrCat(prefix_, input));
@@ -776,7 +777,7 @@ void GraphConstructor::AddPrefixToNodeDef(
node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
for (int i = 0; i < list->s_size(); ++i) {
StringPiece v(list->s(i));
- if (v.Consume(kColocationGroupPrefix)) {
+ if (str_util::ConsumePrefix(&v, kColocationGroupPrefix)) {
list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
}
}
@@ -819,7 +820,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
bool updated = false;
for (int i = 0; i < coloc_values.size(); ++i) {
StringPiece val(coloc_values[i]);
- if (val.Consume(kColocationGroupPrefix)) {
+ if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
const auto& name_pair = uniquified_names_.find(val.ToString());
if (name_pair == uniquified_names_.end()) continue;
updated = true;
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 963c1dc024..c18ccf6ce4 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -156,7 +156,9 @@ class GraphConstructorTest : public ::testing::Test {
return "";
}
StringPiece loc(value[0]);
- return loc.Consume(kColocationGroupPrefix) ? loc.ToString() : "";
+ return str_util::ConsumePrefix(&loc, kColocationGroupPrefix)
+ ? loc.ToString()
+ : "";
}
string GraphDebugString() const {
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 17a174101b..877e4f1b44 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -372,7 +373,7 @@ string ControlLoopName(const string& name) {
bool IsControlLoop(const Node* node) {
const string& name = node->name();
- return StringPiece(name).starts_with("_cloop");
+ return str_util::StartsWith(name, "_cloop");
}
// An enter node for control flow.
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index 6841f29149..83b24cafe2 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -120,7 +121,7 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
if (ndef.op() == "_Recv") {
bool has_control = false;
for (const string& input_name : ndef.input()) {
- if (StringPiece(input_name).starts_with("^")) {
+ if (str_util::StartsWith(input_name, "^")) {
has_control = true;
break;
}
@@ -128,7 +129,7 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
EXPECT_TRUE(has_control);
}
// Must have a control loop
- if (StringPiece(ndef.name()).starts_with("_cloop")) {
+ if (str_util::StartsWith(ndef.name(), "_cloop")) {
if (ndef.op() == "Enter") {
has_control_enter = true;
}
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
index e2ce0ba046..c8c2b225fe 100644
--- a/tensorflow/core/graph/graph_test.cc
+++ b/tensorflow/core/graph/graph_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -408,7 +409,7 @@ TEST_F(GraphTest, NewName) {
EXPECT_NE(a1, a2);
EXPECT_NE(a1, b1);
EXPECT_NE(a2, b1);
- EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1;
+ EXPECT_TRUE(str_util::StartsWith(a1, "A")) << a1;
}
TEST_F(GraphTest, IsValidNode) {
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index cb0fc8a154..3b6e8cc233 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -259,8 +259,14 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
for (Node* var : variables) {
- string new_restore_op_name = graph->NewName(restore_op_name);
- string new_assign_op_name = graph->NewName(assign_op_name);
+ // Add an extra prefix after calling graph->NewName because the "unique"
+ // name may conflict with names generated for Send nodes.
+ // TODO(b/77547936): fix this more generally and get rid of the extra prefix
+ // here.
+ string new_restore_op_name =
+ strings::StrCat(graph->NewName(restore_op_name), "_qt");
+ string new_assign_op_name =
+ strings::StrCat(graph->NewName(assign_op_name), "_qt");
string tensor_names_op_name =
strings::StrCat(new_restore_op_name, "/tensor_names");
string shape_and_slices_op_name =
diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc
index 2ad69dbd0c..e46f92bc24 100644
--- a/tensorflow/core/graph/quantize_training_test.cc
+++ b/tensorflow/core/graph/quantize_training_test.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
@@ -215,7 +216,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) {
Node* found_node;
Status s = FindNode(g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"),
&found_node);
- EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "not found")) << s;
// Ensure that m1 and m2's inputs were quantized.
TF_ASSERT_OK(
@@ -269,7 +270,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) {
Node* found_node;
Status s = FindNode(g, strings::StrCat(d->name(), "/FakeQuantWithMinMaxVars"),
&found_node);
- EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "not found")) << s;
// Ensure that m1 and m2's inputs were quantized.
TF_ASSERT_OK(
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
index 7219d9812f..6c014a8d44 100644
--- a/tensorflow/core/graph/subgraph_test.cc
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -312,8 +312,8 @@ TEST_F(SubgraphTest, ChainOfFools) {
EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0));
}
-static bool HasSubstr(const string& base, const string& substr) {
- bool ok = StringPiece(base).contains(substr);
+static bool HasSubstr(StringPiece base, StringPiece substr) {
+ bool ok = str_util::StrContains(base, substr);
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
return ok;
}
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index 089ea5e527..8af1936d64 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -45,7 +46,7 @@ TensorId ParseTensorName(StringPiece name) {
if (p > base && *p == ':' && mul > 1) {
id.first = StringPiece(base, p - base);
id.second = index;
- } else if (name.starts_with("^")) {
+ } else if (str_util::StartsWith(name, "^")) {
// Control edge
id.first = StringPiece(base + 1);
id.second = Graph::kControlSlot;
diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc
index cb6d107cad..d58cdc3c5b 100644
--- a/tensorflow/core/graph/validate_test.cc
+++ b/tensorflow/core/graph/validate_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -60,7 +61,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) {
CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
EXPECT_FALSE(s.ok());
- EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "NodeDef missing attr"));
// Add the defaults.
TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0));
@@ -83,7 +84,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
EXPECT_FALSE(s.ok());
- EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "NodeDef missing attr"));
// Add the defaults.
TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0));
@@ -91,7 +92,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
// Validation should still fail.
s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
EXPECT_FALSE(s.ok());
- EXPECT_TRUE(StringPiece(s.ToString()).contains("NodeDef missing attr"));
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "NodeDef missing attr"));
}
TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) {
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 39bfca244e..8d8c6084ec 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -62,6 +62,10 @@ void Cluster::DisableOptimizer(bool disable) {
options_.config.mutable_graph_options()->mutable_rewrite_options();
rewriter_config->set_layout_optimizer(RewriterConfig::OFF);
rewriter_config->set_disable_model_pruning(true);
+ rewriter_config->set_function_optimization(RewriterConfig::OFF);
+ rewriter_config->set_arithmetic_optimization(RewriterConfig::OFF);
+ rewriter_config->set_loop_optimization(RewriterConfig::OFF);
+ rewriter_config->set_dependency_optimization(RewriterConfig::OFF);
rewriter_config->set_constant_folding(RewriterConfig::OFF);
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->mutable_auto_parallel()->set_enable(false);
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index b54b34959a..50d6e6468f 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -54,7 +54,7 @@ DeviceProperties GetLocalCPUInfo() {
int64 free_mem = port::AvailableRam();
if (free_mem < INT64_MAX) {
- device.set_memory_size(free_mem * 1024);
+ device.set_memory_size(free_mem);
}
(*device.mutable_environment())["cpu_instruction_set"] =
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index ae70c98608..abfa7bc48e 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -66,6 +66,7 @@ Status VirtualCluster::Run(const GraphDef& graph,
}
Costs node_costs;
+ int node_id = 0;
do {
OpContext op_context = scheduler.GetCurrNode();
node_costs = node_estimator_->PredictCosts(op_context);
@@ -73,6 +74,7 @@ Status VirtualCluster::Run(const GraphDef& graph,
CostGraphDef::Node* cost_node =
metadata->mutable_cost_graph()->add_node();
const string& op_name = op_context.name;
+ cost_node->set_id(node_id++);
cost_node->set_name(op_name);
cost_node->set_device(op_context.device_name);
cost_node->set_compute_cost(
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 5103098f27..8fe154dbf3 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -1011,6 +1011,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
}
// Skip any information that comes from fed nodes.
if (fed_ports.find(node->name()) != fed_ports.end()) {
+ VLOG(2) << "Skipping feed node shape: " << node->name();
continue;
}
for (const auto& merged_shapes : node_ctx->MergedShapes()) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 0f6307cfdf..14e46ecdd9 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -202,12 +202,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- // TODO(76227186): re-enable with output size check & test
- /*
{kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
- */
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -817,6 +814,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
}
if (!shape_found) {
// Set the minimum filter size that's feasible.
+ input_shape.Clear();
for (int i = 0; i < 4; ++i) {
input_shape.add_dim()->set_size(1);
}
@@ -859,6 +857,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
}
if (!shape_found) {
// Set the minimum filter size that's feasible.
+ filter_shape.Clear();
for (int i = 0; i < 4; ++i) {
filter_shape.add_dim()->set_size(1);
}
@@ -1056,6 +1055,13 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
// part of it. For these op the size of the output determines the memory cost.
const auto& op_info = op_context.op_info;
+ const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
+ if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
+ Costs costs = Costs::ZeroCosts();
+ costs.inaccurate = true;
+ return costs;
+ }
+
bool unknown_shapes = false;
// Each output element is a copy of some element from input.
@@ -1242,10 +1248,31 @@ Costs OpLevelCostEstimator::PredictAvgPoolGrad(
const OpContext& op_context) const {
bool found_unknown_shapes = false;
const auto& op_info = op_context.op_info;
- // x: op_info.inputs(0)
+ // x's shape: op_info.inputs(0)
// y_grad: op_info.inputs(1)
- ConvolutionDimensions dims = OpDimensionsFromInputs(
- op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
+ bool shape_found = false;
+ TensorShapeProto x_shape;
+ if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
+ const TensorProto& value = op_info.inputs(0).value();
+ shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
+ }
+ if (!shape_found && op_info.outputs_size() > 0) {
+ x_shape = op_info.outputs(0).shape();
+ shape_found = true;
+ }
+ if (!shape_found) {
+ // Set the minimum shape that's feasible.
+ x_shape.Clear();
+ for (int i = 0; i < 4; ++i) {
+ x_shape.add_dim()->set_size(1);
+ }
+ found_unknown_shapes = true;
+ }
+
+ ConvolutionDimensions dims =
+ OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes);
int64 ops = 0;
if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 56915ed821..d797a8a8c1 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -217,6 +217,39 @@ std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
return output;
}
+// Helper functions for testing GetTensorShapeProtoFromTensorProto().
+void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
+ const std::vector<int64> values, const bool tensor_content,
+ TensorProto* tensor_proto) {
+ tensor_proto->Clear();
+ TensorProto temp_tensor_proto;
+ temp_tensor_proto.set_dtype(dtype);
+ for (const auto& x : shape) {
+ temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
+ }
+ for (const auto& x : values) {
+ if (dtype == DT_INT64) {
+ temp_tensor_proto.add_int64_val(x);
+ } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
+ dtype == DT_UINT8) {
+ temp_tensor_proto.add_int_val(x);
+ } else if (dtype == DT_UINT32) {
+ temp_tensor_proto.add_uint32_val(x);
+ } else if (dtype == DT_UINT64) {
+ temp_tensor_proto.add_uint64_val(x);
+ } else {
+ CHECK(false) << "Unsupported dtype: " << dtype;
+ }
+ }
+ Tensor tensor(dtype);
+ CHECK(tensor.FromProto(temp_tensor_proto));
+ if (tensor_content) {
+ tensor.AsProtoTensorContent(tensor_proto);
+ } else {
+ tensor.AsProtoField(tensor_proto);
+ }
+}
+
OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
const std::vector<int>& ksize,
const std::vector<int>& strides,
@@ -233,8 +266,11 @@ OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
} else if (op_name == "AvgPoolGrad") {
- // input: x, y_grad, output: x_grad.
- DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ // input: x's shape, y_grad, output: x_grad.
+ DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
+ auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
+ GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
+ /*tensor_content=*/false, tensor_proto);
DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
} else if (op_name == "MaxPoolGrad") {
@@ -365,43 +401,56 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
OpLevelCostEstimator estimator_;
};
-// TODO(76227186): re-enable with output size check & test
-/*
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
-OpContext op_context;
-SetCpuDevice(&op_context.op_info);
-op_context.op_info.set_op("Gather");
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("Gather");
-// Huge first input shouldn't affect Gather execution and memory costs.
-DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
-DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
-DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
+ // Huge first input shouldn't affect Gather execution and memory costs.
+ DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+ DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
+ DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
-auto cost = estimator_.PredictCosts(op_context);
-EXPECT_EQ(Costs::Duration(130), cost.memory_time);
-EXPECT_EQ(Costs::Duration(16), cost.compute_time);
-EXPECT_EQ(Costs::Duration(146), cost.execution_time);
-EXPECT_FALSE(cost.inaccurate);
+ auto cost = estimator_.PredictCosts(op_context);
+ EXPECT_EQ(Costs::Duration(130), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(16), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(146), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
}
-TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
-OpContext op_context;
-SetCpuDevice(&op_context.op_info);
-op_context.op_info.set_op("Slice");
+TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("Gather");
-// Huge first input shouldn't affect Slice execution and memory costs.
-DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
-DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
-DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
-DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
+ // Huge first input shouldn't affect Gather execution and memory costs.
+ DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+ DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
+
+ auto cost = estimator_.PredictCosts(op_context);
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
-auto cost = estimator_.PredictCosts(op_context);
-EXPECT_EQ(Costs::Duration(81), cost.memory_time);
-EXPECT_EQ(Costs::Duration(10), cost.compute_time);
-EXPECT_EQ(Costs::Duration(91), cost.execution_time);
-EXPECT_FALSE(cost.inaccurate);
+TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("Slice");
+
+ // Huge first input shouldn't affect Slice execution and memory costs.
+ DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+ DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+ DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+ DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
+
+ auto cost = estimator_.PredictCosts(op_context);
+ EXPECT_EQ(Costs::Duration(81), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(10), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(91), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
}
-*/
TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
@@ -510,39 +559,6 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
}
-// Helper functions for testing GetTensorShapeProtoFromTensorProto().
-void GetTensorProto(const DataType dtype, const std::vector<int64>& shape,
- const std::vector<int64> values, const bool tensor_content,
- TensorProto* tensor_proto) {
- tensor_proto->Clear();
- TensorProto temp_tensor_proto;
- temp_tensor_proto.set_dtype(dtype);
- for (const auto& x : shape) {
- temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
- }
- for (const auto& x : values) {
- if (dtype == DT_INT64) {
- temp_tensor_proto.add_int64_val(x);
- } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
- dtype == DT_UINT8) {
- temp_tensor_proto.add_int_val(x);
- } else if (dtype == DT_UINT32) {
- temp_tensor_proto.add_uint32_val(x);
- } else if (dtype == DT_UINT64) {
- temp_tensor_proto.add_uint64_val(x);
- } else {
- CHECK(false) << "Unsupported dtype: " << dtype;
- }
- }
- Tensor tensor(dtype);
- CHECK(tensor.FromProto(temp_tensor_proto));
- if (tensor_content) {
- tensor.AsProtoTensorContent(tensor_proto);
- } else {
- tensor.AsProtoField(tensor_proto);
- }
-}
-
void ExpectTensorShape(const std::vector<int64>& expected,
const TensorShapeProto& tensor_shape_proto) {
TensorShape tensor_shape_expected(expected);
@@ -746,25 +762,25 @@ TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
{
// Typical 3xz3 window with 2x2 stride.
auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
- EXPECT_EQ(Costs::Duration(1920000), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
- EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
EXPECT_FALSE(costs.inaccurate);
}
{
// 1x1 window with 2x2 stride: used for shortcut in resnet-50.
auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
- EXPECT_EQ(Costs::Duration(1574400), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
- EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
EXPECT_FALSE(costs.inaccurate);
}
{
// 2x2 window with 3x3 stride.
auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
- EXPECT_EQ(Costs::Duration(1476480), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
- EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
+ EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
EXPECT_FALSE(costs.inaccurate);
}
}
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 3ac3ae0f8f..0e5c654acf 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -44,6 +44,8 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
Costs result = left;
result.execution_time += right.execution_time;
+ result.compute_time += right.compute_time;
+ result.memory_time += right.memory_time;
if (right.inaccurate) {
result.inaccurate = true;
}
@@ -841,6 +843,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
Costs VirtualScheduler::Summary() const {
// Print out basic execution summary.
VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
+ VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
+ VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
VLOG(1) << "Expected max per-op streaming buffers: "
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index c31ac9b59c..a24d2dbd9f 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
@@ -68,6 +69,10 @@ bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
+bool IsCheckNumerics(const NodeDef& node) {
+ return node.op() == "CheckNumerics";
+}
+
bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
@@ -360,6 +365,8 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
+bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
+
bool IsVariable(const NodeDef& node) {
const auto& op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
@@ -404,8 +411,18 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
bool ModifiesInputsInPlace(const NodeDef& node) {
// Some nodes do in-place updates on regular tensor inputs.
string op_name = node.op();
+
+ // Ops that modify resource variables effectively modify one of their inputs.
+ if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
+ op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
+ op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
+ op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
+ op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
+ return false;
+ }
+
std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
- if (StringPiece(op_name).contains("inplace")) {
+ if (str_util::StrContains(op_name, "inplace")) {
return true;
}
return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 39affcbc24..8667f72c7e 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -37,6 +37,7 @@ bool IsBiasAdd(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsCast(const NodeDef& node);
+bool IsCheckNumerics(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
@@ -139,6 +140,7 @@ bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
bool IsTruncateMod(const NodeDef& node);
+bool IsUnpack(const NodeDef& node);
bool IsVariable(const NodeDef& node);
bool IsZeta(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 2c365c467c..122fd48584 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -251,6 +251,7 @@ cc_library(
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
+ ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -260,6 +261,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -272,6 +274,11 @@ tf_cuda_cc_test(
":constant_folding",
":model_pruner",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -501,6 +508,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:colocation",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -630,6 +638,7 @@ cc_library(
tf_cuda_cc_test(
name = "debug_stripper_test",
+ size = "small",
srcs = ["debug_stripper_test.cc"],
deps = [
":debug_stripper",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d155e0b289..59a5695af0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include <algorithm>
+#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
@@ -31,8 +32,9 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
+#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -197,39 +199,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
-// Shape is symbolically defined if it has a known rank, and each dimension is
-// defined, or is an unknown symbol (dim.size <= -2).
-bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
- return !shape.unknown_rank() &&
- std::all_of(
- shape.dim().begin(), shape.dim().end(),
- [](const TensorShapeProto::Dim& dim) { return dim.size() != -1; });
-}
-
-bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) {
- return ShapeIsSymbolicallyDefined(properties.shape());
-}
-
-bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
- const TensorShapeProto& right) {
- if (left.unknown_rank() || right.unknown_rank() ||
- left.dim_size() != right.dim_size()) {
- return false;
- }
- for (int i = 0; i < left.dim_size(); ++i) {
- if (left.dim(i).size() == -1 || right.dim(i).size() == -1 ||
- left.dim(i).size() != right.dim(i).size()) {
- return false;
- }
- }
- return true;
-}
-
-bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
- const OpInfo::TensorProperties& right) {
- return ShapesSymbolicallyEqual(left.shape(), right.shape());
-}
-
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
@@ -320,21 +289,16 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
// TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
// optimizations will be migrated to stages
- void AddFrameControlDeps(const NodeDef* old_node,
- const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep) {
- const auto frame_it = ctx_.frame_map->find(old_node);
- if (frame_it != ctx_.frame_map->end()) {
- for (auto node : new_nodes) {
- ctx_.frame_map->emplace(node, frame_it->second);
- }
- if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
- for (auto node : sinks_for_control_dep) {
- MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
- ctx_.node_map);
+ void ForwardControlDependencies(
+ NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+ for (const auto& src : src_nodes) {
+ for (int i = src->input_size() - 1; i >= 0; --i) {
+ if (IsControlInput(src->input(i))) {
+ *target_node->add_input() = src->input(i);
+ ctx_.node_map->AddOutput(NodeName(src->input(i)),
+ target_node->name());
+ } else {
+ break;
}
}
}
@@ -348,17 +312,30 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// original inputs of absorbed nodes.
//
-// All nodes in a Add/AddN subgraph must have symbolically equal shape. All
-// nodes must have the same device placement.
+// 1) All nodes must have the same device placement.
+//
+// 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
+// optimized to a single AddN node.
//
-// Example:
// AddN_1
// / | \
-// Add_1 z Add_2 -> AddN(z, y, z, w, q, e)
+// Add_1 z Add_2 -> AddN(x, y, z, w, q, e)
// / \ / \
// x y w Add_3
// / \
// q e
+//
+// 3) If some nodes have different shape (it needs to be broadcastable to the
+// shape of a "root), tree is optimized to AddNs for symbolically equal
+// shapes, and a tree of Add ops, that minimize broadcasts.
+//
+// AddN_1 Add
+// / | \ / \
+// Add_1 z Add_2 -> Add w
+// / \ / \ / \
+// x y w Add_3 AddN(x, y, q, e) z
+// / \
+// q e
class AddOpsRewriteStage : public ArithmeticOptimizerStage {
public:
explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
@@ -379,7 +356,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(node->name(), &properties);
return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
- HasAllInputsOfSymbolicallyEqualShape(*node, properties);
+ HasAllInputsOfBroadcastableShape(*node, properties);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
@@ -387,7 +364,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
AddOpsGroup group;
TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
- if (!group.absorbed_nodes.empty() && !IsRewritten(group)) {
+ if (!group.absorbed_nodes.empty()) {
*simplified_node_name = RewriteAddOpsGroup(group);
}
@@ -395,6 +372,14 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
}
private:
+ // Input name with a statically inferred shape from GraphProperties
+ struct InputAndShape {
+ InputAndShape(const string& input, const TensorShapeProto& shape)
+ : input(input), shape(shape) {}
+ string input;
+ TensorShapeProto shape;
+ };
+
// Holds together an add ops subgraph that we want to rewrite together.
//
// For the graph above the AddOpsGroup will be:
@@ -406,12 +391,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
TensorShapeProto root_shape;
// Add/AddN operations below the root level that were absorbed by this group
std::vector<NodeDef*> absorbed_nodes;
- // Inputs of absorbed nodes that will be forwarded to rewritten AddN node
- std::vector<string> inputs;
+ // Inputs of absorbed nodes that will be forwarded to optimized AddN ops
+ std::vector<InputAndShape> inputs;
};
- // Check if all inputs have symbolically equal shapes
- bool HasAllInputsOfSymbolicallyEqualShape(
+ // Check if all inputs can be broadcasted to the same shape
+ bool HasAllInputsOfBroadcastableShape(
const NodeDef& node, const OpInfo::TensorProperties& properties) const {
const AddOpsRewriteStage* self = this;
return std::all_of(
@@ -421,7 +406,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
Status has_input_properties =
self->GetTensorProperties(input, &input_properties);
return has_input_properties.ok() &&
- ShapesSymbolicallyEqual(properties, input_properties);
+ ShapesBroadcastable(properties, input_properties);
});
}
@@ -467,11 +452,11 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (node->device() != group.root_node->device()) {
return false;
}
- // All input shapes must be symbolically defined and equal to the node shape
+ // All input shapes must be broadcastable to the node shape
OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(name, &properties);
return has_properties.ok() &&
- HasAllInputsOfSymbolicallyEqualShape(*node, properties);
+ HasAllInputsOfBroadcastableShape(*node, properties);
}
// Node requirements both for a root node and an absorbed node
@@ -490,18 +475,16 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) {
return false;
}
+ // it must not be created by this stage at any of previous optimization runs
+ if (StringPiece(node->name()).contains(stage_name_)) {
+ return false;
+ }
// should not drive or be driven by control dependency
// TODO(ezhulenev): relax this condition for root node
return !(IsDrivenByControlDependency(*node) ||
DrivesControlDependency(*node));
}
- // Check that optimized group node name doesn't exists. It might happen if
- // graph optimized multiple times without pruning between invocations.
- bool IsRewritten(const AddOpsGroup& group) const {
- return ctx_.node_map->NodeExists(AddOpsGroupName(group));
- }
-
// Create an AddOpsGroup with a root in a given node
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
OpInfo::TensorProperties root_node_output_properties;
@@ -513,7 +496,10 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
group->absorbed_nodes.reserve(root_node->input_size());
for (int i = 0; i < root_node->input_size(); ++i) {
- TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(root_node->input(i), group));
+ const string& input_i = root_node->input(i);
+ if (!IsControlInput(input_i)) {
+ TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group));
+ }
}
return Status::OK();
@@ -526,71 +512,159 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (IsAbsorbableByAddOpsGroup(input, *group)) {
group->absorbed_nodes.push_back(node);
for (int i = 0; i < node->input_size(); ++i) {
- TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(node->input(i), group));
+ const string& input_i = node->input(i);
+ if (!IsControlInput(input)) {
+ TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group));
+ }
}
} else {
// If node can't be absorbed, add it to AddOpsGroup input
- group->inputs.push_back(input);
+ OpInfo::TensorProperties properties;
+ TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties));
+ group->inputs.emplace_back(input, properties.shape());
}
return Status::OK();
}
- // New node for AddOpsGroup is added to the same scope as a root_node. All
- // absorbed nodes are stripped of their scope, and only names are used in a
- // new node name.
- //
- // Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"])
- // node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add
- string AddOpsGroupName(const AddOpsGroup& group) const {
- CHECK_NOTNULL(group.root_node);
-
- auto root = ParseNodeScopeAndName(group.root_node->name());
+ // Rewrite an add ops group into a single AddN if all input shapes are
+ // symbolically equal. If not, create AddN for equal shapes first, and then
+ // build an Add tree, minimizing the cost of broadcasts.
+ string RewriteAddOpsGroup(const AddOpsGroup& group) {
+ // all new nodes will be placed under the scope of a root node
+ auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
+
+ auto shape_sig = [](const TensorShapeProto& shape) {
+ string name = strings::StrCat("r:", shape.dim_size(), ":d");
+ for (int i = 0; i < shape.dim_size(); ++i)
+ strings::StrAppend(&name, ":", shape.dim(i).size());
+ return name;
+ };
+
+ // Find what shapes are present in the inputs of absorbed nodes
+ std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
+ for (const auto& input : group.inputs) {
+ shape_sig_to_inputs[shape_sig(input.shape)].push_back(input);
+ }
- std::vector<string> absorbed_node_names(group.absorbed_nodes.size());
- std::transform(group.absorbed_nodes.begin(), group.absorbed_nodes.end(),
- absorbed_node_names.begin(),
- [](const NodeDef* node) { return node->name(); });
+ // Collect all the shapes from representative elements
+ std::vector<TensorShapeProto> shapes;
+ shapes.reserve(shape_sig_to_inputs.size());
+ for (const auto& el : shape_sig_to_inputs)
+ shapes.push_back(el.second[0].shape);
+
+ // If all inputs have the same shape, rewrite whole group with a single AddN
+ if (shapes.size() == 1) {
+ string node_name = OptimizedNodeName(root_scope_and_name);
+ AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
+ group.inputs);
+ // keep track of nodes that were created or absorbed as a part of rewrite
+ rewritten_nodes_.insert(node_name);
+ return node_name;
+ }
- return OptimizedNodeName(root, absorbed_node_names);
- }
+ // For inputs of different shapes:
+ // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
+ // 2. Build a tree of Add nodes, minimizing cost of broadcast
+ std::sort(shapes.begin(), shapes.end(),
+ [](const TensorShapeProto& left, const TensorShapeProto& right) {
+ return CompareSymbolicallyShapedTensorSizes(left, right);
+ });
+
+ // optimized name for leaf AddN nodes
+ auto leaf_node_name = [&root_scope_and_name, this](int i) {
+ return OptimizedNodeName(root_scope_and_name,
+ strings::StrCat("Leaf_", i));
+ };
+ // optimized name for internal nodes of a tree built up from AddN leaves
+ auto internal_node_name = [&root_scope_and_name, this](int i) {
+ return OptimizedNodeName(root_scope_and_name,
+ strings::StrCat("Internal_", i));
+ };
+
+ // Add/AddN nodes that must be added to the tree
+ std::deque<InputAndShape> add_ops;
+
+ // Prepare leaf AddN nodes for inputs of equal shape
+ for (int i = 0; i < shapes.size(); ++i) {
+ const auto node_name = leaf_node_name(i);
+ const auto& inputs = shape_sig_to_inputs[shape_sig(shapes[i])];
+ add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
+ node_name, inputs));
+ }
- // Create a new node for a AddOpsGroup and return it's name.
- string RewriteAddOpsGroup(const AddOpsGroup& group) {
- CHECK_GT(group.absorbed_nodes.size(), 0)
- << "AddOpsGroup must have non empty absorbed nodes";
+ // Build up a tree of Add ops
+ int internal_nodes = 0;
+ do {
+ const InputAndShape lhs = add_ops.front();
+ add_ops.pop_front();
+ const InputAndShape rhs = add_ops.front();
+ add_ops.pop_front();
+ string name = add_ops.empty() ? OptimizedNodeName(root_scope_and_name)
+ : internal_node_name(internal_nodes++);
+ InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
+ add_ops.push_front(add);
+ } while (add_ops.size() > 1);
+
+ InputAndShape optimized_root_node = add_ops.front();
+ return optimized_root_node.input;
+ }
+
+ // Add 'AddN' node to aggregate inputs of symbolically equal shape
+ InputAndShape AddInputsOfSymbolicallyEqualShape(
+ const NodeDef& root_node, const string& node_name,
+ const std::vector<InputAndShape>& inputs) {
+ CHECK(!inputs.empty()) << "Inputs must be non-empty";
+
+ // Do not create redundant AddN nodes
+ if (inputs.size() == 1) {
+ return inputs[0];
+ }
- // name for a new node constructed from AddOpsGroup
- string node_name = AddOpsGroupName(group);
+ // get shape from representative element
+ auto shape = inputs[0].shape;
// copy attributes from a root node
- DataType dtype = group.root_node->attr().at("T").type();
+ DataType dtype = root_node.attr().at("T").type();
// add new AddN node
- NodeDef* added_node = AddEmptyNode(node_name);
- added_node->set_op("AddN");
- added_node->set_device(group.root_node->device());
- (*added_node->mutable_attr())["T"].set_type(dtype);
- (*added_node->mutable_attr())["N"].set_i(group.inputs.size());
-
- // all inputs of absorbed nodes are added to the new node
- for (const string& input : group.inputs) {
- ctx_.node_map->AddOutput(input, node_name);
- added_node->add_input(input);
+ NodeDef* node = AddEmptyNode(node_name);
+ node->set_op("AddN");
+ node->set_device(root_node.device());
+ (*node->mutable_attr())["T"].set_type(dtype);
+ (*node->mutable_attr())["N"].set_i(inputs.size());
+
+ for (const auto& inputAndShape : inputs) {
+ ctx_.node_map->AddOutput(inputAndShape.input, node_name);
+ node->add_input(inputAndShape.input);
}
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(group.root_node, {added_node}, "", {});
+ rewritten_nodes_.insert(node_name);
+ return InputAndShape(node_name, shape);
+ }
+
+ // Add a single 'Add' node to sum two inputs
+ InputAndShape AddAggregatedInputs(const NodeDef& root_node,
+ const string& node_name,
+ const InputAndShape& left,
+ const InputAndShape& right) {
+ // copy attributes from a root node
+ DataType dtype = root_node.attr().at("T").type();
- VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
- << " Add/AddN nodes from the graph";
+ // add new Add node
+ NodeDef* node = AddEmptyNode(node_name);
+ node->set_op("Add");
+ node->set_device(root_node.device());
+ (*node->mutable_attr())["T"].set_type(dtype);
- // keep track of nodes that were created or absorbed as a part of rewrite
- rewritten_nodes_.insert(node_name);
- for (const NodeDef* absorbed : group.absorbed_nodes) {
- rewritten_nodes_.insert(absorbed->name());
- }
+ ctx_.node_map->AddOutput(left.input, node_name);
+ ctx_.node_map->AddOutput(right.input, node_name);
+
+ node->add_input(left.input);
+ node->add_input(right.input);
- return node_name;
+ rewritten_nodes_.insert(node_name);
+ return InputAndShape(
+ node_name, TensorShapeProto()); // shape is not important at this point
}
// keep nodes that were added or absorbed as a part of AddOpsGroup rewrite
@@ -623,7 +697,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
CHECK(IsSupported(node));
std::set<string> common_factors;
- TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+ std::vector<string> ctrl_deps;
+ TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps));
if (common_factors.size() == 1) {
const string& common_factor = *common_factors.begin();
@@ -655,9 +730,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
new_add_node->set_input(i, unique_factors[i]);
}
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
+ // Add control deps on add node
+ for (const string& ctrl_dep : ctrl_deps) {
+ *new_add_node->add_input() = ctrl_dep;
+ ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
+ }
// optimize new inner aggregation node
AddToOptimizationQueue(new_add_node);
@@ -683,14 +760,16 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
}
// Determine the set of common factors if the input nodes are all Mul nodes.
- Status GetCommonFactors(const NodeDef* node,
- std::set<string>* common_factors) const {
+ Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
+ std::vector<string>* ctrl_deps) const {
CHECK(common_factors->empty());
for (int i = 0; i < node->input_size(); ++i) {
if (i > 0 && common_factors->empty()) break;
- if (IsControlInput(node->input(i))) break;
-
+ if (IsControlInput(node->input(i))) {
+ ctrl_deps->push_back(node->input(i));
+ continue;
+ }
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
@@ -710,6 +789,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
std::inserter(intersection, intersection.begin()));
std::swap(*common_factors, intersection);
}
+ for (int i = 2; i < input->input_size(); ++i) {
+ ctrl_deps->push_back(input->input(i));
+ }
}
return Status::OK();
}
@@ -1195,20 +1277,15 @@ void ArithmeticOptimizer::DedupComputations() {
}
}
-void ArithmeticOptimizer::AddFrameControlDeps(
- const NodeDef* old_node, const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep) {
- const auto frame_it = frame_map_.find(old_node);
- if (frame_it != frame_map_.end()) {
- for (auto node : new_nodes) {
- frame_map_.emplace(node, frame_it->second);
- }
- if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- source_for_ctrl_dep, optimized_graph_, node_map_.get());
- for (auto node : sinks_for_control_dep) {
- MaybeAddControlInput(ctrl_dep, node, optimized_graph_, node_map_.get());
+void ArithmeticOptimizer::ForwardControlDependencies(
+ NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+ for (const auto& src : src_nodes) {
+ for (int i = src->input_size() - 1; i >= 0; --i) {
+ if (IsControlInput(src->input(i))) {
+ *target_node->add_input() = src->input(i);
+ node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
+ } else {
+ break;
}
}
}
@@ -1264,19 +1341,18 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
int output_pos = 0;
string input_node_name = ParseNodeName(node->input(0), &output_pos);
const NodeDef* input = node_map_->GetNode(input_node_name);
- if (input->op() == "Reshape") {
+ if (input->op() == "Reshape" && !HasControlInputs(*input)) {
reshape->set_input(0, input->input(0));
node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
nodes_to_simplify->PushBack(reshape);
return reshape->name();
}
- // If the reshape is a no-op, forward its input to its consumers. This is
- // considered aggressive, because users may state that the placeholder
- // outputs tensors of shape [M, N] while feeding it with tensors of shape
- // [M*N] (or worse). The reshape nodes are then necessary to update the
- // tensor metadata to the required shape.
- if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) {
+ // If the reshape is a no-op, forward its input to its consumers, unless it
+ // anchors a control dependency since we want to make sure that control
+ // dependency is triggered.
+ if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
+ !HasControlInputs(*reshape)) {
return reshape->input(0);
}
}
@@ -1329,10 +1405,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map_->AddOutput(new_transpose->name(), new_cast->name());
nodes_to_simplify->PushBack(new_transpose);
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_transpose, new_cast},
- new_transpose->input(0), {new_transpose});
-
+ ForwardControlDependencies(new_transpose, {cast, node});
return new_cast->name();
}
}
@@ -1406,7 +1479,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map_->AddOutput(weights->name(), scaled_weights->name());
scaled_weights->add_input(mul->input(1));
node_map_->AddOutput(scale->name(), scaled_weights->name());
- AddFrameControlDeps(node, {scaled_weights}, "", {});
+ ForwardControlDependencies(scaled_weights, {source});
// Update `conv`'s weights to `scaled_weights`.
conv->set_input(1, scaled_weights->name());
@@ -1442,7 +1515,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
- // Discard aggregate nodes with a single input.
+ // Discard aggregate nodes with a single input and no control dependencies.
if (node->input_size() == 1) {
return node->input(0);
}
@@ -1488,6 +1561,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
return "";
}
new_const_node->set_device(node->device());
+ MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
+ optimized_graph_, node_map_.get());
nodes_to_simplify->PushBack(new_const_node);
// 2. Replace the aggregate node with Mul(Const(N), x).
@@ -1500,9 +1575,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_mul_node->add_input(node->input(0));
node_map_->AddOutput(node->input(0), new_mul_node->name());
- CopyControlInputs(*node, new_mul_node, optimized_graph_, node_map_.get());
- AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0),
- {new_const_node});
+ ForwardControlDependencies(new_mul_node, {node});
return new_mul_node->name();
}
}
@@ -1535,7 +1608,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
FlipBooleanAttr(attr_a, new_op);
new_op->set_input(0, a->input(0));
node_map_->UpdateInput(new_op->name(), a->name(), a->input(0));
- AddFrameControlDeps(node, {new_op}, a->input(0), {new_op});
}
if (b_is_foldable) {
const string attr_b =
@@ -1543,10 +1615,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
FlipBooleanAttr(attr_b, new_op);
new_op->set_input(1, b->input(0));
node_map_->UpdateInput(new_op->name(), b->name(), b->input(0));
- if (!a_is_foldable) {
- AddFrameControlDeps(node, {new_op}, b->input(0), {new_op});
- }
}
+ std::vector<const NodeDef*> deps_to_forward({node});
+ if (a_is_foldable) {
+ deps_to_forward.push_back(a);
+ }
+ if (b_is_foldable) {
+ deps_to_forward.push_back(b);
+ }
+ ForwardControlDependencies(new_op, deps_to_forward);
}
}
@@ -1568,7 +1645,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
: "Transpose");
new_op->set_input(0, input->input(0));
node_map_->UpdateInput(new_op->name(), node->name(), input->input(0));
- AddFrameControlDeps(node, {new_op}, "", {});
+ ForwardControlDependencies(new_op, {node, input});
return new_op->name();
}
}
@@ -1584,38 +1661,27 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get(),
- &frame_map_);
+ graph_properties_.get(), node_map_.get());
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
- std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
-
- if (options_.combine_add_to_addn) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new AddOpsRewriteStage(ctx, ctx_ext)));
- }
- if (options_.hoist_common_factor_out_of_aggregation) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new HoistCommonFactorOutOfAggregation(ctx, ctx_ext)));
- }
- if (options_.remove_identity_transpose) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveIdentityTranspose(ctx, ctx_ext)));
- }
- if (options_.remove_redundant_bitcast) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveRedundantBitcastStage(ctx, ctx_ext)));
- }
- if (options_.remove_redundant_cast) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveRedundantCastStage(ctx, ctx_ext)));
- }
- if (options_.remove_negation) {
- stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
- new RemoveNegationStage(ctx, ctx_ext)));
- }
-
- VLOG(1) << "Simplify arithmetic ops using " << stages.size()
+ // Stop pipeline after first stage returning non-empty simplified tensor name.
+ const auto stop = [](const string& result) { return !result.empty(); };
+ GraphOptimizerStagePipeline<string> pipeline(stop);
+
+ if (options_.combine_add_to_addn)
+ pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
+ if (options_.hoist_common_factor_out_of_aggregation)
+ pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
+ if (options_.remove_identity_transpose)
+ pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
+ if (options_.remove_redundant_bitcast)
+ pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
+ if (options_.remove_redundant_cast)
+ pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+ if (options_.remove_negation)
+ pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+
+ VLOG(1) << "Simplify arithmetic ops using " << pipeline.NumStages()
<< " arithmetic optimization stages";
while (!nodes_to_simplify.Empty()) {
@@ -1628,22 +1694,13 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
}
// if it was not simplified try to run it through all configured stages
- if (simplified_tensor.empty()) {
- for (auto& stage : stages) {
- if (stage->IsSupported(node)) {
- TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
- if (!simplified_tensor.empty()) {
- break;
- }
- }
+ if (!stop(simplified_tensor)) {
+ bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
+ if (!optimized) {
+ continue;
}
}
- // if it's still empty go to the next Node
- if (simplified_tensor.empty()) {
- continue;
- }
-
// re-wire consumers of an old node to the new one
if (NodeName(simplified_tensor) != node->name()) {
// Always consider simplified_tensor for further optimizations.
@@ -1686,24 +1743,28 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
const GrapplerItem& item,
GraphDef* optimized_graph) {
- optimized_graph_ = optimized_graph;
- *optimized_graph_ = item.graph;
+ GrapplerItem optimized_item(item);
+ optimized_graph_ = &optimized_item.graph;
// Set up helper data structures.
nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty();
node_map_.reset(new NodeMap(optimized_graph_));
- int num_frames;
- TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
- &frame_map_, &num_frames));
+
+ DedupComputations();
+
+ // Perform topological sort on the graph in order to help AddOpsRewrite to
+ // optimize larger subgraphs starting from the roots with more inputs.
+ TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
+
// Shapes are only needed in aggressive mode.
graph_properties_.reset(new GraphProperties(item));
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
// Perform the optimizations.
- DedupComputations();
TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
+ optimized_graph->Swap(optimized_graph_);
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 965f0e9ea2..7e81ed0a1f 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -69,7 +68,13 @@ class ArithmeticOptimizer : public GraphOptimizer {
// optimization level by default.
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
- return ArithmeticOptimizerOptions();
+ ArithmeticOptimizerOptions options;
+ // TODO(ezhulenev): enable combine_add_to_addn by default after 1.8
+ // release cut
+ if (opt_level == RewriterConfig::AGGRESSIVE) {
+ options.combine_add_to_addn = true;
+ }
+ return options;
}
};
@@ -94,13 +99,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Dedup redundant nodes in the graph.
void DedupComputations();
- // Fix frame dependencies by adding control dependencies from old_input to
- // nodes in new_nodes_for_control_dep, and update frame_map for all nodes in
- // new_nodes.
- void AddFrameControlDeps(const NodeDef* old_node,
- const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep);
+ // Forward the control dependencies anchored on src_nodes to the target_nodes.
+ void ForwardControlDependencies(NodeDef* target_node,
+ const std::vector<const NodeDef*>& src_nodes);
// Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
// transposes.
@@ -129,7 +130,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool fetch_nodes_known_ = false;
std::unordered_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
- FrameMap frame_map_;
std::unique_ptr<GraphProperties> graph_properties_;
GraphDef* optimized_graph_ = nullptr; // Not owned.
};
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index ad3edc144a..e117341ba3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -156,27 +156,24 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"div"};
- ArithmeticOptimizer optimizer;
- GraphDef output;
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(2, output.node_size());
- const NodeDef& new_c1 = output.node(0);
- EXPECT_EQ("c1", new_c1.name());
- const NodeDef& new_div = output.node(1);
- EXPECT_EQ("div", new_div.name());
- EXPECT_EQ(2, new_div.input_size());
- EXPECT_EQ("c1", new_div.input(0));
- EXPECT_EQ("c1", new_div.input(1));
-
- auto tensors = EvaluateNodes(output, item.fetch, {});
+ const NodeDef* new_c1 = node_map.GetNode("c1");
+ ASSERT_NE(new_c1, nullptr);
+
+ const NodeDef* new_div = node_map.GetNode("div");
+ ASSERT_NE(new_div, nullptr);
+ EXPECT_EQ(2, new_div->input_size());
+ EXPECT_EQ("c1", new_div->input(0));
+ EXPECT_EQ("c1", new_div->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -195,23 +192,30 @@ TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"div"};
+ Tensor bool_t(DT_BOOL, TensorShape({}));
+ bool_t.scalar<bool>().setConstant(true);
+ auto tensors_expected =
+ EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(5, output.node_size());
- const NodeDef& new_div = output.node(3);
- EXPECT_EQ(4, new_div.input_size());
- EXPECT_EQ("check1", new_div.input(0));
- EXPECT_EQ("check1", new_div.input(1));
- EXPECT_EQ("^assert1", new_div.input(2));
- EXPECT_EQ("^assert1", new_div.input(3));
+ const NodeDef* new_div = node_map.GetNode("div");
+ ASSERT_NE(new_div, nullptr);
+ EXPECT_EQ(4, new_div->input_size());
+ EXPECT_EQ("check1", new_div->input(0));
+ EXPECT_EQ("check1", new_div->input(1));
+ EXPECT_EQ("^assert1", new_div->input(2));
+ EXPECT_EQ("^assert1", new_div->input(3));
+
+ auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
@@ -223,32 +227,34 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"div"};
+ item.fetch = {"div1"};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(4, output.node_size());
- const NodeDef& new_c1 = output.node(0);
- EXPECT_EQ("c1", new_c1.name());
- const NodeDef& new_c2 = output.node(1);
- EXPECT_EQ("c2", new_c2.name());
- const NodeDef& new_mul1 = output.node(2);
- EXPECT_EQ("mul1", new_mul1.name());
- EXPECT_EQ(2, new_mul1.input_size());
- EXPECT_EQ("c1", new_mul1.input(0));
- EXPECT_EQ("c2", new_mul1.input(1));
- const NodeDef& new_div1 = output.node(3);
- EXPECT_EQ("div1", new_div1.name());
- EXPECT_EQ(2, new_div1.input_size());
- EXPECT_EQ("mul1", new_div1.input(0));
- EXPECT_EQ("mul1", new_div1.input(1));
+ const NodeDef* new_c1 = node_map.GetNode("c1");
+ ASSERT_NE(new_c1, nullptr);
+ const NodeDef* new_c2 = node_map.GetNode("c2");
+ ASSERT_NE(new_c2, nullptr);
+ const NodeDef* new_mul1 = node_map.GetNode("mul1");
+ ASSERT_NE(new_mul1, nullptr);
+ EXPECT_EQ(2, new_mul1->input_size());
+ EXPECT_EQ("c1", new_mul1->input(0));
+ EXPECT_EQ("c2", new_mul1->input(1));
+ const NodeDef* new_div1 = node_map.GetNode("div1");
+ ASSERT_NE(new_div1, nullptr);
+ EXPECT_EQ(2, new_div1->input_size());
+ EXPECT_EQ("mul1", new_div1->input(0));
+ EXPECT_EQ("mul1", new_div1->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, MulToSquare) {
@@ -259,6 +265,9 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) {
Output id = ops::Identity(s.WithOpName("id"), mul);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"id"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
@@ -273,6 +282,10 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) {
EXPECT_EQ(2, output.node(4).input_size());
EXPECT_EQ("c", output.node(4).input(0));
EXPECT_EQ("^d", output.node(4).input(1));
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
@@ -285,6 +298,9 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
Output id = ops::Identity(s.WithOpName("id"), recip2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"id"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
@@ -295,6 +311,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
EXPECT_EQ("c", output.node(1).input(0));
EXPECT_EQ("c", output.node(3).input(0));
EXPECT_EQ("c", output.node(5).input(0));
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
@@ -307,6 +327,9 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
Output id2 = ops::Identity(s.WithOpName("id2"), recip2);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"id2"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
@@ -320,6 +343,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
EXPECT_EQ(6, output.node_size());
EXPECT_EQ("squeeze", output.node(5).input(0));
EXPECT_EQ("c", output.node(2).input(0));
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
@@ -334,6 +361,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"id2"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
ArithmeticOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
@@ -351,6 +382,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
EXPECT_EQ(original.input(j), optimized.input(j));
}
}
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
@@ -362,28 +397,35 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"id"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(5, output.node_size());
- const NodeDef& new_const = output.node(3);
- EXPECT_EQ(OptimizedName("add_const"), new_const.name());
- EXPECT_EQ("^x", new_const.input(0));
+
+ const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ ASSERT_NE(new_const, nullptr);
+ EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
- new_const.attr().at("value").tensor().tensor_content());
- const NodeDef& new_mul = output.node(4);
- EXPECT_EQ(OptimizedName("add_mul"), new_mul.name());
- EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0));
- EXPECT_EQ("x", new_mul.input(1));
- const NodeDef& new_id = output.node(2);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0));
+ new_const->attr().at("value").tensor().tensor_content());
+
+ const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ ASSERT_NE(new_mul, nullptr);
+ EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ("x", new_mul->input(1));
+
+ const NodeDef* new_id = node_map.GetNode("id");
+ ASSERT_NE(new_id, nullptr);
+ EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
@@ -396,29 +438,36 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"id"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(6, output.node_size());
- const NodeDef& new_const = output.node(4);
- EXPECT_EQ(OptimizedName("add_const"), new_const.name());
- EXPECT_EQ("^x", new_const.input(0));
+
+ const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ ASSERT_NE(new_const, nullptr);
+ EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
- new_const.attr().at("value").tensor().tensor_content());
- const NodeDef& new_mul = output.node(5);
- EXPECT_EQ(OptimizedName("add_mul"), new_mul.name());
- EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0));
- EXPECT_EQ("x", new_mul.input(1));
- EXPECT_EQ("^y", new_mul.input(2));
- const NodeDef& new_id = output.node(3);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0));
+ new_const->attr().at("value").tensor().tensor_content());
+
+ const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ ASSERT_NE(new_mul, nullptr);
+ EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ("x", new_mul->input(1));
+ EXPECT_EQ("^y", new_mul->input(2));
+
+ const NodeDef* new_id = node_map.GetNode("id");
+ ASSERT_NE(new_id, nullptr);
+ EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
@@ -434,6 +483,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
const std::vector<string> devices{
"/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
"/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
@@ -458,48 +508,45 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
EXPECT_EQ(17, output.node_size());
const NodeDef* id_node = node_map.GetNode("id");
- ASSERT_TRUE(id_node != nullptr);
+ ASSERT_NE(id_node, nullptr);
EXPECT_EQ(1, id_node->input_size());
EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
- ASSERT_TRUE(mul_node != nullptr);
+ ASSERT_NE(mul_node, nullptr);
EXPECT_EQ(2, mul_node->input_size());
EXPECT_EQ("Placeholder", mul_node->input(0));
EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
- ASSERT_TRUE(add_6_node != nullptr);
- EXPECT_EQ(3, add_6_node->input_size());
+ ASSERT_NE(add_6_node, nullptr);
+ EXPECT_EQ(2, add_6_node->input_size());
EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
- EXPECT_EQ("^Placeholder", add_6_node->input(2));
const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
- ASSERT_TRUE(add_4_node != nullptr);
+ ASSERT_NE(add_4_node, nullptr);
EXPECT_EQ("Add", add_4_node->op());
- EXPECT_EQ(3, add_4_node->input_size());
+ EXPECT_EQ(2, add_4_node->input_size());
EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
- EXPECT_EQ("^Placeholder", add_4_node->input(2));
const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
- ASSERT_TRUE(add_5_node != nullptr);
+ ASSERT_NE(add_5_node, nullptr);
EXPECT_EQ("Add", add_5_node->op());
- EXPECT_EQ(3, add_5_node->input_size());
+ EXPECT_EQ(2, add_5_node->input_size());
EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
- EXPECT_EQ("^Placeholder", add_5_node->input(2));
const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
- ASSERT_TRUE(add_const_node != nullptr);
+ ASSERT_NE(add_const_node, nullptr);
EXPECT_EQ("Const", add_const_node->op());
EXPECT_EQ(1, add_const_node->input_size());
EXPECT_EQ("^Placeholder", add_const_node->input(0));
const NodeDef* add_1_const_node =
node_map.GetNode(OptimizedName("Add_1_const"));
- ASSERT_TRUE(add_1_const_node != nullptr);
+ ASSERT_NE(add_1_const_node, nullptr);
EXPECT_EQ("Const", add_1_const_node->op());
EXPECT_EQ(1, add_1_const_node->input_size());
EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
@@ -525,7 +572,8 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
GrapplerItem item;
item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
EnableOnlyHoistCommonFactor(&optimizer);
@@ -550,55 +598,63 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
EXPECT_EQ(9, output.node_size());
const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
- ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+ ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
EXPECT_EQ("y1", new_add_node->input(0));
EXPECT_EQ("y2", new_add_node->input(1));
const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
- ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found";
+ ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
EXPECT_EQ("x", new_mul_node->input(0));
EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
const NodeDef* id_node = node_map.GetNode("id");
- ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+ ASSERT_NE(id_node, nullptr) << "Id node not found";
EXPECT_EQ("id", id_node->name());
EXPECT_EQ(HoistMulName("add"), id_node->input(0));
}
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
}
}
TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
- Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+ Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+ Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
Output z = ops::Complex(s.WithOpName("z"), re, im);
Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
Output conj = ops::Conj(s.WithOpName("conj"), z);
Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+ std::vector<string> fetch = {"trans"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("trans_fused"), output.node(6).name());
- EXPECT_EQ("ConjugateTranspose", output.node(6).op());
- EXPECT_EQ("z", output.node(6).input(0));
- EXPECT_EQ("perm", output.node(6).input(1));
+
+ const NodeDef* trans_fused_node =
+ node_map.GetNode(OptimizedName("trans_fused"));
+ ASSERT_NE(trans_fused_node, nullptr);
+ EXPECT_EQ("ConjugateTranspose", trans_fused_node->op());
+ EXPECT_EQ("z", trans_fused_node->input(0));
+ EXPECT_EQ("perm", trans_fused_node->input(1));
+
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
- Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+ Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+ Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
Output z = ops::Complex(s.WithOpName("z"), re, im);
Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
Output conj = ops::Conj(s.WithOpName("conj"), z);
@@ -606,44 +662,56 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"conjugate_trans"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("conjugate_trans_fused"), output.node(6).name());
- EXPECT_EQ("Transpose", output.node(6).op());
- EXPECT_EQ("z", output.node(6).input(0));
- EXPECT_EQ("perm", output.node(6).input(1));
+
+ const NodeDef* conjugate_trans_fused_node =
+ node_map.GetNode(OptimizedName("conjugate_trans_fused"));
+ EXPECT_EQ("Transpose", conjugate_trans_fused_node->op());
+ EXPECT_EQ("z", conjugate_trans_fused_node->input(0));
+ EXPECT_EQ("perm", conjugate_trans_fused_node->input(1));
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
- Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
+ Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+ Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
Output z = ops::Complex(s.WithOpName("z"), re, im);
Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
Output conj = ops::Conj(s.WithOpName("conj"), trans);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"conj"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("conj_fused"), output.node(6).name());
- EXPECT_EQ("ConjugateTranspose", output.node(6).op());
- EXPECT_EQ("z", output.node(6).input(0));
- EXPECT_EQ("perm", output.node(6).input(1));
+
+ const NodeDef* conj_fused_node =
+ node_map.GetNode(OptimizedName("conj_fused"));
+ EXPECT_EQ("ConjugateTranspose", conj_fused_node->op());
+ EXPECT_EQ("z", conj_fused_node->input(0));
+ EXPECT_EQ("perm", conj_fused_node->input(1));
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
@@ -665,27 +733,32 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
}
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"matmul"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+ NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- EXPECT_EQ(OptimizedName("matmul_fused"), output.node(6).name());
- EXPECT_EQ("a", output.node(6).input(0));
- EXPECT_EQ("b", output.node(6).input(1));
+
+ const NodeDef* matmul_fused_node =
+ node_map.GetNode(OptimizedName("matmul_fused"));
+ ASSERT_NE(matmul_fused_node, nullptr);
+ EXPECT_EQ("a", matmul_fused_node->input(0));
+ EXPECT_EQ("b", matmul_fused_node->input(1));
if (matmul_type == "BatchMatMul") {
- EXPECT_TRUE(output.node(6).attr().at("adj_x").b());
- EXPECT_TRUE(output.node(6).attr().at("adj_y").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
} else {
- EXPECT_TRUE(output.node(6).attr().at("transpose_a").b());
- EXPECT_TRUE(output.node(6).attr().at("transpose_b").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
+ EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
}
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
}
@@ -707,6 +780,9 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch = {"matmul"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
@@ -719,6 +795,9 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
EXPECT_EQ("b", output.node(10).input(1));
EXPECT_TRUE(output.node(10).attr().at("adj_x").b());
EXPECT_TRUE(output.node(10).attr().at("adj_y").b());
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
@@ -739,7 +818,10 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
+ auto tensors_expected =
+ EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+ EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
@@ -747,6 +829,9 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
@@ -761,7 +846,10 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
+ item.feed = {{"Placeholder", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
@@ -769,6 +857,9 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
@@ -781,7 +872,6 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
@@ -812,7 +902,10 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({8, 3, 28, 28, 4}));
+ item.feed = {{"nchw_vect_c", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
@@ -820,6 +913,9 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
@@ -1322,8 +1418,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
// check add tree was replaced with AddN
const NodeDef* collapsed_add =
- node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
- ASSERT_TRUE(collapsed_add != nullptr);
+ node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
+ ASSERT_NE(collapsed_add, nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
@@ -1333,7 +1429,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
// check output was re-wired to new node
const NodeDef* updated_outputs = node_map.GetNode("outputs");
- ASSERT_TRUE(updated_outputs != nullptr);
+ ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
@@ -1381,8 +1477,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
// check left Add subtree replaced with AddN
const NodeDef* collapsed_left =
- node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
- ASSERT_TRUE(collapsed_left != nullptr);
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
+ ASSERT_NE(collapsed_left, nullptr);
EXPECT_EQ("AddN", collapsed_left->op());
EXPECT_EQ(3, collapsed_left->input_size());
@@ -1392,8 +1488,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
// check right Add subtree replaced with AddN
const NodeDef* collapsed_right =
- node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy");
- ASSERT_TRUE(collapsed_right != nullptr);
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
+ ASSERT_NE(collapsed_right, nullptr);
EXPECT_EQ("AddN", collapsed_right->op());
EXPECT_EQ(3, collapsed_right->input_size());
@@ -1403,7 +1499,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
// check that Mul inputs re-wired to new Nodes
const NodeDef* updated_mul = node_map.GetNode("Mul");
- ASSERT_TRUE(updated_mul != nullptr);
+ ASSERT_NE(updated_mul, nullptr);
EXPECT_EQ("Mul", updated_mul->op());
EXPECT_EQ(2, updated_mul->input_size());
@@ -1444,9 +1540,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
NodeMap node_map(&output);
// check Add tree replaced with AddN
- const NodeDef* collapsed_add = node_map.GetNode(
- "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc");
- ASSERT_TRUE(collapsed_add != nullptr);
+ const NodeDef* collapsed_add =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
+ ASSERT_NE(collapsed_add, nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(4, collapsed_add->input_size());
@@ -1496,8 +1592,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
// check add tree was replaced with AddN
const NodeDef* collapsed_add =
- node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
- ASSERT_TRUE(collapsed_add != nullptr);
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
+ ASSERT_NE(collapsed_add, nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
EXPECT_EQ("a", collapsed_add->input(0));
@@ -1506,10 +1602,173 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
// check output was re-wired to new node
const NodeDef* updated_outputs = node_map.GetNode("outputs");
- ASSERT_TRUE(updated_outputs != nullptr);
+ ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+ auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
+ auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
+ auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
+ auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
+ auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
+
+ auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyAddToAddNCombining(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ // 1) [a, x], [b, y], [c, z] - aggregate same shapes first
+ // 2) Build an aggregation tree minimizing cost of broadcast
+ //
+ // + +
+ // / \ / \
+ // + + + AddN(c, z)
+ // / \ / \ / \
+ // + c x + --> AddN(a, x) AddN(b, y)
+ // / \ / \
+ // a b y z
+ EXPECT_EQ(12, output.node_size());
+ NodeMap node_map(&output);
+
+ // expected names of outer and inner nodes
+ string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
+ string outer_0_add_name =
+ "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
+ string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
+ string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
+ string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
+
+ // Add [a, x] first
+ const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
+ ASSERT_NE(add_ax_node, nullptr);
+ EXPECT_EQ("AddN", add_ax_node->op());
+ EXPECT_EQ(2, add_ax_node->input_size());
+ EXPECT_EQ("a", add_ax_node->input(0));
+ EXPECT_EQ("x", add_ax_node->input(1));
+
+ // Then add [b, y]
+ const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
+ ASSERT_NE(add_by_node, nullptr);
+ EXPECT_EQ("AddN", add_by_node->op());
+ EXPECT_EQ(2, add_by_node->input_size());
+ EXPECT_EQ("b", add_by_node->input(0));
+ EXPECT_EQ("y", add_by_node->input(1));
+
+ // Then add [c, z]
+ const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
+ ASSERT_NE(add_cz_node, nullptr);
+ EXPECT_EQ("AddN", add_cz_node->op());
+ EXPECT_EQ(2, add_cz_node->input_size());
+ EXPECT_EQ("c", add_cz_node->input(0));
+ EXPECT_EQ("z", add_cz_node->input(1));
+
+ // Then add results together starting from smaller shapes [a, x] + [b, y]
+ const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
+ ASSERT_NE(outer_0_node, nullptr);
+ EXPECT_EQ("Add", outer_0_node->op());
+ EXPECT_EQ(2, outer_0_node->input_size());
+ EXPECT_EQ(inner_0_add_name, outer_0_node->input(0));
+ EXPECT_EQ(inner_1_add_name, outer_0_node->input(1));
+
+ // And finally top level Add node
+ const NodeDef* outer_node = node_map.GetNode(outer_add_name);
+ ASSERT_NE(outer_node, nullptr);
+ EXPECT_EQ("Add", outer_node->op());
+ EXPECT_EQ(2, outer_node->input_size());
+ EXPECT_EQ(outer_0_add_name, outer_node->input(0));
+ EXPECT_EQ(inner_2_add_name, outer_node->input(1));
+
+ // And outputs reading new top level Add node
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_NE(updated_outputs, nullptr);
+ EXPECT_EQ(outer_add_name, updated_outputs->input(0));
+}
+
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // We have a small input with one unknown dimension
+ auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_FLOAT);
+
+ // And second input which is larger, but has the same unknown dimension
+ // device spec prevents this node from rewriting
+ auto d = "/job:do_not_rewrite_me";
+ auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_FLOAT);
+ auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
+
+ // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
+ auto a = ops::Sqrt(s.WithOpName("a"), small);
+ auto b = ops::Square(s.WithOpName("b"), large);
+ auto c = ops::Round(s.WithOpName("c"), small);
+
+ // [add_ab, add_abc] shape must be inferred from inputs
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyAddToAddNCombining(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur: it's much cheaper to add small
+ // tensors, and do the broadcast just once
+ //
+ // + +
+ // / \ / \
+ // + c --> + b
+ // / \ / \
+ // a b a c
+ EXPECT_EQ(9, output.node_size());
+ NodeMap node_map(&output);
+
+ // expected names of outer and inner nodes
+ string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
+ string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
+
+ // outer Add node
+ const NodeDef* outer_add = node_map.GetNode(outer_add_name);
+ ASSERT_NE(outer_add, nullptr);
+ EXPECT_EQ("Add", outer_add->op());
+ EXPECT_EQ(inner_add_name, outer_add->input(0));
+ EXPECT_EQ("b", outer_add->input(1));
+
+ // inner AddN node
+ const NodeDef* inner_add = node_map.GetNode(inner_add_name);
+ ASSERT_NE(inner_add, nullptr);
+ EXPECT_EQ(2, inner_add->input_size());
+ EXPECT_EQ("a", inner_add->input(0));
+ EXPECT_EQ("c", inner_add->input(1));
+
+ // check output was re-wired to new node
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_NE(updated_outputs, nullptr);
+ EXPECT_EQ(outer_add_name, updated_outputs->input(0));
+}
+
TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 7de544de52..d941a0b3f9 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -747,10 +747,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (op.find("Quantized") != string::npos || op.find("Sparse") == 0) {
return false;
}
- if (node.attr().count("_XlaCompile") > 0 &&
- node.attr().at("_XlaCompile").b()) {
- return false;
- }
const OpDef* op_def = nullptr;
Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
@@ -777,7 +773,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
// the case of a merge node that propagate the first inputs that becomes
// available, and therefore only requires a single constant input to be
// foldable.
- bool has_constant_input = false;
+ bool merge_has_constant_input = false;
const bool is_merge = IsMerge(node);
for (const auto& input : node.input()) {
if (IsControlInput(input)) {
@@ -788,21 +784,20 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false;
}
bool is_const = IsReallyConstant(*input_node);
- if (!is_const && !is_merge) {
- return false;
- }
- // Don't fold strings constants for now since this causes problems with
- // checkpointing.
- if (is_const && input_node->attr().at("dtype").type() == DT_STRING) {
+ if (is_const) {
+ // Don't fold strings constants for now since this causes problems with
+ // checkpointing.
+ if (input_node->attr().at("dtype").type() == DT_STRING) {
+ return false;
+ }
+ // Special case: If a Merge node has at least one constant input that
+ // does not depend on a control input, we can fold it.
+ merge_has_constant_input |= !HasControlInputs(*input_node);
+ } else if (!is_merge) {
return false;
}
- has_constant_input |= is_const;
- }
- if (is_merge) {
- return has_constant_input;
}
-
- return true;
+ return !is_merge || merge_has_constant_input;
}
namespace {
@@ -1542,6 +1537,16 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
for (int i = 0; i < optimized_graph->node_size(); ++i) {
NodeDef* node = optimized_graph->mutable_node(i);
+ if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(1, node, optimized_graph);
+ continue;
+ }
+
+ if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(0, node, optimized_graph);
+ continue;
+ }
+
// Remove Shuffle or Reverse op over scalar values.
if (use_shape_info &&
!properties->GetInputProperties(node->name()).empty() &&
@@ -1708,9 +1713,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
}
// Move constants past Enter.
- // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022
- if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) &&
- node->input_size() > 0) {
+ if (IsEnter(*node) && node->input_size() > 0) {
+ if (node->attr().count("is_constant") == 0 ||
+ !node->attr().at("is_constant").b()) {
+ continue;
+ }
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
if (input != nullptr && IsReallyConstant(*input) &&
@@ -1739,7 +1746,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
node_map_->AddOutput(node_name, new_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
- if (consumer->input(i) == node_name) {
+ if (NodeName(consumer->input(i)) == node_name) {
node_map_->UpdateInput(consumer->name(), node_name,
new_node->name());
consumer->set_input(i, new_node->name());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 1db4fb9de7..71ee81dfde 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -83,14 +83,6 @@ class ConstantFoldingTest : public GrapplerTest {
}
};
-template <DataType DTYPE>
-Tensor GetRandomTensor(const TensorShape& shape) {
- typedef typename EnumToDataType<DTYPE>::Type T;
- Tensor tensor(DTYPE, shape);
- tensor.flat<T>() = tensor.flat<T>().random();
- return tensor;
-}
-
TEST_F(ConstantFoldingTest, SimpleFolding) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -380,11 +372,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ(2, t.tensor_shape().dim(1).size());
}
}
- auto a_t = GetRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
- auto b_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
- auto x_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
- auto y_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
- auto bias_t = GetRandomTensor<DT_FLOAT>(TensorShape({2}));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto bias_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
auto tensors_expected = EvaluateNodes(
item.graph, item.fetch,
@@ -1264,6 +1256,10 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
ops::Merge m3(scope.WithOpName("m3"), {x, y});
+ // m4 is not foldable because the only constant input
+ // has a control input, so we cannot know if it will be
+ // triggered.
+ ops::Merge m4(scope.WithOpName("m4"), {x, const1});
ops::Identity out1(scope.WithOpName("out1"), m1.output);
ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
@@ -1271,9 +1267,11 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
ops::Identity out3(scope.WithOpName("out3"), m3.output);
ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
+ ops::Identity out4(scope.WithOpName("out4"), m4.output);
+ ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
GrapplerItem item;
- item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"};
+ item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding optimizer(nullptr /* cpu_device */);
@@ -1281,6 +1279,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ EXPECT_EQ(19, output.node_size());
int found_nodes = 0;
for (const auto& node : output.node()) {
if (node.name() == "out1") {
@@ -1317,10 +1316,18 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("m3:1", node.input(0));
++found_nodes;
+ } else if (node.name() == "out4") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("m4", node.input(0));
+ ++found_nodes;
+ } else if (node.name() == "idx4") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("m4:1", node.input(0));
+ ++found_nodes;
}
}
// Make sure the graph contains all the nodes we're expecting.
- EXPECT_EQ(6, found_nodes);
+ EXPECT_EQ(8, found_nodes);
std::vector<string> fetch = {"out1", "idx1"};
auto tensors = EvaluateNodes(output, fetch);
@@ -1335,6 +1342,82 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
EXPECT_EQ(2, out_idx.flat<int32>()(0));
}
+TEST_F(ConstantFoldingTest, SplitRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 =
+ ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
+ Output in2 =
+ ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
+ auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
+ ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
+ ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
+
+ ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
+
+ GrapplerItem item;
+ item.fetch = {"out"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("split_dim", "Const", {}, {}, &want);
+ AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
+ &want);
+ AddNode("s2", "Split", {"in2", "split_dim"}, {}, &want);
+ AddNode("out", "Add", {"s1", "s2"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, SplitVRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 =
+ ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
+ Output in2 =
+ ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
+ auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
+ auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
+ auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
+ ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
+ ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
+
+ LOG(INFO) << s1.output.size();
+ LOG(INFO) << s2.output.size();
+ ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
+
+ GrapplerItem item;
+ item.fetch = {"out"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("split_dim", "Const", {}, {}, &want);
+ AddNode("size_splits1", "Const", {}, {}, &want);
+ AddNode("size_splits2", "Const", {}, {}, &want);
+ AddNode("s1", "Identity",
+ {"in1", AsControlDependency("size_splits1"),
+ AsControlDependency("split_dim")},
+ {}, &want);
+ AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
+ AddNode("out", "Add", {"s1", "s2"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
@@ -2252,6 +2335,10 @@ TEST_F(ConstantFoldingTest, Enter) {
GrapplerItem item;
AttrValue frame_name;
frame_name.set_s("foo");
+ AttrValue is_constant_true;
+ is_constant_true.set_b(true);
+ AttrValue is_constant_false;
+ is_constant_false.set_b(false);
AttrValue type;
type.set_type(DT_FLOAT);
AttrValue value;
@@ -2262,19 +2349,31 @@ TEST_F(ConstantFoldingTest, Enter) {
GraphDef& graph = item.graph;
AddNode("x", "Placeholder", {}, {{"T", type}}, &graph);
AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
- AddNode("enter1", "Enter", {"x"}, {{"T", type}, {"frame_name", frame_name}},
+ AddNode("enter1", "Enter", {"x"},
+ {{"T", type},
+ {"frame_name", frame_name},
+ {"is_constant", is_constant_true}},
&graph);
- AddNode("enter2", "Enter", {"c1"}, {{"T", type}, {"frame_name", frame_name}},
+ AddNode("enter2", "Enter", {"c1"},
+ {{"T", type},
+ {"frame_name", frame_name},
+ {"is_constant", is_constant_true}},
+ &graph);
+ AddNode("enter3", "Enter", {"c1"},
+ {{"T", type},
+ {"frame_name", frame_name},
+ {"is_constant", is_constant_false}},
&graph);
AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
+ AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
item.fetch.push_back("id1");
item.fetch.push_back("id2");
item.fetch.push_back("id3");
+ item.fetch.push_back("id4");
- ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
- nullptr /* cpu_device */);
+ ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -2283,7 +2382,7 @@ TEST_F(ConstantFoldingTest, Enter) {
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(7, output.node_size());
+ EXPECT_EQ(9, output.node_size());
for (const NodeDef& node : output.node()) {
if (node.name() == "id1") {
EXPECT_EQ("Identity", node.op());
@@ -2295,6 +2394,11 @@ TEST_F(ConstantFoldingTest, Enter) {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^enter2", node.input(0));
}
+ if (node.name() == "id4") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("enter3", node.input(0));
+ }
}
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 0e058e3435..8bd10171f1 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -39,6 +40,10 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
inp = AsControlDependency(inp);
}
}
+ } else if (IsCheckNumerics(node)) {
+ // Replace with Identity op which will be pruned later.
+ node.set_op("Identity");
+ node.mutable_attr()->erase("message");
}
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index aacd55f136..3f11febc64 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -29,14 +29,13 @@ namespace {
class DebugStripperTest : public GrapplerTest {};
TEST_F(DebugStripperTest, OutputEqualToInput) {
- constexpr char device[] = "/device:CPU:0";
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({}));
+ Output add = ops::Add(s, x, y);
+ Output result = ops::Identity(s, add);
GrapplerItem item;
- item.graph = test::function::GDef(
- {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}},
- device),
- test::function::NDef("y", "XTimesTwo", {"x"}, {}, device),
- test::function::NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, device)},
- {});
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
DebugStripper optimizer;
GraphDef output;
@@ -45,19 +44,17 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
}
TEST_F(DebugStripperTest, StripAssertFromGraph) {
- constexpr char device[] = "/device:CPU:0";
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), greaterequal, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("z").WithControlDependencies({assert.operation}), x, y);
GrapplerItem item;
- item.graph = test::function::GDef(
- {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}},
- device),
- test::function::NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}},
- device),
- test::function::NDef("GreaterEqual", "GreaterEqual", {"x", "y"},
- {{"T", DT_FLOAT}}, device),
- test::function::NDef("Assert", "Assert", {"GreaterEqual"},
- {{"T", DT_FLOAT}}, device),
- test::function::NDef("z", "Add", {"x", "y", "^Assert"}, {}, device)},
- {});
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
DebugStripper optimizer;
GraphDef output;
@@ -68,31 +65,27 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) {
if (node.name() == "x") {
count++;
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "y") {
count++;
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "GreaterEqual") {
count++;
EXPECT_EQ("GreaterEqual", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Assert") {
count++;
EXPECT_EQ("NoOp", node.op());
- EXPECT_EQ(device, node.device());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("^GreaterEqual", node.input(0));
- EXPECT_EQ(0, node.attr_size());
+ EXPECT_EQ("^x", node.input(1));
+ EXPECT_EQ("^y", node.input(2));
} else if (node.name() == "z") {
count++;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(device, node.device());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
@@ -100,6 +93,75 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) {
}
}
EXPECT_EQ(5, count);
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ Tensor y_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ y_t.flat<float>()(0) = 0.5f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
+TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto check1 = ops::CheckNumerics(s.WithOpName("CheckNumerics1"), x, "foo");
+ auto check2 = ops::CheckNumerics(s.WithOpName("CheckNumerics2"), y, "foo");
+ Output add = ops::Add(s.WithOpName("z"), check1, check2);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "y") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "CheckNumerics1") {
+ count++;
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ(1, node.attr_size());
+ } else if (node.name() == "CheckNumerics2") {
+ count++;
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ(1, node.attr_size());
+ } else if (node.name() == "z") {
+ count++;
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("CheckNumerics1", node.input(0));
+ EXPECT_EQ("CheckNumerics2", node.input(1));
+ }
+ }
+ EXPECT_EQ(5, count);
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ Tensor y_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ y_t.flat<float>()(0) = 0.5f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
}
} // namespace
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 57b3118245..6a297da52d 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -678,6 +678,50 @@ TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
}
}
+TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto noop =
+ ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal));
+ Output add = ops::Add(
+ s.WithOpName("z").WithControlDependencies({noop.operation}), x, y);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ item.fetch.push_back("z");
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "y") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "GreaterEqual") {
+ count++;
+ } else if (node.name() == "NoOp") {
+ count++;
+ } else if (node.name() == "z") {
+ count++;
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
+ }
+ }
+ EXPECT_EQ(3, count);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 2a6b8a325f..f1da469a6c 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -32,16 +32,129 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+namespace {
+
+class FunctionInliningContext {
+ public:
+ explicit FunctionInliningContext(const GrapplerItem& item)
+ : library_(&item.graph.library()), functions_(InliningCandidates(item)) {}
+
+ const FunctionDefLibrary& Library() const { return *library_; }
+
+ bool HasInlinedFunctions() const { return !functions_.empty(); }
+
+ // Find inlining candidate by name. Return nullptr if not found.
+ const FunctionDef* FindInlinedFunction(const string& name) const {
+ auto it = functions_.find(name);
+ if (it != functions_.end()) {
+ return it->second;
+ } else {
+ return nullptr;
+ }
+ }
+
+ private:
+ std::unordered_map<string, const FunctionDef*> InliningCandidates(
+ const GrapplerItem& item) const {
+ std::unordered_map<string, const FunctionDef*> functions;
+ for (const FunctionDef& func : item.graph.library().function()) {
+ // Don't inline functions marked as noinline
+ if (func.attr().count("_noinline") != 0) {
+ continue;
+ }
+ // Don't touch anything marked XLA to prevent XLA failures further down
+ // the road.
+ if (func.attr().count("_XlaCompile") > 0 &&
+ func.attr().at("_XlaCompile").b()) {
+ continue;
+ }
+ // Can't create IdentityN nodes with no input or output: skip these
+ // functions for now.
+ if (func.signature().input_arg_size() == 0 ||
+ func.signature().output_arg_size() == 0) {
+ continue;
+ }
+ functions[func.signature().name()] = &func;
+ }
+ return functions;
+ }
+
+ const FunctionDefLibrary* library_;
+ std::unordered_map<string, const FunctionDef*> functions_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext);
+};
+
+// Copy input/output argument type to the type_list. Return error if argument
+// type is not explicitly defined, and not specified in function attributes.
+Status CopyArgType(const NodeDef& func_node,
+ const std::unordered_map<string, AttrValue>& func_attr,
+ const string& arg_kind, const OpDef::ArgDef& arg,
+ AttrValue::ListValue* type_list) {
+ if (arg.type() != DT_INVALID) {
+ type_list->add_type(arg.type());
+ } else {
+ auto it = func_attr.find(arg.type_attr());
+ if (it == func_attr.end() || it->second.type() == DT_INVALID) {
+ return errors::InvalidArgument(
+ "Invalid ", arg_kind, " argument ", arg.name(), " for function ",
+ func_node.op(), " instantiated by ", func_node.name());
+ }
+ type_list->add_type(it->second.type());
+ }
+ return Status::OK();
+}
+
+// Add an IdentityN op to hook the function inputs to: this ensures that
+// they're all evaluated before the evaluation of the function body starts.
+Status HookInlinedFunctionInputs(
+ const NodeDef& func_node, const FunctionDef& func,
+ const std::unordered_map<string, AttrValue>& func_attr, NodeDef* inputs) {
+ inputs->set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
+ inputs->set_op("IdentityN");
+ inputs->set_device(func_node.device());
+ *inputs->mutable_input() = func_node.input();
+ AttrValue::ListValue* type_list =
+ (*inputs->mutable_attr())["T"].mutable_list();
+ for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
+ TF_RETURN_IF_ERROR(
+ CopyArgType(func_node, func_attr, "input", arg, type_list));
+ }
+ return Status::OK();
+}
+
+// Add an IdentityN op to hook the function outputs to: this ensures that the
+// function body is fully evaluated before its fanout gets scheduled.
+Status HookInlinedFunctionOutputs(
+ const NodeDef& func_node, const FunctionDef& func,
+ const std::unordered_map<string, AttrValue>& func_attr,
+ const gtl::ArraySlice<string> fetch, NodeDef* outputs) {
+ outputs->set_name(func_node.name());
+ outputs->set_op("IdentityN");
+ outputs->set_device(func_node.device());
+ AttrValue::ListValue* type_list =
+ (*outputs->mutable_attr())["T"].mutable_list();
+ for (int i = 0; i < func.signature().output_arg_size(); ++i) {
+ const OpDef::ArgDef& arg = func.signature().output_arg(i);
+ TF_RETURN_IF_ERROR(
+ CopyArgType(func_node, func_attr, "output", arg, type_list));
+ // Use the fetch names since they take into account the output mapping.
+ outputs->add_input(strings::StrCat(func_node.name(), "/", fetch[i]));
+ }
+ return Status::OK();
+}
+
+Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
+ const FunctionInliningContext& ctx,
+ GraphDef* optimized_graph) {
+ const std::unordered_map<string, AttrValue> func_attr(
+ func_node.attr().begin(), func_node.attr().end());
-Status InlineFunction(const NodeDef& node, const FunctionDef& func,
- const FunctionDefLibrary& library, GraphDef* graph) {
- const std::unordered_map<string, AttrValue> attr(node.attr().begin(),
- node.attr().end());
std::unique_ptr<GrapplerItem> item =
- GrapplerItemFromFunctionDef(func, attr, library);
+ GrapplerItemFromFunctionDef(func, func_attr, ctx.Library());
if (!item) {
- return errors::InvalidArgument("Failed to inline function ", node.op(),
- " instantiated by ", node.name());
+ return errors::InvalidArgument("Failed to inline function ", func_node.op(),
+ " instantiated by ", func_node.name());
}
std::unordered_map<string, int> input_nodes;
@@ -50,43 +163,25 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func,
input_nodes[arg.name()] = i;
}
- // Add an IdentityN op to hook the function inputs to: this ensures that
- // they're all evaluated before the evaluation of the function body starts.
- NodeDef* func_inputs = graph->add_node();
- func_inputs->set_name(strings::StrCat(node.name(), "/", "inlined_inputs"));
- func_inputs->set_op("IdentityN");
- func_inputs->set_device(node.device());
- *func_inputs->mutable_input() = node.input();
- AttrValue::ListValue* type_list =
- (*func_inputs->mutable_attr())["T"].mutable_list();
- for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
- if (arg.type() != DT_INVALID) {
- type_list->add_type(arg.type());
- } else {
- auto it = attr.find(arg.type_attr());
- if (it == attr.end()) {
- return errors::InvalidArgument("Invalid input argument ", arg.name(),
- " for function ", node.op(),
- " instantiated by ", node.name());
- }
- type_list->add_type(it->second.type());
- }
- }
+ // Hook inlined function inputs to IdentityN node
+ NodeDef* func_inputs = optimized_graph->add_node();
+ TF_RETURN_IF_ERROR(
+ HookInlinedFunctionInputs(func_node, func, func_attr, func_inputs));
for (NodeDef& func_body_node : *item->graph.mutable_node()) {
if (input_nodes.find(func_body_node.name()) != input_nodes.end()) {
+ CHECK_EQ(0, func_body_node.input_size());
// Turn input placeholders into identity nodes
if (IsPlaceholder(func_body_node)) {
func_body_node.set_op("Identity");
}
- CHECK_EQ(0, func_body_node.input_size());
int input_id = input_nodes[func_body_node.name()];
func_body_node.add_input(
strings::StrCat(func_inputs->name(), ":", input_id));
} else {
// Update the input names if any.
for (string& input : *func_body_node.mutable_input()) {
- input = AddPrefixToNodeName(input, node.name());
+ input = AddPrefixToNodeName(input, /*prefix=*/func_node.name());
}
// If the node has no input, make hook it up to the func_inputs node to
// ensure it runs in the same frame as the other nodes of the function
@@ -98,39 +193,29 @@ Status InlineFunction(const NodeDef& node, const FunctionDef& func,
// Add the node name as a prefix to avoid collisions after inlining
func_body_node.set_name(
- strings::StrCat(node.name(), "/", func_body_node.name()));
+ strings::StrCat(func_node.name(), "/", func_body_node.name()));
// Make sure the node is placed
- func_body_node.set_device(node.device());
-
- // Move the node to the main graph
- graph->add_node()->Swap(&func_body_node);
- }
-
- // Add an IdentityN op to hook the function outputs to: this ensures that the
- // function body is fully evaluated before its fanout gets scheduled.
- NodeDef* func_outputs = graph->add_node();
- func_outputs->set_name(node.name());
- func_outputs->set_op("IdentityN");
- func_outputs->set_device(node.device());
- type_list = (*func_outputs->mutable_attr())["T"].mutable_list();
- for (int i = 0; i < func.signature().output_arg_size(); ++i) {
- const OpDef::ArgDef& arg = func.signature().output_arg(i);
- if (arg.type() != DT_INVALID) {
- type_list->add_type(arg.type());
+ func_body_node.set_device(func_node.device());
+
+ // Check if a body node is itself a function
+ const FunctionDef* func_body_node_func =
+ ctx.FindInlinedFunction(func_body_node.op());
+ if (func_body_node_func != nullptr) {
+ // Recursively inline function calls
+ TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
+ ctx, optimized_graph));
} else {
- auto it = attr.find(arg.type_attr());
- if (it == attr.end()) {
- return errors::InvalidArgument("Invalid output argument ", arg.name(),
- " for function ", node.op(),
- " instantiated by ", node.name());
- }
- type_list->add_type(it->second.type());
+ // Move the node to the main graph
+ optimized_graph->add_node()->Swap(&func_body_node);
}
- // Use the fetch names since they take into account the output mapping.
- func_outputs->add_input(strings::StrCat(node.name(), "/", item->fetch[i]));
}
+ // Hook inlined function outputs to IdentityN node
+ NodeDef* func_outputs = optimized_graph->add_node();
+ TF_RETURN_IF_ERROR(HookInlinedFunctionOutputs(func_node, func, func_attr,
+ item->fetch, func_outputs));
+
return Status::OK();
}
@@ -278,31 +363,14 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
return Status::OK();
}
+} // namespace
+
Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- std::unordered_map<string, const FunctionDef*> functions;
- for (const FunctionDef& func : item.graph.library().function()) {
- // Don't inline functions marked as noinline
- if (func.attr().count("_noinline") != 0) {
- continue;
- }
- // Don't touch anything marked XLA to prevent XLA failures further down the
- // road.
- if (func.attr().count("_XlaCompile") > 0 &&
- func.attr().at("_XlaCompile").b()) {
- continue;
- }
- // Can't create IdentityN nodes with no input or output: skip these
- // functions for now.
- if (func.signature().input_arg_size() == 0 ||
- func.signature().output_arg_size() == 0) {
- continue;
- }
- functions[func.signature().name()] = &func;
- }
+ FunctionInliningContext function_inlining_ctx(item);
- // Nothing to do.
- if (functions.empty()) {
+ // Nothing to do here.
+ if (!function_inlining_ctx.HasInlinedFunctions()) {
*optimized_graph = item.graph;
return Status::OK();
}
@@ -315,12 +383,14 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph));
continue;
}
- auto it = functions.find(node.op());
- if (it == functions.end()) {
- *optimized_graph->add_node() = node;
+
+ const FunctionDef* func =
+ function_inlining_ctx.FindInlinedFunction(node.op());
+ if (func != nullptr) {
+ TF_RETURN_IF_ERROR(
+ InlineFunction(node, *func, function_inlining_ctx, optimized_graph));
} else {
- TF_RETURN_IF_ERROR(InlineFunction(node, *it->second, item.graph.library(),
- optimized_graph));
+ *optimized_graph->add_node() = node;
}
}
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index deb2fabded..c804d75756 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -26,7 +26,22 @@ namespace tensorflow {
namespace grappler {
namespace {
-class FunctionOptimizerTest : public GrapplerTest {};
+constexpr char kDevice[] = "/device:CPU:0";
+
+class FunctionOptimizerTest : public GrapplerTest {
+ protected:
+ Tensor MakeScalarTensor(float value) {
+ Tensor tensor(DT_FLOAT, {});
+ tensor.scalar<float>()() = value;
+ return tensor;
+ }
+
+ Tensor MakeScalarTensor(int value) {
+ Tensor tensor(DT_INT32, {});
+ tensor.scalar<int>()() = value;
+ return tensor;
+ }
+};
TEST_F(FunctionOptimizerTest, SimpleFunction) {
// Build a graph to compute y = XTimesTwo(x)
@@ -94,9 +109,8 @@ TEST_F(FunctionOptimizerTest, SimpleFunction) {
}
EXPECT_EQ(7, count);
+ Tensor pi = MakeScalarTensor(3.14f);
item.fetch = {"z"};
- Tensor pi(DT_FLOAT, {});
- pi.flat<float>()(0) = 3.14f;
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
@@ -183,9 +197,8 @@ TEST_F(FunctionOptimizerTest, FixedTypeFunction) {
}
EXPECT_EQ(6, count);
+ Tensor pi = MakeScalarTensor(3.14f);
item.fetch = {"z"};
- Tensor pi(DT_FLOAT, {});
- pi.flat<float>()(0) = 3.14f;
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
@@ -268,9 +281,8 @@ TEST_F(FunctionOptimizerTest, FunctionWithOutputMapping) {
}
EXPECT_EQ(6, count);
+ Tensor pi = MakeScalarTensor(3.14f);
item.fetch = {"z"};
- Tensor pi(DT_FLOAT, {});
- pi.flat<float>()(0) = 3.14f;
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
@@ -325,18 +337,11 @@ TEST_F(FunctionOptimizerTest, FunctionWithInputForwarding) {
TF_EXPECT_OK(status);
item.fetch = {"z0", "z1", "z2"};
- Tensor in(DT_FLOAT, {});
- in.flat<float>()(0) = 3.14f;
- item.feed.emplace_back("x0", in);
- in.flat<float>()(0) = 2.7f;
- item.feed.emplace_back("x1", in);
- in.flat<float>()(0) = 1.0f;
- item.feed.emplace_back("x2", in);
- in.flat<float>()(0) = -1.0f;
- item.feed.emplace_back("x4", in);
- Tensor in_int(DT_INT32, {});
- in_int.flat<int>()(0) = 1234;
- item.feed.emplace_back("x3", in_int);
+ item.feed.emplace_back("x0", MakeScalarTensor(3.14f));
+ item.feed.emplace_back("x1", MakeScalarTensor(2.7f));
+ item.feed.emplace_back("x2", MakeScalarTensor(1.0f));
+ item.feed.emplace_back("x4", MakeScalarTensor(-1.0f));
+ item.feed.emplace_back("x3", MakeScalarTensor(1234));
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
@@ -379,6 +384,100 @@ TEST_F(FunctionOptimizerTest, FunctionWithoutInput) {
EXPECT_EQ(item.graph.DebugString(), output.DebugString());
}
+TEST_F(FunctionOptimizerTest, InlineFunctionWithNestedFunctionCall) {
+ // Define square via function library:
+ // MySquare(x) = MyMul(x, x)
+
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ FunctionDef square_func = FunctionDefHelper::Create(
+ "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {test::function::NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}},
+ kDevice),
+ test::function::NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}},
+ kDevice),
+ test::function::NDef("outputs", "Identity", {"square:0"},
+ {{"T", DT_FLOAT}}, kDevice)},
+ // FunctionLib
+ {mul_func, square_func});
+
+ GraphDef output;
+ FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "square/inlined_inputs" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("a", node.input(0));
+ } else if (node.name() == "square/x" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/inlined_inputs:0", node.input(0));
+ } else if (node.name() == "square/output/inlined_inputs" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("square/x", node.input(0));
+ EXPECT_EQ("square/x", node.input(1));
+ } else if (node.name() == "square/output/x" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output/inlined_inputs:0", node.input(0));
+ } else if (node.name() == "square/output/y" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output/inlined_inputs:1", node.input(0));
+ } else if (node.name() == "square/output/output" && count++) {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("square/output/x", node.input(0));
+ EXPECT_EQ("square/output/y", node.input(1));
+ } else if (node.name() == "square/output" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output/output:0", node.input(0));
+ } else if (node.name() == "square" && count++) {
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square/output:0", node.input(0));
+ } else if (node.name() == "outputs" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(kDevice, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square:0", node.input(0));
+ }
+ }
+ EXPECT_EQ(9, count);
+
+ item.fetch = {"outputs"};
+ item.feed.emplace_back("a", MakeScalarTensor(2.0f));
+ auto tensors_expected = EvaluateFetchNodes(item);
+
+ GrapplerItem optimized(item, std::move(output));
+ auto tensors = EvaluateFetchNodes(optimized);
+
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
TEST_F(FunctionOptimizerTest, SymbolicGradients) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index be95c00d2d..7ed0474861 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
namespace tensorflow {
namespace grappler {
@@ -45,21 +44,16 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
- GraphProperties* graph_properties, NodeMap* node_map,
- FrameMap* frame_map)
+ GraphProperties* graph_properties, NodeMap* node_map)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
- node_map(node_map),
- frame_map(frame_map) {}
+ node_map(node_map) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
- // TODO(ezhulenev): it seems that frame_map is only relevant for loop
- // optimizer? Move it to loop-optimizer specific context extension.
- FrameMap* frame_map;
};
Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
@@ -117,6 +111,9 @@ class GraphOptimizerStage {
: optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {}
virtual ~GraphOptimizerStage() = default;
+ const string& stage_name() const { return stage_name_; }
+ const string& optimizer_name() const { return optimizer_name_; }
+
// Check if we should try to simplify node. Returning true doesn't
// guarantee that node will be simplified.
//
@@ -179,6 +176,64 @@ class GraphOptimizerStage {
const GraphOptimizerContext ctx_;
};
+template <typename Result>
+class GraphOptimizerStagePipeline {
+ public:
+ // Break predicate specifies if a pipeline should stop early, and not pass
+ // a node to the next registered optimizer stage, typically that should be the
+ // case when a stage successfully optimized a node, and it wants to yield
+ // control to the optimizer.
+ explicit GraphOptimizerStagePipeline(
+ const std::function<bool(const Result&)> break_predicate)
+ : break_predicate_(break_predicate) {}
+
+ // Add a stage to the pipeline. It should be called with the arguments for the
+ // stage constructor:
+ //
+ // pipeline.AddStage<FooStage>(constructor_arg1, constructor_arg2);
+ //
+ // Returns a reference to the added stage.
+ template <typename T, typename... Args>
+ T& AddStage(Args&&... args) {
+ auto stage = new T(std::forward<Args>(args)...);
+ stages_.push_back(std::unique_ptr<T>(stage));
+ return *stage;
+ }
+
+ // Pass a node through all registered optimizer stages, until break predicate
+ // is true.
+ //
+ // Return true, if pipeline exited after a break predicate was evaluated as
+ // 'true', which typically means that a node was optimized by one of the
+ // registered stages.
+ //
+ // Return false, if node was not optimized by any of registered stages.
+ bool PassThroughAllStages(NodeDef* node, Result* result) {
+ for (auto& stage : stages_) {
+ if (stage->IsSupported(node)) {
+ const Status stage_status = stage->TrySimplify(node, result);
+ // Each stage must be "error safe" (just like exception safe). In
+ // case of any error it must leave optimized graph unmodified.
+ if (!stage_status.ok()) {
+ LOG(WARNING) << "Failed to run optimizer " << stage->optimizer_name()
+ << ", stage " << stage->stage_name()
+ << ". Error: " << stage_status.error_message();
+ }
+ if (break_predicate_(*result)) return true;
+ }
+ }
+ return false;
+ }
+
+ std::size_t NumStages() { return stages_.size(); }
+
+ private:
+ std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_;
+ std::function<bool(const Result&)> break_predicate_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline);
+};
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
index 416327e622..3f5ab87a5a 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
@@ -58,8 +58,8 @@ TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScope_InScope) {
TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ nullptr,
- /*graph_properties*/ nullptr, /*node_name*/ nullptr,
- /*frame_map*/ nullptr);
+ /*graph_properties*/ nullptr,
+ /*node_name*/ nullptr);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
const auto node = ParseNodeScopeAndName("a/b/c/Add");
@@ -94,8 +94,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map,
- /*frame_map*/ nullptr);
+ /*node_name*/ &node_map);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -134,8 +133,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map,
- /*frame_map*/ nullptr);
+ /*node_name*/ &node_map);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -165,4 +163,4 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
} // namespace
} // end namespace grappler
-} // end namespace tensorflow \ No newline at end of file
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 254c1edf7b..308eecd420 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -2119,6 +2119,10 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item,
Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
+ if (cluster == nullptr) {
+ return errors::InvalidArgument("cluster == nullptr");
+ }
+
if (GetNumGPUs(*cluster) < 1) {
// LayoutOptimizer is currently only tuned for GPU.
*output = item.graph;
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index ad655db727..5723e397ab 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/utils/colocation.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
@@ -44,16 +45,15 @@ int64 NumEdges(const GraphDef& graph) {
}
string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
- return strings::StrCat("Graph size before: ", before.node_size(), " nodes, ",
- NumEdges(before),
- " edges. Graph size after: ", after.node_size(),
- " nodes, ", NumEdges(after), " edges.");
+ return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
+ after.node_size() - before.node_size(), "), ",
+ NumEdges(after), " edges (",
+ NumEdges(after) - NumEdges(before), ")");
}
} // namespace
std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
const string& optimizer) {
- VLOG(1) << "Adding graph optimization pass: " << optimizer;
std::unique_ptr<GraphOptimizer> graph_optimizer;
if (optimizer == "pruning") {
graph_optimizer.reset(new ModelPruner());
@@ -171,46 +171,58 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
return Status::OK();
}
+ // Some optimizers should be run only once.
+ const std::set<string> run_once_optimizers = {"layout"};
bool already_optimized = false;
- for (const auto& optimizer : optimizers) {
- if (!already_optimized) {
- Status status = optimizer->Optimize(cluster, item, optimized_graph);
- string result;
- if (!status.ok()) {
- VLOG(1) << "Not able to apply optimizer " << optimizer->name()
- << ". Return status: " << status.ToString();
- result = status.ToString();
- } else {
- already_optimized = true;
- result = strings::StrCat(
- "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph));
+ const int num_iterations =
+ cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
+ ? 1
+ : cfg_.meta_optimizer_iterations();
+ for (int iteration = 0; iteration < num_iterations; ++iteration) {
+ VLOG(1) << "Starting optimization iteration " << iteration + 1;
+ for (const auto& optimizer : optimizers) {
+ if (iteration > 0 && run_once_optimizers.count(optimizer->name())) {
+ continue;
}
- result_.push_back(std::make_pair(optimizer->name(), result));
- VLOG(1) << "Optimizer " << optimizer->name()
- << " return status: " << result;
- } else {
- GrapplerItem optimized_item(item, std::move(*optimized_graph));
- Status status =
- optimizer->Optimize(cluster, optimized_item, optimized_graph);
- string result;
- if (!status.ok()) {
- VLOG(1) << "Not able to apply optimizer " << optimizer->name()
- << ". Return status: " << status.ToString();
- optimized_graph->Swap(&optimized_item.graph);
- result = status.ToString();
+ if (!already_optimized) {
+ Status status = optimizer->Optimize(cluster, item, optimized_graph);
+ string result;
+ if (!status.ok()) {
+ VLOG(1) << "Not able to apply optimizer " << optimizer->name()
+ << ". Return status: " << status.ToString();
+ result = status.ToString();
+ } else {
+ already_optimized = true;
+ result = strings::StrCat(
+ "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph));
+ }
+ result_.push_back(std::make_pair(optimizer->name(), result));
+ VLOG(1) << "Optimizer " << optimizer->name()
+ << " return status: " << result;
} else {
- result = strings::StrCat(
- "OK. ",
- PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph));
+ GrapplerItem optimized_item(item, std::move(*optimized_graph));
+ Status status =
+ optimizer->Optimize(cluster, optimized_item, optimized_graph);
+ string result;
+ if (!status.ok()) {
+ VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
+ << status.ToString();
+ optimized_graph->Swap(&optimized_item.graph);
+ result = status.ToString();
+ } else {
+ result = strings::StrCat(
+ optimizer->name(), ": ",
+ PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph));
+ }
+ result_.push_back(std::make_pair(optimizer->name(), result));
+ VLOG(1) << result;
}
- result_.push_back(std::make_pair(optimizer->name(), result));
- VLOG(1) << "Optimizer " << optimizer->name()
- << " return status: " << result;
}
}
if (already_optimized) {
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+ ReassignColocation(optimized_graph);
// Make sure that the optimizers preserved the graph version and library.
DCHECK_GE(optimized_graph->library().function_size(),
item.graph.library().function_size());
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 536347d834..d9a386b9be 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -72,6 +72,20 @@ TEST(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST(MetaOptimizerTest, RunOptimizersTwice) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/optimizers/symbolic_shapes.h
index a9dcf44e23..eb79bab314 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/optimizers/symbolic_shapes.h
@@ -31,8 +31,8 @@ bool IsUnknown(const TensorShapeProto::Dim& dim);
bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape);
bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties);
-// Shapes are symbolically equal, if they have the same rank, they are
-// they are known or symbolically defined, and have matching dimensions.
+// Shapes are symbolically equal, if they have the same rank, they are known or
+// symbolically defined, and have matching dimensions.
bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
const TensorShapeProto& right);
bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 86a6d5000d..5893f286ed 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -255,6 +255,14 @@ int NumOutputs(const NodeDef& node, GraphDef* graph) {
return num_outputs;
}
+bool HasControlInputs(const NodeDef& node) {
+ int num_inputs = node.input_size();
+ if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
+ return true;
+ }
+ return false;
+}
+
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (const string& input : node.input()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 7aa31939f5..11555d712a 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -138,6 +138,9 @@ string AsControlDependency(const string& node);
// some of the outputs may be unconnected.
int NumOutputs(const NodeDef& node, GraphDef* graph);
+// Returns true iff the node has at least one control input.
+bool HasControlInputs(const NodeDef& node);
+
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index baf24c2505..7419c26dff 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -181,3 +181,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+cc_library(
+ name = "colocation",
+ srcs = ["colocation.cc"],
+ hdrs = ["colocation.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:utils",
+ ],
+)
+
+tf_cc_test(
+ name = "colocation_test",
+ size = "small",
+ srcs = ["colocation_test.cc"],
+ deps = [
+ ":colocation",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
diff --git a/tensorflow/core/grappler/utils/colocation.cc b/tensorflow/core/grappler/utils/colocation.cc
new file mode 100644
index 0000000000..0573e0a830
--- /dev/null
+++ b/tensorflow/core/grappler/utils/colocation.cc
@@ -0,0 +1,122 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/colocation.h"
+
+#include <cstring>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+
+// Find root node of the colocation group.
+// The map is mapping from one node name to its parent. node_name is the
+// starting node to search. By iteratively following the path from child to
+// parent, we can find the root node for the colocation group that node_name
+// belongs to.
+string GetColocationGroupRoot(std::unordered_map<string, string>* map,
+ const string& node_name) {
+ if (map->find(node_name) == map->end()) {
+ // If node_name is not in the map, we create a new root node which points
+ // to itself.
+ map->insert({node_name, node_name});
+ return node_name;
+ }
+ string cur = node_name;
+ while ((*map)[cur] != cur) {
+ // Backtracing the map until we reach the root node.
+ cur = (*map)[cur];
+ }
+ return cur;
+}
+
+// Merge two colocation groups into one.
+// left and right is the root node of two colocation groups respectively.
+void MergeColocationGroup(std::unordered_map<string, string>* map,
+ const string& left, const string& right) {
+ // Do nothing if left or right node is not in the map.
+ if (map->find(left) == map->end() || map->find(right) == map->end()) {
+ return;
+ }
+ if (left != right) {
+ // Make the right node a child of the left node, which merges the two
+ // groups.
+ map->at(right) = left;
+ }
+}
+} // namespace
+
+// Use of disjoint set algorithm to build the colocation groups from the input
+// graph. The core data structure in use is a hash map from one node to its
+// parent node. Whenever we see two nodes colocate with each other, we merge
+// their colocation groups together. After we traverse all colocation pairs
+// in the graph, we will have several disjoint sets. Then we pick the root node
+// of each disjoint set as the representative node, and let all other nodes in
+// the group colocate with the representative node.
+void ReassignColocation(GraphDef* graph) {
+ constexpr char kClassAttr[] = "_class";
+ constexpr char kColocPrefix[] = "loc:@";
+
+ // A hashmap that maps from a node name to its parent node name.
+ std::unordered_map<string, string> coloc_groups;
+ NodeMap node_map(graph);
+ for (const auto& node : graph->node()) {
+ auto iter = node.attr().find(kClassAttr);
+ if (iter != node.attr().end() && iter->second.has_list()) {
+ for (const auto& str : iter->second.list().s()) {
+ size_t pos = str.find(kColocPrefix);
+ if (pos == 0) {
+ // After we find a colocation, update the colocation groups.
+ string colocate_node = str.substr(pos + strlen(kColocPrefix));
+ MergeColocationGroup(
+ &coloc_groups, GetColocationGroupRoot(&coloc_groups, node.name()),
+ GetColocationGroupRoot(&coloc_groups, colocate_node));
+ }
+ }
+ }
+ }
+
+ // We use the root node of each colocation groups as its representative
+ // node. For each node in one group, colocate with the representative node
+ // if the node is in the graph.
+ for (const auto& pair : coloc_groups) {
+ if (pair.first != pair.second) {
+ // This is a child node.
+ NodeDef* node = node_map.GetNode(pair.first);
+ if (node) {
+ // Colocate this node with the root node.
+ AttrValue new_value;
+ new_value.mutable_list()->add_s(
+ kColocPrefix + GetColocationGroupRoot(&coloc_groups, pair.first));
+ node->mutable_attr()->erase(kClassAttr);
+ node->mutable_attr()->insert({kClassAttr, new_value});
+ }
+ } else {
+ // This is a root node. Clear the _class attribute.
+ NodeDef* node = node_map.GetNode(pair.first);
+ if (node) { // root node should always exist in the graph as guaranteed
+ // by order of merging. Just put check here to ensure safety.
+ node->mutable_attr()->erase(kClassAttr);
+ }
+ }
+ }
+}
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/colocation.h b/tensorflow/core/grappler/utils/colocation.h
new file mode 100644
index 0000000000..6062db6102
--- /dev/null
+++ b/tensorflow/core/grappler/utils/colocation.h
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
+
+#include <unordered_map>
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Evaluates the colocation relation in the graph and rewrites the new
+// colocation relation in the graph. We scan the graph nodes sequentially, and
+// builds a disjoint-sets of nodes (within each disjoint-set the nodes are
+// colocated with each other). We then select the root node of each set as a
+// representative node, and then colocate each node within the set (should also
+// exist in graph) with the representative node.
+// Note that there is current one situation this function can't handle:
+// Node A colocates with X, node B colocates with Y, X colocates with Y but
+// X, Y are removed from graph. In this case we can't know A colocates with B.
+void ReassignColocation(GraphDef* graph);
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
diff --git a/tensorflow/core/grappler/utils/colocation_test.cc b/tensorflow/core/grappler/utils/colocation_test.cc
new file mode 100644
index 0000000000..6638364240
--- /dev/null
+++ b/tensorflow/core/grappler/utils/colocation_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/colocation.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class ColocationTest : public ::testing::Test {};
+
+bool VerifyNodeHasColocation(const NodeDef& ndef, const string& coloc) {
+ if (ndef.attr().empty()) {
+ return false;
+ }
+ if (ndef.attr().find("_class") == ndef.attr().end()) {
+ return false;
+ }
+ return ndef.attr().at("_class").list().s(0) == coloc;
+}
+
+TEST(ColocationTest, ReassignColocation_SingleNode) {
+ // Node A colocates with B, but node B is not in the graph.
+ // A
+ // |
+ // |
+ // [B]
+
+ NodeDef ndef;
+ const Status status =
+ NodeDefBuilder("A", "Const").Attr("_class", {"loc:@B"}).Finalize(&ndef);
+ TF_EXPECT_OK(status);
+ GraphDef gdef = test::function::GDef({ndef});
+
+ EXPECT_EQ(1, gdef.node_size());
+ EXPECT_EQ(1, gdef.node(0).attr_size());
+
+ ReassignColocation(&gdef);
+
+ // Validates that node A's colocation info is cleared.
+ EXPECT_EQ(1, gdef.node_size());
+ EXPECT_EQ(0, gdef.node(0).attr_size());
+}
+
+TEST(ColocationTest, ReassignColocation_MultiNode_SingleGroup) {
+ // Node A, B, C colocate with X. D colocates with C. E colocates with D.
+ // Node X is not in the graph.
+ // A B C---D---E
+ // | | |
+ // | | |
+ // +--[X]--+
+ // After re-assign of colocation, A, B, C, D should colocate with E.
+ // A B C D
+ // | | | |
+ // | | | |
+ // +---+-E-+---+
+
+ NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e;
+ Status status =
+ NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
+ TF_EXPECT_OK(status);
+ GraphDef gdef =
+ test::function::GDef({ndef_a, ndef_b, ndef_c, ndef_d, ndef_e});
+
+ EXPECT_EQ(5, gdef.node_size());
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
+
+ ReassignColocation(&gdef);
+
+ EXPECT_EQ(5, gdef.node_size());
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
+ EXPECT_EQ(0, gdef.node(4).attr_size()); // E
+}
+
+TEST(ColocationTest, ReassignColocation_MultiNode_MultiGroup) {
+ // Before re-assign:
+ // Node A, B, C colocate with X. D colocates with C. E colocates with D.
+ // Node U, V colocates with W. Node X, W are not in the graph:
+ // A B C---D---E
+ // | | |
+ // | | |
+ // +--[X]--+
+ //
+ // U V
+ // | |
+ // | |
+ // +--[W]--+
+ //
+ // After re-assign:
+ // A, B, C, D should colocate with E. U should colocate with V.
+ // A B C D
+ // | | | |
+ // | | | |
+ // +---+-E-+---+
+ //
+ // U
+ // |
+ // |
+ // V
+
+ NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v;
+ Status status =
+ NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("U", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_u);
+ TF_EXPECT_OK(status);
+ status =
+ NodeDefBuilder("V", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_v);
+ TF_EXPECT_OK(status);
+ GraphDef gdef = test::function::GDef(
+ {ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v});
+
+ EXPECT_EQ(7, gdef.node_size());
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@W")); // U
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(6), "loc:@W")); // V
+
+ ReassignColocation(&gdef);
+
+ EXPECT_EQ(7, gdef.node_size());
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
+ EXPECT_EQ(0, gdef.node(4).attr_size()); // E
+ EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@V")); // U
+ EXPECT_EQ(0, gdef.node(6).attr_size()); // V
+}
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index 3bc7bea454..e1394b9c35 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -57,6 +57,15 @@ class GrapplerTest : public ::testing::Test {
// Count nodes of the given op-type in a graph.
int CountOpNodes(const GraphDef& graph, const string& op);
+ // Get a random tansor with given shape.
+ template <DataType DTYPE>
+ Tensor GenerateRandomTensor(const TensorShape& shape) const {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ Tensor tensor(DTYPE, shape);
+ tensor.flat<T>() = tensor.flat<T>().random();
+ return tensor;
+ }
+
private:
SessionOptions options_;
};
diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h
index a312e8e8a4..2ed1628bf1 100644
--- a/tensorflow/core/kernels/assign_op.h
+++ b/tensorflow/core/kernels/assign_op.h
@@ -77,7 +77,8 @@ class AssignOp : public OpKernel {
// 1. Try to reuse the rhs.
std::unique_ptr<Tensor> input_alias = context->forward_input(
- 1, old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr);
+ 1, OpKernelContext::Params::kNoReservation /*output_index*/,
+ old_lhs.dtype(), old_lhs.shape(), DEVICE_MEMORY, attr);
if (input_alias != nullptr) {
// Transfer ownership to the ref.
context->replace_ref_input(0, *input_alias.release(),
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
index 39ef8ee3ac..4485152e96 100644
--- a/tensorflow/core/kernels/data_format_ops.cc
+++ b/tensorflow/core/kernels/data_format_ops.cc
@@ -37,25 +37,37 @@ class DataFormatDimMapOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
string dst_format;
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
+ OP_REQUIRES(context, src_format.size() == 4,
+ errors::InvalidArgument(strings::StrCat(
+ "Source format must of length 4, received src_format = ",
+ src_format)));
OP_REQUIRES(
- context, src_format == "NHWC",
+ context, dst_format.size() == 4,
errors::InvalidArgument(strings::StrCat(
- "Current implementation doesn't support source data format ",
- src_format)));
- OP_REQUIRES(context, dst_format == "NCHW",
- errors::InvalidArgument(strings::StrCat(
- "Current implementation doesn't support dst data format ",
- dst_format)));
+ "Destination format must of length 4, received dst_format = ",
+ dst_format)));
+ dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
+ for (int i = 0; i < src_format.size(); ++i) {
+ for (int j = 0; j < dst_format.size(); ++j) {
+ if (dst_format[j] == src_format[i]) {
+ dst_idx_.vec<int>()(i) = j;
+ break;
+ }
+ }
+ }
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- Tensor* output = nullptr;
+ Tensor* output;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
- input.flat<T>(), output->flat<T>());
+ input.flat<T>(), output->flat<T>(),
+ dst_idx_.vec<int>());
}
+
+ Tensor dst_idx_;
};
template <typename Device, typename T>
@@ -147,11 +159,11 @@ TF_CALL_int64(REGISTER_KERNEL);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void DataFormatDimMap<GPUDevice, T>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
- typename TTypes<T>::Flat y); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void DataFormatDimMap<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
+ typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \
extern template struct DataFormatDimMap<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index 2ccc919586..1ca144cb40 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -27,15 +27,25 @@ namespace functor {
template <typename Device, typename T>
struct DataFormatDimMap {
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
- typename TTypes<T>::Flat y) {
+ typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) {
auto zero = x.constant(0);
auto one = x.constant(1);
- auto three = x.constant(3);
+ auto two = x.constant(2);
+
+ auto f_zero = x.constant(dst(0));
+ auto f_one = x.constant(dst(1));
+ auto f_two = x.constant(dst(2));
+ auto f_three = x.constant(dst(3));
+
auto four = x.constant(4);
auto x_mod = (x + four) % 4;
+
auto is_zero = (x_mod == zero);
- auto is_three = (x_mod == three);
- y.device(d) = is_zero.select(zero, is_three.select(one, x_mod + one));
+ auto is_one = (x_mod == one);
+ auto is_two = (x_mod == two);
+
+ y.device(d) = is_zero.select(
+ f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
}
};
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index b687088db1..911aa3a78f 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -21,10 +20,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/platform/mutex.h"
-namespace tensorflow {
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef FunctionLibraryRuntime::Handle FHandle;
@@ -106,11 +107,9 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
opts->runner = ctx->runner();
}
-} // end namespace
-
-class FunctionalIf : public AsyncOpKernel {
+class IfOp : public AsyncOpKernel {
public:
- explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
const NameAttrList* func;
@@ -120,7 +119,7 @@ class FunctionalIf : public AsyncOpKernel {
OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_));
}
- ~FunctionalIf() override {}
+ ~IfOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
bool cond;
@@ -134,8 +133,7 @@ class FunctionalIf : public AsyncOpKernel {
class State {
public:
- State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond,
- DoneCallback done)
+ State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done)
: kernel_(kernel),
ctx_(ctx),
cond_(cond),
@@ -168,7 +166,7 @@ class FunctionalIf : public AsyncOpKernel {
}
private:
- FunctionalIf* const kernel_;
+ IfOp* const kernel_;
OpKernelContext* const ctx_;
const bool cond_;
const DoneCallback done_;
@@ -179,18 +177,22 @@ class FunctionalIf : public AsyncOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf);
+// TODO(drpng): remove this.
+REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
- FunctionalIf);
+ IfOp);
+
+REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
+REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
-class FunctionalWhile : public AsyncOpKernel {
+class WhileOp : public AsyncOpKernel {
public:
- explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
}
- ~FunctionalWhile() override {}
+ ~WhileOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
auto lib = ctx->function_library();
@@ -234,7 +236,7 @@ class FunctionalWhile : public AsyncOpKernel {
class State {
public:
- State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle,
+ State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
FHandle body_handle, DoneCallback done)
: kernel_(kernel),
ctx_(ctx),
@@ -253,7 +255,7 @@ class FunctionalWhile : public AsyncOpKernel {
void Start() { EvalCond(); }
private:
- FunctionalWhile* const kernel_;
+ WhileOp* const kernel_;
OpKernelContext* const ctx_;
const FHandle cond_handle_;
const FHandle body_handle_;
@@ -316,7 +318,152 @@ class FunctionalWhile : public AsyncOpKernel {
}
};
};
-REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile);
-REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile);
+// TODO(drpng): remove these.
+REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
+REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
+
+REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
+REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
+
+Status GetScalar(OpKernelContext* ctx, int index, int32* value,
+ const char* label) {
+ Tensor t = ctx->input(index);
+ if (!TensorShapeUtils::IsScalar(t.shape())) {
+ return errors::InvalidArgument(label, " must be a scalar, but ",
+ t.shape().DebugString());
+ }
+ *value = t.scalar<int32>()();
+ return Status::OK();
+}
+
+class ForOp : public AsyncOpKernel {
+ public:
+ explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ auto lib = ctx->function_library();
+ OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
+ const NameAttrList* func;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
+ OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
+ }
+
+ ~ForOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ (new State(this, ctx, done))->Start();
+ }
+
+ private:
+ FHandle body_handle_;
+
+ class State {
+ public:
+ State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
+ : kernel_(kernel),
+ ctx_(ctx),
+ done_(std::move(done)),
+ lib_(CHECK_NOTNULL(ctx_->function_library())),
+ args_(1 + ctx_->num_inputs() - 3) {
+ args_[0] = Tensor(DT_INT32, {});
+ iter_ = &args_[0].scalar<int32>()();
+
+ const int32 num_loop_inputs = ctx_->num_inputs() - 3;
+ rets_.reserve(num_loop_inputs);
+ for (int i = 0; i < num_loop_inputs; ++i) {
+ rets_.push_back(ctx_->input(3 + i));
+ }
+ }
+
+ ~State() {}
+
+ void Start() {
+ Status s = StartLoop();
+ if (!s.ok()) Finish(s);
+ }
+
+ private:
+ ForOp* const kernel_;
+ OpKernelContext* const ctx_;
+ const DoneCallback done_;
+ FunctionLibraryRuntime* const lib_;
+ FunctionLibraryRuntime::Options opts_;
+ TensorVec args_;
+ TensorVec rets_;
+
+ int32* iter_; // points to args_[0].
+ int32 limit_;
+ int32 delta_;
+
+ // If an error e is returned, caller must call Finish(e).
+ // If OK is returned, the async loop execution has been started.
+ Status StartLoop() {
+ SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
+
+ TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
+ TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
+ TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
+
+ if ((delta_ > 0 && *iter_ <= limit_) ||
+ (delta_ < 0 && *iter_ >= limit_) ||
+ (delta_ == 0 && *iter_ == limit_)) {
+ RunNext();
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
+ " ", limit_, " ", delta_);
+ }
+ }
+
+ void RunNext() {
+ bool done_loop;
+ if (delta_ > 0) {
+ done_loop = *iter_ >= limit_;
+ } else {
+ done_loop = *iter_ <= limit_;
+ }
+ if (done_loop) {
+ Finish(Status::OK());
+ return;
+ }
+
+ if (rets_.size() >= args_.size()) {
+ Finish(errors::InvalidArgument(
+ "For loop body returned ", rets_.size(),
+ " arguments. Expected: ", args_.size() - 1));
+ return;
+ }
+ for (int i = 0; i < rets_.size(); ++i) {
+ args_[1 + i] = std::move(rets_[i]);
+ }
+ rets_.clear();
+ lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
+ [this](const Status& s) {
+ if (s.ok()) {
+ *iter_ += delta_;
+ RunNext();
+ } else {
+ Finish(s);
+ }
+ });
+ }
+
+ void Finish(Status s) {
+ if (s.ok()) {
+ s = SetOutputs(kernel_, ctx_, rets_);
+ }
+ ctx_->SetStatus(s);
+ done_();
+ delete this;
+ }
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
+REGISTER_KERNEL_BUILDER(Name("For")
+ .Device(DEVICE_GPU)
+ .HostMemory("start")
+ .HostMemory("limit")
+ .HostMemory("delta"),
+ ForOp);
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index e3872fee0e..57b7798ba0 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/initializable_lookup_table.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -62,8 +63,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
mutex_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
- table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)),
- default_val);
+ table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
}
return Status::OK();
@@ -78,9 +78,8 @@ class MutableHashTableOfScalars final : public LookupInterface {
table_.clear();
}
for (int64 i = 0; i < key_values.size(); ++i) {
- gtl::InsertOrUpdate(&table_,
- SubtleMustCopyUnlessStringOrFloat(key_values(i)),
- SubtleMustCopyUnlessStringOrFloat(value_values(i)));
+ gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
+ SubtleMustCopyIfIntegral(value_values(i)));
}
return Status::OK();
}
@@ -172,8 +171,8 @@ class MutableHashTableOfTensors final : public LookupInterface {
mutex_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
- ValueArray* value_vec = gtl::FindOrNull(
- table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)));
+ ValueArray* value_vec =
+ gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
if (value_vec != nullptr) {
for (int64 j = 0; j < value_dim; j++) {
value_values(i, j) = value_vec->at(j);
@@ -203,8 +202,8 @@ class MutableHashTableOfTensors final : public LookupInterface {
V value = value_values(i, j);
value_vec.push_back(value);
}
- gtl::InsertOrUpdate(
- &table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), value_vec);
+ gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
+ value_vec);
}
return Status::OK();
}
@@ -379,15 +378,14 @@ class MutableDenseHashTable final : public LookupInterface {
for (int64 j = 0; j < value_size; ++j) {
// TODO(andreasst): check if we can get rid of SubtleMustCopy
// here and elsewhere in this file.
- value_matrix(i, j) = SubtleMustCopyUnlessStringOrFloat(
- value_buckets_matrix(bucket_index, j));
+ value_matrix(i, j) =
+ SubtleMustCopyIfIntegral(value_buckets_matrix(bucket_index, j));
}
break;
}
if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) {
for (int64 j = 0; j < value_size; ++j) {
- value_matrix(i, j) =
- SubtleMustCopyUnlessStringOrFloat(default_flat(j));
+ value_matrix(i, j) = SubtleMustCopyIfIntegral(default_flat(j));
}
break;
}
@@ -531,7 +529,7 @@ class MutableDenseHashTable final : public LookupInterface {
if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
for (int64 j = 0; j < value_size; ++j) {
value_buckets_matrix(bucket_index, j) =
- SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j));
+ SubtleMustCopyIfIntegral(value_matrix(i, j));
}
break;
}
@@ -539,11 +537,11 @@ class MutableDenseHashTable final : public LookupInterface {
++num_entries_;
for (int64 j = 0; j < key_size; ++j) {
key_buckets_matrix(bucket_index, j) =
- SubtleMustCopyUnlessStringOrFloat(key_matrix(i, j));
+ SubtleMustCopyIfIntegral(key_matrix(i, j));
}
for (int64 j = 0; j < value_size; ++j) {
value_buckets_matrix(bucket_index, j) =
- SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j));
+ SubtleMustCopyIfIntegral(value_matrix(i, j));
}
break;
}
@@ -849,6 +847,7 @@ REGISTER_KERNEL(string, int64);
REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, Variant);
#undef REGISTER_KERNEL
@@ -899,6 +898,7 @@ REGISTER_KERNEL(int64, double);
REGISTER_KERNEL(string, float);
REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(int64, bool);
+REGISTER_KERNEL(int64, Variant);
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 3657fd5b6a..29a0cc91fe 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -125,19 +125,21 @@ namespace lookup {
// integral types. However non-integer variables are not allowed and therefore
// the local copy is unnecessary.
template <typename T>
-T SubtleMustCopyUnlessStringOrFloat(const T& value) {
+T SubtleMustCopyIfIntegral(const T& value) {
return internal::SubtleMustCopy(value);
}
-inline const string& SubtleMustCopyUnlessStringOrFloat(const string& value) {
+inline const string& SubtleMustCopyIfIntegral(const string& value) {
return value;
}
-inline const float SubtleMustCopyUnlessStringOrFloat(const float value) {
+inline const float SubtleMustCopyIfIntegral(const float value) { return value; }
+
+inline const double SubtleMustCopyIfIntegral(const double value) {
return value;
}
-inline const double SubtleMustCopyUnlessStringOrFloat(const double value) {
+inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) {
return value;
}
@@ -204,8 +206,8 @@ class HashTable : public InitializableLookupTable {
const auto key_values = keys.flat<K>();
const auto value_values = values.flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) {
- const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
- const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i));
+ const K key = SubtleMustCopyIfIntegral(key_values(i));
+ const V value = SubtleMustCopyIfIntegral(value_values(i));
const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value);
if (previous_value != value) {
return errors::FailedPrecondition(
@@ -224,8 +226,7 @@ class HashTable : public InitializableLookupTable {
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
- *table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)),
- default_val);
+ *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
}
return Status::OK();
}
diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h
index ad606803ee..6c19f9841c 100644
--- a/tensorflow/core/kernels/queue_op.h
+++ b/tensorflow/core/kernels/queue_op.h
@@ -43,6 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
void Compute(OpKernelContext* context) override {
ResourceOpKernel<QueueInterface>::Compute(context);
+ mutex_lock l(mu_);
if (resource_ && context->track_allocations()) {
context->record_persistent_memory_allocation(resource_->MemoryUsed());
}
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index d1675f27dd..f49a05c70a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -250,8 +250,9 @@ class AssignVariableOp : public OpKernel {
// Copying is unnecessary if we are the last user of the value
// tensor, we can just adopt the input tensor's buffer instead.
- std::unique_ptr<Tensor> input_alias =
- context->forward_input(1, dtype_, value.shape(), DEVICE_MEMORY, attr);
+ std::unique_ptr<Tensor> input_alias = context->forward_input(
+ 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_,
+ value.shape(), DEVICE_MEMORY, attr);
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
if (input_alias) {
@@ -363,9 +364,36 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(DT_VARIANT)));
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+
+ // Copying is unnecessary if we are the last user of the value
+ // tensor, we can just adopt the input tensor's buffer instead.
+ // Note that Variant objects themselves always reside on host.
+ std::unique_ptr<Tensor> input_alias = context->forward_input(
+ 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
+ value.shape(), HOST_MEMORY, attr);
+
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
*variable->tensor() = Tensor(DT_VARIANT, value.shape());
+
+ if (input_alias) {
+ *variable->tensor() = *input_alias;
+ return;
+ }
+
+ // Need to copy, but maybe we can re-use variable's buffer?
+ if (!variable->tensor()->RefCountIsOne() ||
+ !variable->tensor()->shape().IsSameSize(value.shape())) {
+ PersistentTensor unused;
+ Tensor* tmp;
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_VARIANT, value.shape(),
+ &unused, &tmp, attr));
+ *variable->tensor() = *tmp;
+ }
+
const auto elements_in = value.flat<Variant>();
auto elements_out = variable->tensor()->flat<Variant>();
auto copy_fn = std::bind(&VariantCopyFn<Device>, context,
@@ -577,7 +605,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
#if GOOGLE_CUDA
#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GATHER_GPU);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
index 4b5df7aff0..4ebb7fbcc7 100644
--- a/tensorflow/core/kernels/sparse_cross_op.cc
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
@@ -419,7 +419,7 @@ class SparseCrossOp : public OpKernel {
context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
errors::InvalidArgument(
"Dense inputs should be a matrix but received shape ",
- indices_list_in[i].shape().DebugString(), " at position ", i));
+ dense_list_in[i].shape().DebugString(), " at position ", i));
OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
errors::InvalidArgument("Expected batch size ", batch_size,
" got ", dense_list_in[i].dim_size(0),
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
index 5bd79778a6..0b006fa2b4 100644
--- a/tensorflow/core/lib/core/stringpiece.cc
+++ b/tensorflow/core/lib/core/stringpiece.cc
@@ -55,6 +55,4 @@ StringPiece StringPiece::substr(size_t pos, size_t n) const {
return StringPiece(data_ + pos, n);
}
-const StringPiece::size_type StringPiece::npos = size_type(-1);
-
} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index 79409cce4b..835b938cbf 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -65,7 +65,7 @@ class StringPiece {
iterator begin() const { return data_; }
iterator end() const { return data_ + size_; }
- static const size_t npos;
+ static const size_t npos = size_type(-1);
// Return the ith byte in the referenced data.
// REQUIRES: n < size()
diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc
index 64852943ad..0c24c660a2 100644
--- a/tensorflow/core/lib/io/format.cc
+++ b/tensorflow/core/lib/io/format.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <limits>
+
#include "tensorflow/core/lib/io/format.h"
#include "tensorflow/core/lib/core/coding.h"
@@ -84,6 +86,11 @@ Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle,
// Read the block contents as well as the type/crc footer.
// See table_builder.cc for the code that built this structure.
size_t n = static_cast<size_t>(handle.size());
+
+ if (kBlockTrailerSize > std::numeric_limits<size_t>::max() - n) {
+ return errors::DataLoss("handle.size() too big");
+ }
+
char* buf = new char[n + kBlockTrailerSize];
StringPiece contents;
Status s = file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf);
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index 516decc3c0..8f34baa7de 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <locale>
#include <unordered_map>
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -203,7 +204,7 @@ bool safe_strto64(StringPiece str, int64* value) {
int64 vlimit = kint64max;
int sign = 1;
- if (str.Consume("-")) {
+ if (str_util::ConsumePrefix(&str, "-")) {
sign = -1;
// Different limit for positive and negative integers.
vlimit = kint64min;
@@ -265,7 +266,7 @@ bool safe_strto32(StringPiece str, int32* value) {
int64 vmax = kint32max;
int sign = 1;
- if (str.Consume("-")) {
+ if (str_util::ConsumePrefix(&str, "-")) {
sign = -1;
// Different max for positive and negative integers.
++vmax;
diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc
index fee8a6f93e..ede9f4d390 100644
--- a/tensorflow/core/lib/strings/ordered_code_test.cc
+++ b/tensorflow/core/lib/strings/ordered_code_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -128,7 +129,7 @@ void TestWriteAppends(T first, U second) {
string encoded_first_only = encoded;
OCWriteToString<U>(&encoded, second);
EXPECT_NE(encoded, encoded_first_only);
- EXPECT_TRUE(StringPiece(encoded).starts_with(encoded_first_only));
+ EXPECT_TRUE(str_util::StartsWith(encoded, encoded_first_only));
}
template <typename T>
diff --git a/tensorflow/core/lib/strings/scanner.h b/tensorflow/core/lib/strings/scanner.h
index d3b63357ee..c82e771368 100644
--- a/tensorflow/core/lib/strings/scanner.h
+++ b/tensorflow/core/lib/strings/scanner.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -75,14 +76,14 @@ class Scanner {
// Consume the next s.size() characters of the input, if they match <s>. If
// they don't match <s>, this is a no-op.
Scanner& ZeroOrOneLiteral(StringPiece s) {
- cur_.Consume(s);
+ str_util::ConsumePrefix(&cur_, s);
return *this;
}
// Consume the next s.size() characters of the input, if they match <s>. If
// they don't match <s>, then GetResult will ultimately return false.
Scanner& OneLiteral(StringPiece s) {
- if (!cur_.Consume(s)) {
+ if (!str_util::ConsumePrefix(&cur_, s)) {
error_ = true;
}
return *this;
diff --git a/tensorflow/core/lib/wav/wav_io_test.cc b/tensorflow/core/lib/wav/wav_io_test.cc
index d8a83fc464..9e41da6a20 100644
--- a/tensorflow/core/lib/wav/wav_io_test.cc
+++ b/tensorflow/core/lib/wav/wav_io_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -203,7 +204,7 @@ TEST(WavIO, ChunkSizeOverflow) {
wav_data_string, &decoded_audio, &decoded_sample_count,
&decoded_channel_count, &decoded_sample_rate);
EXPECT_FALSE(decode_status.ok());
- EXPECT_TRUE(StringPiece(decode_status.error_message()).contains("too large"))
+ EXPECT_TRUE(str_util::StrContains(decode_status.error_message(), "too large"))
<< decode_status.error_message();
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 7cdf36f423..10b24c2d34 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -20672,6 +20672,38 @@ op {
is_stateful: true
}
op {
+ name: "For"
+ input_arg {
+ name: "start"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "limit"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "delta"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+}
+op {
name: "FractionalAvgPool"
input_arg {
name: "value"
@@ -22755,6 +22787,45 @@ op {
is_stateful: true
}
op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+}
+op {
name: "Igamma"
input_arg {
name: "a"
@@ -68076,6 +68147,31 @@ op {
}
}
op {
+ name: "While"
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "cond"
+ type: "func"
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+ is_stateful: true
+}
+op {
name: "WholeFileReader"
output_arg {
name: "reader_handle"
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index 4b21fac80a..792686cae1 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -50,6 +50,7 @@ REGISTER_OP("RemoteCall")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
+// TODO(drpng): remove this.
REGISTER_OP("_If")
.Input("cond: Tcond")
.Input("input: Tin")
@@ -76,8 +77,18 @@ else_branch: A function that takes 'inputs' and returns a list of
tensors. whose types are the same as what then_branch returns.
)doc");
-// TODO(b/37549631) setting the While Op to always be stateful is too
-// conservative.
+REGISTER_OP("If")
+ .Input("cond: Tcond")
+ .Input("input: Tin")
+ .Output("output: Tout")
+ .Attr("Tcond: type")
+ .Attr("Tin: list(type)")
+ .Attr("Tout: list(type)")
+ .Attr("then_branch: func")
+ .Attr("else_branch: func")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+// TODO(drpng): remove this.
REGISTER_OP("_While")
.Input("input: T")
.Output("output: T")
@@ -108,4 +119,30 @@ body: A function that takes a list of tensors and returns another
by T.
)doc");
+// TODO(b/37549631) setting the While Op to always be stateful is too
+// conservative.
+REGISTER_OP("While")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: list(type) >= 0")
+ .Attr("cond: func")
+ .Attr("body: func")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(i));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("For")
+ .Input("start: int32")
+ .Input("limit: int32")
+ .Input("delta: int32")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: list(type) >= 0")
+ .Attr("body: func")
+ .SetShapeFn(shape_inference::UnknownShape);
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 8dcd3e815f..da38a6bc24 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
@@ -362,7 +363,7 @@ class MathGradTest : public ::testing::Test {
};
void HasError(const Status& s, const string& substr) {
- EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
<< s << ", expected substring " << substr;
}
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index ca3772e6f8..8f974d5367 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -239,20 +240,21 @@ TEST(MathOpsTest, Select_ShapeFn) {
// Expect an error when the shapes can't be merged.
handle_data[2]->at(0).first = shape_proto({2, 2});
- EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
- .contains("must be equal, but are 1 and 2"));
+ EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(),
+ "must be equal, but are 1 and 2"));
handle_data[2]->at(0).first = i1; // restore to valid
// Expect an error when the types can't be merged.
handle_data[2]->at(1).second = DT_INT64;
- EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
- .contains("pointing to different dtypes"));
+ EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(),
+ "pointing to different dtypes"));
handle_data[2]->at(1).second = DT_INT32; // restore to valid
// Expect an error when different numbers of tensors are merged.
handle_data[2]->push_back({i1, DT_FLOAT});
- EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
- .contains("pointing to different numbers of tensors"));
+ EXPECT_TRUE(
+ str_util::StrContains(run_inference_for_handles().error_message(),
+ "pointing to different numbers of tensors"));
handle_data[2]->pop_back(); // restore to valid.
}
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 42a68cb712..5764976aee 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -9780,6 +9780,38 @@ op {
is_stateful: true
}
op {
+ name: "For"
+ input_arg {
+ name: "start"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "limit"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "delta"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+}
+op {
name: "FractionalAvgPool"
input_arg {
name: "value"
@@ -11184,6 +11216,45 @@ op {
is_stateful: true
}
op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+}
+op {
name: "Igamma"
input_arg {
name: "a"
@@ -32937,6 +33008,31 @@ op {
}
}
op {
+ name: "While"
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "cond"
+ type: "func"
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+ is_stateful: true
+}
+op {
name: "WholeFileReader"
output_arg {
name: "reader_handle"
diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc
index 4df62734e9..e597a490d6 100644
--- a/tensorflow/core/platform/abi.cc
+++ b/tensorflow/core/platform/abi.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/platform/abi.h"
-#if defined(PLATFORM_WINDOWS)
+#if defined(_MSC_VER)
#include <windows.h>
#include <cstring>
#else
@@ -26,19 +26,19 @@ limitations under the License.
#include <memory>
#include <string>
-#if defined(PLATFORM_WINDOWS)
+#if defined(_MSC_VER)
extern "C" char* __unDName(char* output_string, const char* name,
int max_string_length, void* (*p_alloc)(std::size_t),
void (*p_free)(void*), unsigned short disable_flags);
-#endif // defined(PLATFORM_WINDOWS)
+#endif // defined(_MSC_VER)
namespace tensorflow {
namespace port {
std::string MaybeAbiDemangle(const char* name) {
-#if defined(PLATFORM_WINDOWS)
+#if defined(_MSC_VER)
std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc,
std::free,
static_cast<unsigned short>(0))};
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 3ee7be3c4e..be84316c48 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -85,6 +85,7 @@ cc_library(
":retrying_utils",
":time_util",
"//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@jsoncpp_git//:jsoncpp",
],
@@ -263,6 +264,7 @@ tf_cc_test(
deps = [
":gcs_file_system",
":http_request_fake",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 1691826483..3c0dc13d75 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -172,7 +172,7 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
fname);
}
- objectp.Consume("/");
+ str_util::ConsumePrefix(&objectp, "/");
*object = objectp.ToString();
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("GCS path doesn't contain an object name: ",
@@ -535,7 +535,8 @@ class GcsWritableFile : public WritableFile {
*uploaded = 0;
} else {
StringPiece range_piece(received_range);
- range_piece.Consume("bytes="); // May or may not be present.
+ str_util::ConsumePrefix(&range_piece,
+ "bytes="); // May or may not be present.
std::vector<int64> range_parts;
if (!str_util::SplitAndParseAsInts(range_piece, '-', &range_parts) ||
range_parts.size() != 2) {
@@ -1172,7 +1173,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
// 'object_prefix', which is part of 'dirname', should be removed from
// the beginning of 'name'.
StringPiece relative_path(name);
- if (!relative_path.Consume(object_prefix)) {
+ if (!str_util::ConsumePrefix(&relative_path, object_prefix)) {
return errors::Internal(strings::StrCat(
"Unexpected response: the returned file name ", name,
" doesn't match the prefix ", object_prefix));
@@ -1201,7 +1202,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
}
const string& prefix_str = prefix.asString();
StringPiece relative_path(prefix_str);
- if (!relative_path.Consume(object_prefix)) {
+ if (!str_util::ConsumePrefix(&relative_path, object_prefix)) {
return errors::Internal(
"Unexpected response: the returned folder name ", prefix_str,
" doesn't match the prefix ", object_prefix);
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 8516421614..2fbde9b6a7 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include <fstream>
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/cloud/http_request_fake.h"
#include "tensorflow/core/platform/test.h"
@@ -584,8 +585,9 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) {
TF_EXPECT_OK(file->Append("content2"));
const auto& status = file->Close();
EXPECT_EQ(errors::Code::ABORTED, status.code());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("All 10 retry attempts failed. The last failure: "
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(),
+ "All 10 retry attempts failed. The last failure: "
"Unavailable: important HTTP error 503"))
<< status;
}
@@ -641,13 +643,12 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
const auto& status = file->Close();
EXPECT_EQ(errors::Code::UNAVAILABLE, status.code());
EXPECT_TRUE(
- StringPiece(status.error_message())
- .contains(
- "Upload to gs://bucket/path/writeable.txt failed, caused by: "
- "Not found: important HTTP error 410"))
+ str_util::StrContains(status.error_message(),
+ "Upload to gs://bucket/path/writeable.txt failed, "
+ "caused by: Not found: important HTTP error 410"))
<< status;
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("when uploading gs://bucket/path/writeable.txt"))
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(), "when uploading gs://bucket/path/writeable.txt"))
<< status;
}
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index d3f763bb3c..ee6886fef7 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/platform/cloud/retrying_file_system.h"
#include <fstream>
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -245,7 +246,7 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) {
char scratch[10];
const auto& status = random_access_file->Read(0, 10, &result, scratch);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -399,7 +400,7 @@ TEST(RetryingFileSystemTest, NewWritableFile_AllRetriesFailed) {
// Use it and check the results.
const auto& status = writable_file->Sync();
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -428,7 +429,7 @@ TEST(RetryingFileSystemTest, NewReadOnlyMemoryRegionFromFile_AllRetriesFailed) {
const auto& status =
fs.NewReadOnlyMemoryRegionFromFile("filename.txt", &result);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -454,7 +455,7 @@ TEST(RetryingFileSystemTest, GetChildren_AllRetriesFailed) {
std::vector<string> result;
const auto& status = fs.GetChildren("gs://path", &result);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -481,7 +482,7 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_AllRetriesFailed) {
std::vector<string> result;
const auto& status = fs.GetMatchingPaths("gs://path/dir", &result);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -506,7 +507,7 @@ TEST(RetryingFileSystemTest, DeleteFile_AllRetriesFailed) {
std::vector<string> result;
const auto& status = fs.DeleteFile("gs://path/file.txt");
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -531,7 +532,7 @@ TEST(RetryingFileSystemTest, CreateDir_AllRetriesFailed) {
std::vector<string> result;
const auto& status = fs.CreateDir("gs://path/newdir");
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -556,7 +557,7 @@ TEST(RetryingFileSystemTest, DeleteDir_AllRetriesFailed) {
std::vector<string> result;
const auto& status = fs.DeleteDir("gs://path/dir");
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -582,7 +583,7 @@ TEST(RetryingFileSystemTest, GetFileSize_AllRetriesFailed) {
uint64 size;
const auto& status = fs.GetFileSize("gs://path/file.txt", &size);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -605,7 +606,7 @@ TEST(RetryingFileSystemTest, RenameFile_AllRetriesFailed) {
const auto& status = fs.RenameFile("old_name", "new_name");
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -630,7 +631,7 @@ TEST(RetryingFileSystemTest, Stat_AllRetriesFailed) {
FileStatistics stat;
const auto& status = fs.Stat("file_name", &stat);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -642,7 +643,7 @@ TEST(RetryingFileSystemTest, FileExists_AllRetriesFailed) {
const auto& status = fs.FileExists("file_name");
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -677,7 +678,7 @@ TEST(RetryingFileSystemTest, IsDirectory_AllRetriesFailed) {
const auto& status = fs.IsDirectory("gs://path/dir");
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
@@ -706,7 +707,7 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) {
const auto& status =
fs.DeleteRecursively("gs://path/dir", &undeleted_files, &undeleted_dirs);
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Retriable error #10"))
+ str_util::StrContains(status.error_message(), "Retriable error #10"))
<< status;
}
diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc
index 6eb340e094..1b6527618a 100644
--- a/tensorflow/core/platform/cloud/retrying_utils_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/platform/cloud/retrying_utils.h"
#include <fstream>
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -31,10 +32,9 @@ TEST(RetryingUtilsTest, CallWithRetries_RetryDelays) {
const auto& status = RetryingUtils::CallWithRetries(f, 500000L, sleep);
EXPECT_EQ(errors::Code::ABORTED, status.code());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("All 10 retry attempts "
- "failed. The last failure: "
- "Unavailable: Failed."))
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(),
+ "All 10 retry attempts failed. The last failure: Unavailable: Failed."))
<< status;
EXPECT_EQ(10, requested_delays.size());
diff --git a/tensorflow/core/platform/default/tracing_impl.h b/tensorflow/core/platform/default/tracing_impl.h
index e813e4a17a..7834548896 100644
--- a/tensorflow/core/platform/default/tracing_impl.h
+++ b/tensorflow/core/platform/default/tracing_impl.h
@@ -22,7 +22,6 @@ limitations under the License.
// IWYU pragma: friend third_party/tensorflow/core/platform/tracing.h
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/tracing.h"
diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc
index 3631d9ddf9..82cbc43b4f 100644
--- a/tensorflow/core/platform/denormal.cc
+++ b/tensorflow/core/platform/denormal.cc
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <tuple>
+
#include "tensorflow/core/platform/denormal.h"
-#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index a2f42f44ac..b55e94d552 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <deque>
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -28,28 +27,6 @@ limitations under the License.
namespace tensorflow {
-namespace {
-
-constexpr int kNumThreads = 8;
-
-// Run a function in parallel using a ThreadPool, but skip the ThreadPool
-// on the iOS platform due to its problems with more than a few threads.
-void ForEach(int first, int last, const std::function<void(int)>& f) {
-#if TARGET_OS_IPHONE
- for (int i = first; i < last; i++) {
- f(i);
- }
-#else
- int num_threads = std::min(kNumThreads, last - first);
- thread::ThreadPool threads(Env::Default(), "ForEach", num_threads);
- for (int i = first; i < last; i++) {
- threads.Schedule([f, i] { f(i); });
- }
-#endif
-}
-
-} // anonymous namespace
-
FileSystem::~FileSystem() {}
string FileSystem::TranslateName(const string& name) const {
@@ -94,76 +71,6 @@ bool FileSystem::FilesExist(const std::vector<string>& files,
return result;
}
-Status FileSystem::GetMatchingPaths(const string& pattern,
- std::vector<string>* results) {
- results->clear();
- // Find the fixed prefix by looking for the first wildcard.
- string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\"));
- string eval_pattern = pattern;
- std::vector<string> all_files;
- string dir = io::Dirname(fixed_prefix).ToString();
- // If dir is empty then we need to fix up fixed_prefix and eval_pattern to
- // include . as the top level directory.
- if (dir.empty()) {
- dir = ".";
- fixed_prefix = io::JoinPath(dir, fixed_prefix);
- eval_pattern = io::JoinPath(dir, pattern);
- }
-
- // Setup a BFS to explore everything under dir.
- std::deque<string> dir_q;
- dir_q.push_back(dir);
- Status ret; // Status to return.
- // children_dir_status holds is_dir status for children. It can have three
- // possible values: OK for true; FAILED_PRECONDITION for false; CANCELLED
- // if we don't calculate IsDirectory (we might do that because there isn't
- // any point in exploring that child path).
- std::vector<Status> children_dir_status;
- while (!dir_q.empty()) {
- string current_dir = dir_q.front();
- dir_q.pop_front();
- std::vector<string> children;
- Status s = GetChildren(current_dir, &children);
- ret.Update(s);
- if (children.empty()) continue;
- // This IsDirectory call can be expensive for some FS. Parallelizing it.
- children_dir_status.resize(children.size());
- ForEach(0, children.size(),
- [this, &current_dir, &children, &fixed_prefix,
- &children_dir_status](int i) {
- const string child_path = io::JoinPath(current_dir, children[i]);
- // In case the child_path doesn't start with the fixed_prefix then
- // we don't need to explore this path.
- if (!str_util::StartsWith(child_path, fixed_prefix)) {
- children_dir_status[i] = Status(tensorflow::error::CANCELLED,
- "Operation not needed");
- } else {
- children_dir_status[i] = IsDirectory(child_path);
- }
- });
- for (int i = 0; i < children.size(); ++i) {
- const string child_path = io::JoinPath(current_dir, children[i]);
- // If the IsDirectory call was cancelled we bail.
- if (children_dir_status[i].code() == tensorflow::error::CANCELLED) {
- continue;
- }
- // If the child is a directory add it to the queue.
- if (children_dir_status[i].ok()) {
- dir_q.push_back(child_path);
- }
- all_files.push_back(child_path);
- }
- }
-
- // Match all obtained files to the input pattern.
- for (const auto& f : all_files) {
- if (Env::Default()->MatchPath(f, eval_pattern)) {
- results->push_back(f);
- }
- }
- return ret;
-}
-
Status FileSystem::DeleteRecursively(const string& dirname,
int64* undeleted_files,
int64* undeleted_dirs) {
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 8f99766e15..077b1d79cf 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -138,10 +138,8 @@ class FileSystem {
/// * OK - no errors
/// * UNIMPLEMENTED - Some underlying functions (like GetChildren) are not
/// implemented
- /// The default implementation uses a combination of GetChildren, MatchPath
- /// and IsDirectory.
virtual Status GetMatchingPaths(const string& pattern,
- std::vector<string>* results);
+ std::vector<string>* results) = 0;
/// \brief Obtains statistics for the given path.
virtual Status Stat(const string& fname, FileStatistics* stat) = 0;
diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc
new file mode 100644
index 0000000000..22c5057281
--- /dev/null
+++ b/tensorflow/core/platform/file_system_helper.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/file_system_helper.h"
+
+#include <deque>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/platform.h"
+
+namespace tensorflow {
+namespace internal {
+
+namespace {
+
+constexpr int kNumThreads = 8;
+
+// Run a function in parallel using a ThreadPool, but skip the ThreadPool
+// on the iOS platform due to its problems with more than a few threads.
+void ForEach(int first, int last, const std::function<void(int)>& f) {
+#if TARGET_OS_IPHONE
+ for (int i = first; i < last; i++) {
+ f(i);
+ }
+#else
+ int num_threads = std::min(kNumThreads, last - first);
+ thread::ThreadPool threads(Env::Default(), "ForEach", num_threads);
+ for (int i = first; i < last; i++) {
+ threads.Schedule([f, i] { f(i); });
+ }
+#endif
+}
+
+} // namespace
+
+Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern,
+ std::vector<string>* results) {
+ results->clear();
+ // Find the fixed prefix by looking for the first wildcard.
+ string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\"));
+ string eval_pattern = pattern;
+ std::vector<string> all_files;
+ string dir = io::Dirname(fixed_prefix).ToString();
+ // If dir is empty then we need to fix up fixed_prefix and eval_pattern to
+ // include . as the top level directory.
+ if (dir.empty()) {
+ dir = ".";
+ fixed_prefix = io::JoinPath(dir, fixed_prefix);
+ eval_pattern = io::JoinPath(dir, pattern);
+ }
+
+ // Setup a BFS to explore everything under dir.
+ std::deque<string> dir_q;
+ dir_q.push_back(dir);
+ Status ret; // Status to return.
+ // children_dir_status holds is_dir status for children. It can have three
+ // possible values: OK for true; FAILED_PRECONDITION for false; CANCELLED
+ // if we don't calculate IsDirectory (we might do that because there isn't
+ // any point in exploring that child path).
+ std::vector<Status> children_dir_status;
+ while (!dir_q.empty()) {
+ string current_dir = dir_q.front();
+ dir_q.pop_front();
+ std::vector<string> children;
+ Status s = fs->GetChildren(current_dir, &children);
+ ret.Update(s);
+ if (children.empty()) continue;
+ // This IsDirectory call can be expensive for some FS. Parallelizing it.
+ children_dir_status.resize(children.size());
+ ForEach(0, children.size(),
+ [fs, &current_dir, &children, &fixed_prefix,
+ &children_dir_status](int i) {
+ const string child_path = io::JoinPath(current_dir, children[i]);
+ // In case the child_path doesn't start with the fixed_prefix then
+ // we don't need to explore this path.
+ if (!str_util::StartsWith(child_path, fixed_prefix)) {
+ children_dir_status[i] = Status(tensorflow::error::CANCELLED,
+ "Operation not needed");
+ } else {
+ children_dir_status[i] = fs->IsDirectory(child_path);
+ }
+ });
+ for (int i = 0; i < children.size(); ++i) {
+ const string child_path = io::JoinPath(current_dir, children[i]);
+ // If the IsDirectory call was cancelled we bail.
+ if (children_dir_status[i].code() == tensorflow::error::CANCELLED) {
+ continue;
+ }
+ // If the child is a directory add it to the queue.
+ if (children_dir_status[i].ok()) {
+ dir_q.push_back(child_path);
+ }
+ all_files.push_back(child_path);
+ }
+ }
+
+ // Match all obtained files to the input pattern.
+ for (const auto& f : all_files) {
+ if (env->MatchPath(f, eval_pattern)) {
+ results->push_back(f);
+ }
+ }
+ return ret;
+}
+
+} // namespace internal
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/file_system_helper.h b/tensorflow/core/platform/file_system_helper.h
new file mode 100644
index 0000000000..8d812b0e38
--- /dev/null
+++ b/tensorflow/core/platform/file_system_helper.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_
+#define TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class FileSystem;
+class Env;
+
+namespace internal {
+
+// Given a pattern, stores in 'results' the set of paths (in the given file
+// system) that match that pattern.
+//
+// This helper may be used by implementations of FileSystem::GetMatchingPaths()
+// in order to provide parallel scanning of subdirectories (except on iOS).
+//
+// Arguments:
+// fs: may not be null and will be used to identify directories and list
+// their contents.
+// env: may not be null and will be used to check if a match has been found.
+// pattern: see FileSystem::GetMatchingPaths() for details.
+// results: will be cleared and may not be null.
+//
+// Returns an error status if any call to 'fs' failed.
+Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern,
+ std::vector<string>* results);
+
+} // namespace internal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 74863293a3..9a71fbe2b7 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/posix/error.h"
@@ -396,6 +397,11 @@ Status HadoopFileSystem::GetChildren(const string& dir,
return Status::OK();
}
+Status HadoopFileSystem::GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) {
+ return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
+}
+
Status HadoopFileSystem::DeleteFile(const string& fname) {
hdfsFS fs = nullptr;
TF_RETURN_IF_ERROR(Connect(fname, &fs));
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.h b/tensorflow/core/platform/hadoop/hadoop_file_system.h
index 5f2b222622..6af7a698ff 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.h
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.h
@@ -49,6 +49,9 @@ class HadoopFileSystem : public FileSystem {
Status GetChildren(const string& dir, std::vector<string>* result) override;
+ Status GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) override;
+
Status DeleteFile(const string& fname) override;
Status CreateDir(const string& name) override;
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
index 6ba2f04d0f..b207d34749 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/test.h"
@@ -197,7 +198,7 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) {
// Skip the test if we're not testing on HDFS. Hadoop's local filesystem
// implementation makes no guarantees that writable files are readable while
// being written.
- if (!StringPiece(fname).starts_with("hdfs://")) {
+ if (!str_util::StartsWith(fname, "hdfs://")) {
return;
}
diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h
index 7bb9fc264f..fca3a2332d 100644
--- a/tensorflow/core/platform/mem.h
+++ b/tensorflow/core/platform/mem.h
@@ -59,7 +59,7 @@ void MallocExtension_ReleaseToSystem(std::size_t num_bytes);
// routine, this routine returns 0.
std::size_t MallocExtension_GetAllocatedSize(const void* p);
-// Returns the amount of RAM available in kB, or INT64_MAX if unknown.
+// Returns the amount of RAM available in bytes, or INT64_MAX if unknown.
int64 AvailableRam();
} // namespace port
diff --git a/tensorflow/core/platform/null_file_system.h b/tensorflow/core/platform/null_file_system.h
index 008e6d54d0..420abc1ada 100644
--- a/tensorflow/core/platform/null_file_system.h
+++ b/tensorflow/core/platform/null_file_system.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/file_system_helper.h"
namespace tensorflow {
@@ -65,6 +66,11 @@ class NullFileSystem : public FileSystem {
return errors::Unimplemented("GetChildren unimplemented");
}
+ Status GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) override {
+ return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
+ }
+
Status DeleteFile(const string& fname) override {
return errors::Unimplemented("DeleteFile unimplemented");
}
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index 494acde803..8e316472fe 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -177,7 +177,7 @@ int64 AvailableRam() {
struct sysinfo info;
int err = sysinfo(&info);
if (err == 0) {
- return info.freeram / 1024;
+ return info.freeram;
}
#endif
return INT64_MAX;
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index 9a8021565c..47bfa020ce 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/posix/error.h"
#include "tensorflow/core/platform/posix/posix_file_system.h"
@@ -225,6 +226,11 @@ Status PosixFileSystem::GetChildren(const string& dir,
return Status::OK();
}
+Status PosixFileSystem::GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) {
+ return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
+}
+
Status PosixFileSystem::DeleteFile(const string& fname) {
Status result;
if (unlink(TranslateName(fname).c_str()) != 0) {
diff --git a/tensorflow/core/platform/posix/posix_file_system.h b/tensorflow/core/platform/posix/posix_file_system.h
index 98ffa43b8a..e8898d0a97 100644
--- a/tensorflow/core/platform/posix/posix_file_system.h
+++ b/tensorflow/core/platform/posix/posix_file_system.h
@@ -47,6 +47,9 @@ class PosixFileSystem : public FileSystem {
Status Stat(const string& fname, FileStatistics* stats) override;
+ Status GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) override;
+
Status DeleteFile(const string& fname) override;
Status CreateDir(const string& name) override;
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 301fcb9dbf..ee423699b2 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/s3/aws_logging.h"
#include "tensorflow/core/platform/s3/s3_crypto.h"
@@ -497,6 +498,11 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
return Status::OK();
}
+Status S3FileSystem::GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) {
+ return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
+}
+
Status S3FileSystem::DeleteFile(const string& fname) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h
index 31264be621..5d0565b378 100644
--- a/tensorflow/core/platform/s3/s3_file_system.h
+++ b/tensorflow/core/platform/s3/s3_file_system.h
@@ -46,6 +46,9 @@ class S3FileSystem : public FileSystem {
Status Stat(const string& fname, FileStatistics* stat) override;
+ Status GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) override;
+
Status DeleteFile(const string& fname) override;
Status CreateDir(const string& name) override;
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index 8f7bff1bb0..3c6e7b0db5 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -103,7 +103,9 @@ class Tracing {
friend class ScopedAnnotation;
friend class TraceMe;
- static std::atomic<Tracing::Engine*> tracing_engine_;
+ // TODO: TF_EXPORT is for building //tensorflow/contrib/data:_dataset_ops.so
+ // on Windows. Figure out a way to remove TF_EXPORT here.
+ TF_EXPORT static std::atomic<Tracing::Engine*> tracing_engine_;
static Tracing::Engine* engine() {
return tracing_engine_.load(std::memory_order_acquire);
}
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index f3b27ea394..174f41a993 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -166,7 +166,7 @@ int64 AvailableRam() {
MEMORYSTATUSEX statex;
statex.dwLength = sizeof(statex);
if (GlobalMemoryStatusEx(&statex)) {
- return statex.ullAvailPhys / 1024;
+ return statex.ullAvailPhys;
}
return INT64_MAX;
}
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 682e46e0fc..dc2efbeaf5 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/posix/error.h"
#include "tensorflow/core/platform/windows/error.h"
@@ -494,7 +495,8 @@ Status WindowsFileSystem::GetMatchingPaths(const string& pattern,
// but no code appears to rely on this behavior.
string converted_pattern(pattern);
std::replace(converted_pattern.begin(), converted_pattern.end(), '\\', '/');
- TF_RETURN_IF_ERROR(FileSystem::GetMatchingPaths(converted_pattern, results));
+ TF_RETURN_IF_ERROR(internal::GetMatchingPaths(this, Env::Default(),
+ converted_pattern, results));
for (string& result : *results) {
std::replace(result.begin(), result.end(), '/', '\\');
}
diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc
index e968b9c97e..96b6cc30bd 100644
--- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc
+++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -82,8 +83,8 @@ TEST_F(TFProfAdvisorTest, OperationChecker) {
(*options.mutable_checkers())[kCheckers[1]];
AdviceProto advice = advisor_->Advise(options);
EXPECT_EQ(advice.checkers().at(kCheckers[1]).reports_size(), 1);
- EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[1]).reports(0))
- .contains("NCHW"));
+ EXPECT_TRUE(str_util::StrContains(
+ advice.checkers().at(kCheckers[1]).reports(0), "NCHW"));
}
TEST_F(TFProfAdvisorTest, UtilizationChecker) {
@@ -91,16 +92,17 @@ TEST_F(TFProfAdvisorTest, UtilizationChecker) {
(*options.mutable_checkers())[kCheckers[0]];
AdviceProto advice = advisor_->Advise(options);
EXPECT_EQ(advice.checkers().at(kCheckers[0]).reports_size(), 1);
- EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[0]).reports(0))
- .contains("low utilization"));
+ EXPECT_TRUE(str_util::StrContains(
+ advice.checkers().at(kCheckers[0]).reports(0), "low utilization"));
}
TEST_F(TFProfAdvisorTest, ExpensiveOperationChecker) {
AdvisorOptionsProto options;
(*options.mutable_checkers())[kCheckers[2]];
AdviceProto advice = advisor_->Advise(options);
- EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[2]).reports(0))
- .contains("top 1 operation type: Conv2D"));
+ EXPECT_TRUE(
+ str_util::StrContains(advice.checkers().at(kCheckers[2]).reports(0),
+ "top 1 operation type: Conv2D"));
}
} // namespace tfprof
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index bb772460b0..9b6202e7b4 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -29,6 +29,14 @@ message RewriterConfig {
AGGRESSIVE = 3;
}
+ // Enum controling the number of times to run optimizers. The default is to
+ // run them once.
+ enum NumIterationsType {
+ DEFAULT_NUM_ITERS = 0;
+ ONE = 1;
+ TWO = 2;
+ }
+
// Optimize tensor layouts (default is ON)
// e.g. This will try to use NCHW layout on GPU which is faster.
Toggle layout_optimizer = 1;
@@ -51,6 +59,10 @@ message RewriterConfig {
// If true, don't remove unnecessary ops from the graph
bool disable_model_pruning = 2;
+ // Controls how many times we run the optimizers in meta optimizer (default
+ // is once).
+ NumIterationsType meta_optimizer_iterations = 12;
+
enum MemOptType {
// The default setting (SCHEDULING and SWAPPING HEURISTICS only)
DEFAULT_MEM_OPT = 0;
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 3efc703faf..480ce94fca 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -28,7 +29,9 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
const std::function<bool(string)>& hook,
bool* value_parsing_ok) {
*value_parsing_ok = true;
- if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ if (str_util::ConsumePrefix(&arg, "--") &&
+ str_util::ConsumePrefix(&arg, flag) &&
+ str_util::ConsumePrefix(&arg, "=")) {
*value_parsing_ok = hook(arg.ToString());
return true;
}
@@ -40,7 +43,9 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
const std::function<bool(int32)>& hook,
bool* value_parsing_ok) {
*value_parsing_ok = true;
- if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ if (str_util::ConsumePrefix(&arg, "--") &&
+ str_util::ConsumePrefix(&arg, flag) &&
+ str_util::ConsumePrefix(&arg, "=")) {
char extra;
int32 parsed_int32;
if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
@@ -60,7 +65,9 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
const std::function<bool(int64)>& hook,
bool* value_parsing_ok) {
*value_parsing_ok = true;
- if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ if (str_util::ConsumePrefix(&arg, "--") &&
+ str_util::ConsumePrefix(&arg, flag) &&
+ str_util::ConsumePrefix(&arg, "=")) {
char extra;
int64 parsed_int64;
if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) {
@@ -80,7 +87,8 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
const std::function<bool(bool)>& hook,
bool* value_parsing_ok) {
*value_parsing_ok = true;
- if (arg.Consume("--") && arg.Consume(flag)) {
+ if (str_util::ConsumePrefix(&arg, "--") &&
+ str_util::ConsumePrefix(&arg, flag)) {
if (arg.empty()) {
*value_parsing_ok = hook(true);
return true;
@@ -107,7 +115,9 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
const std::function<bool(float)>& hook,
bool* value_parsing_ok) {
*value_parsing_ok = true;
- if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ if (str_util::ConsumePrefix(&arg, "--") &&
+ str_util::ConsumePrefix(&arg, flag) &&
+ str_util::ConsumePrefix(&arg, "=")) {
char extra;
float parsed_float;
if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index c1bc0f3378..ff9c108f10 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -408,7 +408,7 @@ static void MergeDevNamesError(const string& name_a, const string& name_b,
DeviceNameUtils::ParsedName target_a = Name(name_a);
Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b));
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StringPiece(s.error_message()).contains(expected_error_substr))
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), expected_error_substr))
<< s;
}
diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc
index f1ec497a67..b87dce0dff 100644
--- a/tensorflow/core/util/equal_graph_def.cc
+++ b/tensorflow/core/util/equal_graph_def.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -144,7 +145,7 @@ bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
int first_control_input = actual.input_size();
for (int i = 0; i < actual.input_size(); ++i) {
- if (StringPiece(actual.input(i)).starts_with("^")) {
+ if (str_util::StartsWith(actual.input(i), "^")) {
first_control_input = i;
break;
}
@@ -240,7 +241,7 @@ uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
// Normal inputs. Order important.
int first_control_input = ndef.input_size();
for (int i = 0; i < ndef.input_size(); ++i) {
- if (StringPiece(ndef.input(i)).starts_with("^")) {
+ if (str_util::StartsWith(ndef.input(i), "^")) {
first_control_input = i;
break;
}
diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc
index a0f43d2d4a..1fa6b8bec0 100644
--- a/tensorflow/core/util/memmapped_file_system.cc
+++ b/tensorflow/core/util/memmapped_file_system.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/util/memmapped_file_system.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/memmapped_file_system.pb.h"
@@ -157,6 +158,12 @@ Status MemmappedFileSystem::GetChildren(const string& filename,
return errors::Unimplemented("memmapped format doesn't support GetChildren");
}
+Status MemmappedFileSystem::GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) {
+ return errors::Unimplemented(
+ "memmapped format doesn't support GetMatchingPaths");
+}
+
Status MemmappedFileSystem::DeleteFile(const string& filename) {
return errors::Unimplemented("memmapped format doesn't support DeleteFile");
}
@@ -236,7 +243,7 @@ Status MemmappedFileSystem::InitializeFromFile(Env* env,
}
bool MemmappedFileSystem::IsMemmappedPackageFilename(const string& filename) {
- return StringPiece(filename).starts_with(kMemmappedPackagePrefix);
+ return str_util::StartsWith(filename, kMemmappedPackagePrefix);
}
namespace {
diff --git a/tensorflow/core/util/memmapped_file_system.h b/tensorflow/core/util/memmapped_file_system.h
index 541587aeab..76cc4911f5 100644
--- a/tensorflow/core/util/memmapped_file_system.h
+++ b/tensorflow/core/util/memmapped_file_system.h
@@ -85,6 +85,8 @@ class MemmappedFileSystem : public FileSystem {
Status NewAppendableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
Status GetChildren(const string& dir, std::vector<string>* r) override;
+ Status GetMatchingPaths(const string& pattern,
+ std::vector<string>* results) override;
Status DeleteFile(const string& f) override;
Status CreateDir(const string& d) override;
Status DeleteDir(const string& d) override;
diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc
index 575c27d4ef..90ea09876e 100644
--- a/tensorflow/core/util/reporter_test.cc
+++ b/tensorflow/core/util/reporter_test.cc
@@ -29,7 +29,7 @@ namespace {
// Tests of all the error paths in log_reader.cc follow:
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(StringPiece(s).contains(expected))
+ EXPECT_TRUE(str_util::StrContains(s, expected))
<< s << " does not contain " << expected;
}
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 08f1aa7125..7f166f0ec0 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -293,7 +294,7 @@ void VersionTest(const VersionDef& version, StringPiece expected_error) {
BundleReader reader(Env::Default(), path);
EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
EXPECT_TRUE(
- StringPiece(reader.status().error_message()).starts_with(expected_error));
+ str_util::StartsWith(reader.status().error_message(), expected_error));
}
} // namespace
@@ -588,7 +589,7 @@ TEST(TensorBundleTest, Error) {
TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
EXPECT_TRUE(
- StringPiece(writer.status().ToString()).contains("duplicate key"));
+ str_util::StrContains(writer.status().ToString(), "duplicate key"));
EXPECT_FALSE(writer.Finish().ok());
}
{ // Double finish
@@ -598,7 +599,7 @@ TEST(TensorBundleTest, Error) {
}
{ // Not found.
BundleReader reader(Env::Default(), Prefix("nonexist"));
- EXPECT_TRUE(StringPiece(reader.status().ToString()).contains("Not found"));
+ EXPECT_TRUE(str_util::StrContains(reader.status().ToString(), "Not found"));
}
}
@@ -629,7 +630,7 @@ TEST(TensorBundleTest, Checksum) {
BundleReader reader(Env::Default(), Prefix(prefix));
Status status = reader.Lookup(key, &val);
EXPECT_TRUE(errors::IsDataLoss(status));
- EXPECT_TRUE(StringPiece(status.ToString()).contains(expected_msg));
+ EXPECT_TRUE(str_util::StrContains(status.ToString(), expected_msg));
};
// Corrupts a float tensor.
@@ -680,8 +681,8 @@ TEST(TensorBundleTest, Endianness) {
BundleReader reader(Env::Default(), Prefix("end"));
EXPECT_TRUE(errors::IsUnimplemented(reader.status()));
- EXPECT_TRUE(StringPiece(reader.status().ToString())
- .contains("different endianness from the reader"));
+ EXPECT_TRUE(str_util::StrContains(reader.status().ToString(),
+ "different endianness from the reader"));
}
TEST(TensorBundleTest, TruncatedTensorContents) {
diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc
index 010cc36823..3c9590e488 100644
--- a/tensorflow/core/util/tensor_slice_reader_test.cc
+++ b/tensorflow/core/util/tensor_slice_reader_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -422,7 +423,7 @@ static void VersionTest(const VersionDef& versions, const string& error) {
// Read it back in and verify that we get the expected error
TensorSliceReader reader(path, OpenTableTensorSliceReader);
EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT &&
- StringPiece(reader.status().error_message()).starts_with(error))
+ str_util::StartsWith(reader.status().error_message(), error))
<< "Expected error starting with '" << errors::InvalidArgument(error)
<< "', got '" << reader.status() << "'";
}
diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc
index ff5bfd65ae..31397f11b6 100644
--- a/tensorflow/core/util/tensor_slice_writer_test.cc
+++ b/tensorflow/core/util/tensor_slice_writer_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -333,8 +334,8 @@ TEST(TensorSliceWriteTest, SizeErrors) {
const std::vector<int8> data(300000000, -1);
Status s = writer.Add("test1", shape, slice, data.data());
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Tensor slice is too large to serialize"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Tensor slice is too large to serialize"));
}
// Add a large string tensor slice, which will fail.
@@ -344,8 +345,8 @@ TEST(TensorSliceWriteTest, SizeErrors) {
const std::vector<string> data(256 * 1024, std::string(8192, 'f'));
Status s = writer.Add("test2", shape, slice, data.data());
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Tensor slice is too large to serialize"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(), "Tensor slice is too large to serialize"));
}
}
diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md
index bdff60b39e..1ab0340ad9 100644
--- a/tensorflow/docs_src/extend/index.md
+++ b/tensorflow/docs_src/extend/index.md
@@ -16,9 +16,10 @@ TensorFlow:
for your own file and record formats.
Python is currently the only language supported by TensorFlow's API stability
-promises. However, TensorFlow also provides functionality in C++, Java, and Go,
+promises. However, TensorFlow also provides functionality in C++, Go, Java and
+[JavaScript](https://js.tensorflow.org),
plus community support for [Haskell](https://github.com/tensorflow/haskell) and
-[Rust](https://github.com/tensorflow/rust). If you'd like to create or
+[Rust](https://github.com/tensorflow/rust). If you'd like to create or
develop TensorFlow features in a language other than these languages, read the
following guide:
diff --git a/tensorflow/docs_src/mobile/tflite/devguide.md b/tensorflow/docs_src/mobile/tflite/devguide.md
index 5b521dca7b..96392a3c9b 100644
--- a/tensorflow/docs_src/mobile/tflite/devguide.md
+++ b/tensorflow/docs_src/mobile/tflite/devguide.md
@@ -88,7 +88,7 @@ Tensorflow Lite format. This process uses several model formats:
extracted from a `SavedModel`.
* *TensorFlow Lite model* (.tflite) —A serialized
[FlatBuffer](https://google.github.io/flatbuffers/) that contains TensorFlow
- Lite operators and tensors for the TensorFlow Lite interpreter, similiar to a
+ Lite operators and tensors for the TensorFlow Lite interpreter, similar to a
`FrozenGraphDef`.
### Freeze Graph
diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md
index 8db65737dc..dc5b403428 100644
--- a/tensorflow/docs_src/programmers_guide/eager.md
+++ b/tensorflow/docs_src/programmers_guide/eager.md
@@ -1,35 +1,34 @@
# Eager Execution
TensorFlow's eager execution is an imperative programming environment that
-evaluates operations immediately, without an extra graph-building step.
-Operations return concrete values instead of constructing a computational graph
-to run later. This makes it easy to get started with TensorFlow, debug models,
-reduce boilerplate code, and is fun! To follow along with this guide, run the
-code samples below in an interactive `python` interpreter.
-
-Eager execution supports most TensorFlow operations and GPU acceleration.
-Automatic differentiation uses a dynamically-constructed tape instead of a static
-graph to compute gradients. Eager execution is a flexible machine learning
-platform for research and experimentation that provides:
-
-* *An intuitive interface* —Structure your code naturally and use Python data
+evaluates operations immediately, without building graphs: operations return
+concrete values instead of constructing a computational graph to run later. This
+makes it easy to get started with TensorFlow and debug models, and it
+reduces boilerplate as well. To follow along with this guide, run the code
+samples below in an interactive `python` interpreter.
+
+Eager execution is a flexible machine learning platform for research and
+experimentation, providing:
+
+* *An intuitive interface*—Structure your code naturally and use Python data
structures. Quickly iterate on small models and small data.
-* *Easier debugging* —Call ops directly to inspect running models and test
+* *Easier debugging*—Call ops directly to inspect running models and test
changes. Use standard Python debugging tools for immediate error reporting.
-* *Natural control flow* —Use Python control flow instead of graph control flow,
- including support for dynamic models.
+* *Natural control flow*—Use Python control flow instead of graph control
+ flow, simplifying the specification of dynamic models.
-For a collection of examples running in eager execution, see:
+Eager execution supports most TensorFlow operations and GPU acceleration. For a
+collection of examples running in eager execution, see:
[tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples).
-Note: Some models may experience increased overhead with eager execution enabled.
-Performance improvements are ongoing, but please
+Note: Some models may experience increased overhead with eager execution
+enabled. Performance improvements are ongoing, but please
[file a bug](https://github.com/tensorflow/tensorflow/issues) if you find a
problem and share your benchmarks.
## Setup and basic usage
-Upgrade to TensorFlow 1.7 to include updates for eager execution:
+Upgrade to the latest version of TensorFlow:
```
$ pip install --upgrade tensorflow
@@ -110,9 +109,106 @@ environments and is useful for writing code to [work with graphs](#work_with_gra
import tensorflow.contrib.eager as tfe
```
+## Dynamic control flow
+
+A major benefit of eager execution is that all the functionality of the host
+language is available while your model is executing. So, for example,
+it is easy to write [fizzbuzz](https://en.wikipedia.org/wiki/Fizz_buzz):
+
+```py
+def fizzbuzz(max_num):
+ counter = tf.constant(0)
+ for num in range(max_num):
+ num = tf.constant(num)
+ if num % 3 == 0 and num % 5 == 0:
+ print('FizzBuzz')
+ elif num % 3 == 0:
+ print('Fizz')
+ elif num % 5 == 0:
+ print('Buzz')
+ else:
+ print(num)
+ counter += 1
+ return counter
+```
+
+This has conditionals that depend on tensor values and it prints these values
+at runtime.
+
+## Build a model
+
+Many machine learning models are represented by composing layers. When
+using TensorFlow with eager execution you can either write your own layers or
+use a layer provided in the `tf.keras.layers` package.
+
+While you can use any Python object to represent a layer,
+TensorFlow has `tf.keras.layers.Layer` as a convenient base class. Inherit from
+it to implement your own layer:
+
+```py
+class MySimpleLayer(tf.keras.layers.Layer):
+ def __init__(self, output_units):
+ self.output_units = output_units
+
+ def build(self, input):
+ # The build method gets called the first time your layer is used.
+ # Creating variables on build() allows you to make their shape depend
+ # on the input shape and hence remove the need for the user to specify
+ # full shapes. It is possible to create variables during __init__() if
+ # you already know their full shapes.
+ self.kernel = self.add_variable(
+ "kernel", [input.shape[-1], self.output_units])
+
+ def call(self, input):
+ # Override call() instead of __call__ so we can perform some bookkeeping.
+ return tf.matmul(input, self.kernel)
+```
+
+Use `tf.keras.layers.Dense` layer instead of `MySimpleLayer` above as it has
+a superset of its functionality (it can also add a bias).
+
+When composing layers into models you can use `tf.keras.Sequential` to represent
+models which are a linear stack of layers. It is easy to use for basic models:
+
+```py
+model = tf.keras.Sequential([
+ tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape
+ tf.keras.layers.Dense(10)
+])
+```
+
+Alternatively, organize models in classes by inheriting from `tf.keras.Model`.
+This is a container for layers that is a layer itself, allowing `tf.keras.Model`
+objects to contain other `tf.keras.Model` objects.
+
+```py
+class MNISTModel(tf.keras.Model):
+ def __init__(self):
+ super(MNISTModel, self).__init__()
+ self.dense1 = tf.keras.layers.Dense(units=10)
+ self.dense2 = tf.keras.layers.Dense(units=10)
+
+ def call(self, input):
+ """Run the model."""
+ result = self.dense1(input)
+ result = self.dense2(result)
+ result = self.dense2(result) # reuse variables from dense2 layer
+ return result
+
+model = MNISTModel()
+```
+
+It's not required to set an input shape for the `tf.keras.Model` class since
+the parameters are set the first time input is passed to the layer.
+
+`tf.keras.layers` classes create and contain their own model variables that
+are tied to the lifetime of their layer objects. To share layer variables, share
+their objects.
+
+
## Eager training
-### Automatic differentiation
+### Computing gradients
[Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
is useful for implementing machine learning algorithms such as
@@ -124,7 +220,7 @@ operations for computing gradients later.
not tracing. Since different operations can occur during each call, all
forward-pass operations get recorded to a "tape". To compute the gradient, play
the tape backwards and then discard. A particular `tfe.GradientTape` can only
-be computed once, subsequent calls throw a runtime error.
+compute one gradient; subsequent calls throw a runtime error.
```py
w = tfe.Variable([[1.0]])
@@ -216,189 +312,12 @@ for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
global_step=tf.train.get_or_create_global_step())
```
-#### Dynamic models
-
-`tfe.GradientTape` can also be used in dynamic models. This example for a
-[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search)
-algorithm looks like normal NumPy code, except there are gradients and is
-differentiable, despite the complex control flow:
-
-```py
-def line_search_step(fn, init_x, rate=1.0):
- with tfe.GradientTape() as tape:
- # Variables are automatically recorded, but manually watch a tensor
- tape.watch(init_x)
- value = fn(init_x)
- grad, = tape.gradient(value, [init_x])
- grad_norm = tf.reduce_sum(grad * grad)
- init_value = value
- while value > init_value - rate * grad_norm:
- x = init_x - rate * grad
- value = fn(x)
- rate /= 2.0
- return x, value
-```
-
-#### Additional functions to compute gradients
-
-`tfe.GradientTape` is a powerful interface for computing gradients, but there
-is another [Autograd](https://github.com/HIPS/autograd)-style API available for
-automatic differentiation. These functions are useful if writing math code with
-only tensors and gradient functions, and without `tfe.Variables`:
-
-* `tfe.gradients_function` —Returns a function that computes the derivatives
- of its input function parameter with respect to its arguments. The input
- function parameter must return a scalar value. When the returned function is
- invoked, it returns a list of `tf.Tensor` objects: one element for each
- argument of the input function. Since anything of interest must be passed as a
- function parameter, this becomes unwieldy if there's a dependency on many
- trainable parameters.
-* `tfe.value_and_gradients_function` —Similar to
- `tfe.gradients_function`, but when the returned function is invoked, it
- returns the value from the input function in addition to the list of
- derivatives of the input function with respect to its arguments.
-
-In the following example, `tfe.gradients_function` takes the `square`
-function as an argument and returns a function that computes the partial
-derivatives of `square` with respect to its inputs. To calculate the derivative
-of `square` at `3`, `grad(3.0)` returns `6`.
-
-```py
-def square(x):
- return tf.multiply(x, x)
-
-grad = tfe.gradients_function(square)
-
-square(3.) # => 9.0
-grad(3.) # => [6.0]
-
-# The second-order derivative of square:
-gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
-gradgrad(3.) # => [2.0]
-
-# The third-order derivative is None:
-gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0])
-gradgradgrad(3.) # => [None]
-
-
-# With flow control:
-def abs(x):
- return x if x > 0. else -x
-
-grad = tfe.gradients_function(abs)
-
-grad(3.) # => [1.0]
-grad(-3.) # => [-1.0]
-```
-
-### Custom gradients
-
-Custom gradients are an easy way to override gradients in eager and graph
-execution. Within the forward function, define the gradient with respect to the
-inputs, outputs, or intermediate results. For example, here's an easy way to clip
-the norm of the gradients in the backward pass:
-
-```py
-@tf.custom_gradient
-def clip_gradient_by_norm(x, norm):
- y = tf.identity(x)
- def grad_fn(dresult):
- return [tf.clip_by_norm(dresult, norm), None]
- return y, grad_fn
-```
-
-Custom gradients are commonly used to provide a numerically stable gradient for a
-sequence of operations:
-
-```py
-def log1pexp(x):
- return tf.log(1 + tf.exp(x))
-grad_log1pexp = tfe.gradients_function(log1pexp)
-
-# The gradient computation works fine at x = 0.
-grad_log1pexp(0.) # => [0.5]
-
-# However, x = 100 fails because of numerical instability.
-grad_log1pexp(100.) # => [nan]
-```
-
-Here, the `log1pexp` function can be analytically simplified with a custom
-gradient. The implementation below reuses the value for `tf.exp(x)` that is
-computed during the forward pass—making it more efficient by eliminating
-redundant calculations:
-
-```py
-@tf.custom_gradient
-def log1pexp(x):
- e = tf.exp(x)
- def grad(dy):
- return dy * (1 - 1 / (1 + e))
- return tf.log(1 + e), grad
-
-grad_log1pexp = tfe.gradients_function(log1pexp)
-
-# As before, the gradient computation works fine at x = 0.
-grad_log1pexp(0.) # => [0.5]
-
-# And the gradient computation also works at x = 100.
-grad_log1pexp(100.) # => [1.0]
-```
-
-
-## Build and train models
-
-There are many parameters to optimize when calculating derivatives. TensorFlow
-code is easier to read when structured into reusable classes and objects instead
-of a single top-level function. Eager execution encourages the use of the
-Keras-style layer classes in the `tf.keras.layers` module. Additionally, the
-`tf.train.Optimizer` classes provide sophisticated techniques to calculate
-parameter updates.
The following example creates a multi-layer model that classifies the standard
[MNIST handwritten digits](https://www.tensorflow.org/tutorials/layers). It
demonstrates the optimizer and layer APIs to build trainable graphs in an eager
execution environment.
-### Build a model
-
-The `tf.keras.Sequential` model is a linear stack of layers. It is easy to
-use for basic models:
-
-```py
-model = tf.keras.Sequential([
- tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape
- tf.keras.layers.Dense(10)
-])
-```
-
-Alternatively, organize models in classes by inheriting from `tf.keras.Model`.
-This is a container for layers that is a layer itself, allowing `tf.keras.Model`
-objects to contain other `tf.keras.Model` objects.
-
-```py
-class MNISTModel(tf.keras.Model):
- def __init__(self):
- super(MNISTModel, self).__init__()
- self.dense1 = tf.keras.layers.Dense(units=10)
- self.dense2 = tf.keras.layers.Dense(units=10)
-
- def call(self, input):
- """Run the model."""
- result = self.dense1(input)
- result = self.dense2(result)
- result = self.dense2(result) # reuse variables from dense2 layer
- return result
-
-model = MNISTModel()
-```
-
-It's not required to set an input shape for the `tf.keras.Model` class since
-the parameters are set the first time input is passed to the layer.
-
-`tf.keras.layers` classes create and contain their own model variables that
-are tied to the lifetime of their layer objects. To share layer variables, share
-their objects.
-
### Train a model
Even without training, call the model and inspect the output in eager execution:
@@ -415,7 +334,7 @@ result = model(batch)
This example uses the
[dataset.py module](https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py)
from the
-[TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist),
+[TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist);
download this file to your local directory. Run the following to download the
MNIST data files to your working directory and prepare a `tf.data.Dataset`
for training:
@@ -662,11 +581,141 @@ for _ in range(iterations):
...
```
+## Advanced automatic differentiation topics
+
+### Dynamic models
+
+`tfe.GradientTape` can also be used in dynamic models. This example for a
+[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search)
+algorithm looks like normal NumPy code, except there are gradients and is
+differentiable, despite the complex control flow:
+
+```py
+def line_search_step(fn, init_x, rate=1.0):
+ with tfe.GradientTape() as tape:
+ # Variables are automatically recorded, but manually watch a tensor
+ tape.watch(init_x)
+ value = fn(init_x)
+ grad, = tape.gradient(value, [init_x])
+ grad_norm = tf.reduce_sum(grad * grad)
+ init_value = value
+ while value > init_value - rate * grad_norm:
+ x = init_x - rate * grad
+ value = fn(x)
+ rate /= 2.0
+ return x, value
+```
+
+### Additional functions to compute gradients
+
+`tfe.GradientTape` is a powerful interface for computing gradients, but there
+is another [Autograd](https://github.com/HIPS/autograd)-style API available for
+automatic differentiation. These functions are useful if writing math code with
+only tensors and gradient functions, and without `tfe.Variables`:
+
+* `tfe.gradients_function` —Returns a function that computes the derivatives
+ of its input function parameter with respect to its arguments. The input
+ function parameter must return a scalar value. When the returned function is
+ invoked, it returns a list of `tf.Tensor` objects: one element for each
+ argument of the input function. Since anything of interest must be passed as a
+ function parameter, this becomes unwieldy if there's a dependency on many
+ trainable parameters.
+* `tfe.value_and_gradients_function` —Similar to
+ `tfe.gradients_function`, but when the returned function is invoked, it
+ returns the value from the input function in addition to the list of
+ derivatives of the input function with respect to its arguments.
+
+In the following example, `tfe.gradients_function` takes the `square`
+function as an argument and returns a function that computes the partial
+derivatives of `square` with respect to its inputs. To calculate the derivative
+of `square` at `3`, `grad(3.0)` returns `6`.
+
+```py
+def square(x):
+ return tf.multiply(x, x)
+
+grad = tfe.gradients_function(square)
+
+square(3.) # => 9.0
+grad(3.) # => [6.0]
+
+# The second-order derivative of square:
+gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
+gradgrad(3.) # => [2.0]
+
+# The third-order derivative is None:
+gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0])
+gradgradgrad(3.) # => [None]
+
+
+# With flow control:
+def abs(x):
+ return x if x > 0. else -x
+
+grad = tfe.gradients_function(abs)
+
+grad(3.) # => [1.0]
+grad(-3.) # => [-1.0]
+```
+
+### Custom gradients
+
+Custom gradients are an easy way to override gradients in eager and graph
+execution. Within the forward function, define the gradient with respect to the
+inputs, outputs, or intermediate results. For example, here's an easy way to clip
+the norm of the gradients in the backward pass:
+
+```py
+@tf.custom_gradient
+def clip_gradient_by_norm(x, norm):
+ y = tf.identity(x)
+ def grad_fn(dresult):
+ return [tf.clip_by_norm(dresult, norm), None]
+ return y, grad_fn
+```
+
+Custom gradients are commonly used to provide a numerically stable gradient for a
+sequence of operations:
+
+```py
+def log1pexp(x):
+ return tf.log(1 + tf.exp(x))
+grad_log1pexp = tfe.gradients_function(log1pexp)
+
+# The gradient computation works fine at x = 0.
+grad_log1pexp(0.) # => [0.5]
+
+# However, x = 100 fails because of numerical instability.
+grad_log1pexp(100.) # => [nan]
+```
+
+Here, the `log1pexp` function can be analytically simplified with a custom
+gradient. The implementation below reuses the value for `tf.exp(x)` that is
+computed during the forward pass—making it more efficient by eliminating
+redundant calculations:
+
+```py
+@tf.custom_gradient
+def log1pexp(x):
+ e = tf.exp(x)
+ def grad(dy):
+ return dy * (1 - 1 / (1 + e))
+ return tf.log(1 + e), grad
+
+grad_log1pexp = tfe.gradients_function(log1pexp)
+
+# As before, the gradient computation works fine at x = 0.
+grad_log1pexp(0.) # => [0.5]
+
+# And the gradient computation also works at x = 100.
+grad_log1pexp(100.) # => [1.0]
+```
+
## Performance
-Computation is not automatically offloaded to GPUs during eager execution. To
-explicitly direct a computation to a GPU, enclose it in a
-`tf.device('/gpu:0')` block:
+Computation is automatically offloaded to GPUs during eager execution. If you
+want control over where a computation runs you can enclose it in a
+`tf.device('/gpu:0')` block (or the CPU equivalent):
```py
import time
diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md
index e8c2fa6990..017db0e8cb 100644
--- a/tensorflow/docs_src/programmers_guide/index.md
+++ b/tensorflow/docs_src/programmers_guide/index.md
@@ -5,6 +5,7 @@ works. The units are as follows:
## High Level APIs
+ * @{$programmers_guide/eager}, which is the easiest way to use tensorflow.
* @{$programmers_guide/estimators}, which introduces a high-level
TensorFlow API that greatly simplifies ML programming.
* @{$programmers_guide/datasets}, which explains how to
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index 63bc39de6c..baa65d3243 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@@ -137,15 +138,15 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height,
// Now try to figure out what kind of file it is and decode it.
const int wanted_channels = 3;
tensorflow::Output image_reader;
- if (tensorflow::StringPiece(file_name).ends_with(".png")) {
+ if (tensorflow::str_util::EndsWith(file_name, ".png")) {
image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
DecodePng::Channels(wanted_channels));
- } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
+ } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
// gif decoder returns 4-D tensor, remove the first dim
image_reader =
Squeeze(root.WithOpName("squeeze_first_dim"),
DecodeGif(root.WithOpName("gif_reader"), file_reader));
- } else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) {
+ } else if (tensorflow::str_util::EndsWith(file_name, ".bmp")) {
image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
} else {
// Assume if it's neither a PNG nor a GIF then it must be a JPEG.
diff --git a/tensorflow/examples/multibox_detector/main.cc b/tensorflow/examples/multibox_detector/main.cc
index e38704fd98..96ea525a4e 100644
--- a/tensorflow/examples/multibox_detector/main.cc
+++ b/tensorflow/examples/multibox_detector/main.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -84,10 +85,10 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height,
// Now try to figure out what kind of file it is and decode it.
const int wanted_channels = 3;
tensorflow::Output image_reader;
- if (tensorflow::StringPiece(file_name).ends_with(".png")) {
+ if (tensorflow::str_util::EndsWith(file_name, ".png")) {
image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
DecodePng::Channels(wanted_channels));
- } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
+ } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader);
} else {
// Assume if it's neither a PNG nor a GIF then it must be a JPEG.
@@ -131,7 +132,7 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height,
Status SaveImage(const Tensor& tensor, const string& file_path) {
LOG(INFO) << "Saving image to " << file_path;
- CHECK(tensorflow::StringPiece(file_path).ends_with(".png"))
+ CHECK(tensorflow::str_util::EndsWith(file_path, ".png"))
<< "Only saving of png files is supported.";
auto root = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index a33703ad6f..0fd2177df7 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -1720,6 +1720,131 @@ func Size(scope *Scope, input tf.Output, optional ...SizeAttr) (output tf.Output
return op.Output(0)
}
+// Returns the rank of a tensor.
+//
+// This operation returns an integer representing the rank of `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// # shape of tensor 't' is [2, 2, 3]
+// rank(t) ==> 3
+// ```
+//
+// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
+// of a tensor is the number of indices required to uniquely select each element
+// of the tensor. Rank is also known as "order", "degree", or "ndims."
+func Rank(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Rank",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ReverseSequenceAttr is an optional argument to ReverseSequence.
+type ReverseSequenceAttr func(optionalAttr)
+
+// ReverseSequenceBatchDim sets the optional batch_dim attribute to value.
+//
+// value: The dimension along which reversal is performed.
+// If not specified, defaults to 0
+func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr {
+ return func(m optionalAttr) {
+ m["batch_dim"] = value
+ }
+}
+
+// Reverses variable length slices.
+//
+// This op first slices `input` along the dimension `batch_dim`, and for each
+// slice `i`, reverses the first `seq_lengths[i]` elements along
+// the dimension `seq_dim`.
+//
+// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`,
+// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`.
+//
+// The output slice `i` along dimension `batch_dim` is then given by input
+// slice `i`, with the first `seq_lengths[i]` slices along dimension
+// `seq_dim` reversed.
+//
+// For example:
+//
+// ```
+// # Given this:
+// batch_dim = 0
+// seq_dim = 1
+// input.dims = (4, 8, ...)
+// seq_lengths = [7, 2, 3, 5]
+//
+// # then slices of input are reversed on seq_dim, but only up to seq_lengths:
+// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
+// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
+// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
+// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]
+//
+// # while entries past seq_lens are copied through:
+// output[0, 7:, :, ...] = input[0, 7:, :, ...]
+// output[1, 2:, :, ...] = input[1, 2:, :, ...]
+// output[2, 3:, :, ...] = input[2, 3:, :, ...]
+// output[3, 2:, :, ...] = input[3, 2:, :, ...]
+// ```
+//
+// In contrast, if:
+//
+// ```
+// # Given this:
+// batch_dim = 2
+// seq_dim = 0
+// input.dims = (8, ?, 4, ...)
+// seq_lengths = [7, 2, 3, 5]
+//
+// # then slices of input are reversed on seq_dim, but only up to seq_lengths:
+// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
+// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
+// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
+// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]
+//
+// # while entries past seq_lens are copied through:
+// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
+// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
+// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
+// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
+// ```
+//
+// Arguments:
+// input: The input to reverse.
+// seq_lengths: 1-D with length `input.dims(batch_dim)` and
+// `max(seq_lengths) <= input.dims(seq_dim)`
+// seq_dim: The dimension which is partially reversed.
+//
+// Returns The partially reversed input. It has the same shape as `input`.
+func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"seq_dim": seq_dim}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ReverseSequence",
+ Input: []tf.Input{
+ input, seq_lengths,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the complex conjugate of a complex number.
//
// Given a tensor `input` of complex numbers, this operation returns a tensor of
@@ -5128,102 +5253,6 @@ func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
return op.Output(0)
}
-// ReverseSequenceAttr is an optional argument to ReverseSequence.
-type ReverseSequenceAttr func(optionalAttr)
-
-// ReverseSequenceBatchDim sets the optional batch_dim attribute to value.
-//
-// value: The dimension along which reversal is performed.
-// If not specified, defaults to 0
-func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr {
- return func(m optionalAttr) {
- m["batch_dim"] = value
- }
-}
-
-// Reverses variable length slices.
-//
-// This op first slices `input` along the dimension `batch_dim`, and for each
-// slice `i`, reverses the first `seq_lengths[i]` elements along
-// the dimension `seq_dim`.
-//
-// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`,
-// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`.
-//
-// The output slice `i` along dimension `batch_dim` is then given by input
-// slice `i`, with the first `seq_lengths[i]` slices along dimension
-// `seq_dim` reversed.
-//
-// For example:
-//
-// ```
-// # Given this:
-// batch_dim = 0
-// seq_dim = 1
-// input.dims = (4, 8, ...)
-// seq_lengths = [7, 2, 3, 5]
-//
-// # then slices of input are reversed on seq_dim, but only up to seq_lengths:
-// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
-// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
-// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
-// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]
-//
-// # while entries past seq_lens are copied through:
-// output[0, 7:, :, ...] = input[0, 7:, :, ...]
-// output[1, 2:, :, ...] = input[1, 2:, :, ...]
-// output[2, 3:, :, ...] = input[2, 3:, :, ...]
-// output[3, 2:, :, ...] = input[3, 2:, :, ...]
-// ```
-//
-// In contrast, if:
-//
-// ```
-// # Given this:
-// batch_dim = 2
-// seq_dim = 0
-// input.dims = (8, ?, 4, ...)
-// seq_lengths = [7, 2, 3, 5]
-//
-// # then slices of input are reversed on seq_dim, but only up to seq_lengths:
-// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
-// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
-// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
-// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]
-//
-// # while entries past seq_lens are copied through:
-// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
-// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
-// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
-// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
-// ```
-//
-// Arguments:
-// input: The input to reverse.
-// seq_lengths: 1-D with length `input.dims(batch_dim)` and
-// `max(seq_lengths) <= input.dims(seq_dim)`
-// seq_dim: The dimension which is partially reversed.
-//
-// Returns The partially reversed input. It has the same shape as `input`.
-func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"seq_dim": seq_dim}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ReverseSequence",
- Input: []tf.Input{
- input, seq_lengths,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative.
type DepthwiseConv2dNativeAttr func(optionalAttr)
@@ -5808,35 +5837,6 @@ func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Outp
return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
}
-// Returns the rank of a tensor.
-//
-// This operation returns an integer representing the rank of `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// # shape of tensor 't' is [2, 2, 3]
-// rank(t) ==> 3
-// ```
-//
-// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
-// of a tensor is the number of indices required to uniquely select each element
-// of the tensor. Rank is also known as "order", "degree", or "ndims."
-func Rank(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Rank",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Transforms a Tensor into a serialized TensorProto proto.
//
// Arguments:
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index 0b69a8cbe5..c99d04869a 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0-rc1</version>
+ <version>1.7.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 541876f7f5..4561c2c8ad 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0-rc1</version>
+ <version>1.7.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index d8933e5238..82a2b8e769 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0-rc1</version>
+ <version>1.7.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 6286fd73df..4c1ec0cc80 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0-rc1</version>
+ <version>1.7.0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 4e881f5a63..fcd8236bad 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0-rc1</version>
+ <version>1.7.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index d512a7eda9..241581713a 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.7.0-rc1</version>
+ <version>1.7.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a3e79d46d8..6ec8a1cdab 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -28,6 +28,8 @@ load("//tensorflow:tensorflow.bzl", "py_tests")
load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
+load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library")
@@ -58,9 +60,10 @@ py_library(
"//tensorflow/tools/api/generator:__pkg__",
"//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed
],
- deps = [":no_contrib"] + if_not_windows([
+ deps = [
+ ":no_contrib",
"//tensorflow/contrib:contrib_py",
- ]),
+ ],
)
py_library(
@@ -284,6 +287,17 @@ cc_library(
)
cc_library(
+ name = "py_exception_registry",
+ srcs = ["lib/core/py_exception_registry.cc"],
+ hdrs = ["lib/core/py_exception_registry.h"],
+ deps = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:lib",
+ "//util/python:python_headers",
+ ],
+)
+
+cc_library(
name = "kernel_registry",
srcs = ["util/kernel_registry.cc"],
hdrs = ["util/kernel_registry.h"],
@@ -413,6 +427,7 @@ tf_cc_shared_object(
"-lm",
],
"//tensorflow:darwin": [],
+ "//tensorflow:windows": [],
}),
deps = [
"//tensorflow/core:framework_headers_lib",
@@ -960,7 +975,6 @@ py_test(
srcs = ["framework/contrib_test.py"],
main = "framework/contrib_test.py",
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
@@ -1330,7 +1344,6 @@ py_test(
srcs = ["framework/dtypes_test.py"],
main = "framework/dtypes_test.py",
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":framework_for_generated_wrappers",
":framework_test_lib",
@@ -1706,7 +1719,6 @@ py_test(
size = "small",
srcs = ["ops/clip_ops_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":client_testlib",
":clip_ops",
@@ -2775,7 +2787,6 @@ cuda_py_test(
],
data = ["//tensorflow/core:image_testdata"],
shard_count = 5,
- tags = ["no_windows"],
)
cuda_py_test(
@@ -3313,6 +3324,7 @@ tf_py_wrap_cc(
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
"lib/core/bfloat16.i",
+ "lib/core/py_exception_registry.i",
"lib/core/py_func.i",
"lib/core/strings.i",
"lib/io/file_io.i",
@@ -3340,6 +3352,7 @@ tf_py_wrap_cc(
":kernel_registry",
":numpy_lib",
":safe_ptr",
+ ":py_exception_registry",
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
@@ -3374,6 +3387,65 @@ tf_py_wrap_cc(
tf_additional_gdr_deps()),
)
+# ** Targets for Windows build (start) **
+# We need the following targets to expose symbols from _pywrap_tensorflow.dll
+
+# Build a cc_binary from tf_custom_op_library_additional_deps_impl,
+# it contains all object code from its dependencies.
+tf_native_cc_binary(
+ name = "tf_custom_op_library_additional_deps.so",
+ linkshared = 1,
+ linkstatic = 1,
+ deps = tf_custom_op_library_additional_deps_impl(),
+)
+
+# Get a DEF file generated by parsing all object files
+# of tf_custom_op_library_additional_deps.so
+filegroup(
+ name = "pywrap_tensorflow_def_file",
+ srcs = [":tf_custom_op_library_additional_deps.so"],
+ output_group = "def_file",
+)
+
+# Filter the DEF file to reduce the number of symbols to 64K or less.
+# Note that we also write the name of the pyd file into DEF file so that
+# the dynamic libraries of custom ops can find it at runtime.
+genrule(
+ name = "pywrap_tensorflow_filtered_def_file",
+ srcs = [":pywrap_tensorflow_def_file"],
+ outs = ["pywrap_tensorflow_filtered_def_file.def"],
+ cmd = select({
+ "//tensorflow:windows": """
+ $(location @local_config_def_file_filter//:def_file_filter) \\
+ --input $(location :pywrap_tensorflow_def_file) \\
+ --output $@ \\
+ --target _pywrap_tensorflow_internal.pyd
+ """,
+ "//conditions:default": "touch $@", # Just a placeholder for Unix platforms
+ }),
+ tools = ["@local_config_def_file_filter//:def_file_filter"],
+)
+
+# Get the import library of _pywrap_tensorflow_internal.dll
+filegroup(
+ name = "pywrap_tensorflow_import_lib_file",
+ srcs = [":_pywrap_tensorflow_internal.so"],
+ output_group = "interface_library",
+)
+
+# Create a cc_import rule for the import library of _pywrap_tensorflow_internal.dll
+# so that custom ops' dynamic libraries can link against it.
+cc_import(
+ name = "pywrap_tensorflow_import_lib",
+ interface_library = select({
+ "//tensorflow:windows": ":pywrap_tensorflow_import_lib_file",
+ "//conditions:default": "not_exsiting_on_unix.lib", # Just a placeholder for Unix platforms
+ }),
+ system_provided = 1,
+)
+
+# ** Targets for Windows build (end) **
+
py_library(
name = "lib",
srcs = [
@@ -3707,6 +3779,7 @@ cuda_py_test(
":math_ops",
"//tensorflow/core:protos_all_py",
],
+ tags = ["no_windows"],
)
cuda_py_test(
@@ -3746,7 +3819,6 @@ py_test(
size = "small",
srcs = ["lib/core/bfloat16_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":client_testlib",
":lib",
@@ -3939,6 +4011,7 @@ py_test(
srcs = ["training/saver_large_partitioned_variable_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "no_windows",
"noasan", # http://b/30782289
"notsan", # http://b/30782289
],
@@ -4054,7 +4127,6 @@ py_test(
size = "small",
srcs = ["training/checkpoint_ops_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":checkpoint_ops_gen",
":client",
@@ -4095,10 +4167,7 @@ py_test(
size = "medium",
srcs = ["training/monitored_session_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_windows",
- "notsan", # b/67945581
- ],
+ tags = ["notsan"], # b/67945581
deps = [
":array_ops",
":client_testlib",
@@ -4710,6 +4779,7 @@ py_test(
":client_testlib",
":framework_for_generated_wrappers",
":math_ops",
+ ":tf_item",
":tf_optimizer",
"//tensorflow/core:protos_all_py",
"//third_party/py/numpy",
@@ -4771,6 +4841,29 @@ py_test(
)
cuda_py_test(
+ name = "constant_folding_test",
+ size = "medium",
+ srcs = [
+ "grappler/constant_folding_test.py",
+ ],
+ additional_deps = [
+ ":client_testlib",
+ ":framework_for_generated_wrappers",
+ ":array_ops",
+ ":control_flow_ops",
+ ":dtypes",
+ ":functional_ops",
+ ":math_ops",
+ ":ops",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ ],
+ tags = [
+ "grappler",
+ ],
+)
+
+cuda_py_test(
name = "layout_optimizer_test",
size = "medium",
srcs = [
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5c9ed9ccaf..4c84d78f2e 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -27,7 +27,6 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -629,14 +628,12 @@ class BaseSession(SessionInterface):
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- # pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts,
- status)
- # pylint: enable=protected-access
- else:
- self._session = tf_session.TF_NewDeprecatedSession(opts, status)
+ if self._created_with_new_api:
+ # pylint: disable=protected-access
+ self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ # pylint: enable=protected-access
+ else:
+ self._session = tf_session.TF_NewDeprecatedSession(opts)
finally:
tf_session.TF_DeleteSessionOptions(opts)
@@ -663,22 +660,20 @@ class BaseSession(SessionInterface):
Returns:
A list of devices in the session.
"""
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- raw_device_list = tf_session.TF_SessionListDevices(
- self._session, status)
- else:
- raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
- self._session, status)
- device_list = []
- size = tf_session.TF_DeviceListCount(raw_device_list)
- for i in range(size):
- name = tf_session.TF_DeviceListName(raw_device_list, i, status)
- device_type = tf_session.TF_DeviceListType(raw_device_list, i, status)
- memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status)
- device_list.append(_DeviceAttributes(name, device_type, memory))
- tf_session.TF_DeleteDeviceList(raw_device_list)
- return device_list
+ if self._created_with_new_api:
+ raw_device_list = tf_session.TF_SessionListDevices(self._session)
+ else:
+ raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
+ self._session)
+ device_list = []
+ size = tf_session.TF_DeviceListCount(raw_device_list)
+ for i in range(size):
+ name = tf_session.TF_DeviceListName(raw_device_list, i)
+ device_type = tf_session.TF_DeviceListType(raw_device_list, i)
+ memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
+ device_list.append(_DeviceAttributes(name, device_type, memory))
+ tf_session.TF_DeleteDeviceList(raw_device_list)
+ return device_list
def close(self):
"""Closes this session.
@@ -692,15 +687,13 @@ class BaseSession(SessionInterface):
if self._created_with_new_api:
if self._session and not self._closed:
self._closed = True
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_CloseSession(self._session, status)
+ tf_session.TF_CloseSession(self._session)
else:
with self._extend_lock:
if self._opened and not self._closed:
self._closed = True
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_CloseDeprecatedSession(self._session, status)
+ tf_session.TF_CloseDeprecatedSession(self._session)
def __del__(self):
# cleanly ignore all exceptions
@@ -710,11 +703,10 @@ class BaseSession(SessionInterface):
pass
if self._session is not None:
try:
- status = c_api_util.ScopedTFStatus()
if self._created_with_new_api:
- tf_session.TF_DeleteSession(self._session, status)
+ tf_session.TF_DeleteSession(self._session)
else:
- tf_session.TF_DeleteDeprecatedSession(self._session, status)
+ tf_session.TF_DeleteDeprecatedSession(self._session)
except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case,
@@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph()
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- return tf_session.TF_SessionPRunSetup_wrapper(
- session, feed_list, fetch_list, target_list, status)
- else:
+ if self._created_with_new_api:
+ return tf_session.TF_SessionPRunSetup_wrapper(
+ session, feed_list, fetch_list, target_list)
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
target_list, status)
@@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface):
def _extend_graph(self):
if self._created_with_new_api:
with self._graph._lock: # pylint: disable=protected-access
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.ExtendSession(self._session, status)
+ tf_session.ExtendSession(self._session)
else:
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
@@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata):
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- self._session, options, feed_dict, fetch_list, target_list,
- run_metadata, status)
- else:
+ if self._created_with_new_api:
+ return tf_session.TF_SessionRun_wrapper(
+ self._session, options, feed_dict, fetch_list, target_list,
+ run_metadata)
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_Run(
self._session, options, feed_dict, fetch_list, target_list,
status, run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- return tf_session.TF_SessionPRun_wrapper(
- self._session, handle, feed_dict, fetch_list, status)
- else:
+ if self._created_with_new_api:
+ return tf_session.TF_SessionPRun_wrapper(
+ self._session, handle, feed_dict, fetch_list)
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRun(
self._session, handle, feed_dict, fetch_list, status)
diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py
index 5a7413c12e..38a3acb2dc 100644
--- a/tensorflow/python/client/session_list_devices_test.py
+++ b/tensorflow/python/client/session_list_devices_test.py
@@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.client import session
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object):
def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions()
- with errors.raise_exception_on_not_ok_status() as status:
- c_session = tf_session.TF_NewSession(
- ops.get_default_graph()._c_graph, opts, status)
- raw_device_list = tf_session.TF_SessionListDevices(
- c_session, status)
+ c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts)
+ raw_device_list = tf_session.TF_SessionListDevices(c_session)
size = tf_session.TF_DeviceListCount(raw_device_list)
- # Test that invalid device numbers return -1 rather than a Swig-wrapped
- # pointer.
- status_no_exception = c_api_util.ScopedTFStatus()
- memory = tf_session.TF_DeviceListMemoryBytes(
- raw_device_list, size, status_no_exception)
- self.assertEqual(memory, -1)
+ with self.assertRaises(errors.InvalidArgumentError):
+ tf_session.TF_DeviceListMemoryBytes(raw_device_list, size)
tf_session.TF_DeleteDeviceList(raw_device_list)
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_CloseSession(c_session, status)
+ tf_session.TF_CloseSession(c_session)
def testListDevicesGrpcSession(self):
server = server_lib.Server.create_local_server()
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 77ce9195ee..b82182d5d3 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -18,11 +18,11 @@ limitations under the License.
%{
#include "tensorflow/c/python_api.h"
-#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
+#include "tensorflow/python/client/tf_session_helper.h"
// Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs.
@@ -72,7 +72,7 @@ void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) {
int size = PySequence_Fast_GET_SIZE(py_int_seq);
for (int i = 0; i < size; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(py_int_seq, i);
- vec->push_back(PyInt_AsLong(item));
+ vec->push_back(PyLong_AsLongLong(item));
}
}
@@ -157,6 +157,25 @@ tensorflow::ImportNumpy();
}
}
+// We use TF_OperationGetControlOutputs_wrapper instead of
+// TF_OperationGetControlOutputs
+%ignore TF_OperationGetControlOutputs;
+%unignore TF_OperationGetControlOutputs_wrapper;
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_OperationGetControlOutputs_wrapper;
+
+// Build a Python list of TF_Operation* and return it.
+%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper {
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+
+ for (size_t i = 0; i < $1.size(); ++i) {
+ PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
+ }
+}
+
%ignore TF_OperationOutputConsumers;
%unignore TF_OperationOutputConsumers_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
@@ -438,6 +457,11 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
$1 = PyLong_AsLongLong($input);
}
+// Override default py3 behavior of attempting to encode into Unicode.
+%typemap(out) std::string tensorflow::ResourceHandleShapeAndType {
+ $result = PyBytes_FromStringAndSize($1.data(), $1.size());
+}
+
// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
// skip for now
%ignore TF_WhileParams;
@@ -499,9 +523,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
_TF_SetTarget(opts, target)
if config is not None:
from tensorflow.python.framework import errors
- with errors.raise_exception_on_not_ok_status() as status:
- config_str = config.SerializeToString()
- _TF_SetConfig(opts, config_str, status)
+ config_str = config.SerializeToString()
+ _TF_SetConfig(opts, config_str)
return opts
%}
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index ca57abd712..b48d758e4a 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -550,6 +550,15 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
return control_inputs;
}
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+ TF_Operation* oper) {
+ std::vector<TF_Operation*> control_outputs(
+ TF_OperationNumControlOutputs(oper));
+ TF_OperationGetControlOutputs(oper, control_outputs.data(),
+ control_outputs.size());
+ return control_outputs;
+}
+
std::vector<const char*> TF_OperationOutputConsumers_wrapper(
TF_Output oper_out) {
int num_consumers = TF_OperationOutputNumConsumers(oper_out);
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 603d03e315..d2b4abc476 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected);
//
// If shape is unknown, sets unknown_shape to true.
tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
- TF_Graph* graph, TF_Output output, TF_Status* out_status,
- bool* unknown_shape);
+ TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape);
// Runs the graph associated with the session starting with the supplied inputs.
// On success, `py_outputs` is populated with a numpy ndarray for each output
@@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets,
- TF_Buffer* run_metadata, TF_Status* out_status,
+ TF_Buffer* run_metadata, TF_Status* status,
std::vector<PyObject*>* py_outputs);
// Set up the graph with the intended feeds (inputs) and fetches (output) for
@@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session,
const std::vector<TF_Output>& inputs,
const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets,
- const char** out_handle,
- TF_Status* out_status);
+ const char** out_handle, TF_Status* status);
// Continue to run the graph with additional feeds and fetches. The
// execution state is uniquely identified by the handle.
@@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
const std::vector<TF_Output>& inputs,
const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs,
- TF_Status* out_status,
+ TF_Status* status,
std::vector<PyObject*>* py_outputs);
// Retrieves the inputs of this operation.
@@ -192,6 +190,10 @@ std::vector<TF_Output> GetOperationInputs(TF_Operation* oper);
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
TF_Operation* oper);
+// Retrieves the control outputs of this operation.
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+ TF_Operation* oper);
+
// Retrieves the op names of the consumers of `oper_out`. The returned strings
// have the lifetime of the underlying TF_Graph.
std::vector<const char*> TF_OperationOutputConsumers_wrapper(
@@ -204,7 +206,7 @@ TF_Function* TF_GraphToFunction_wrapper(
const std::vector<TF_Operation*>* opers,
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
const NameVector& output_names, const TF_FunctionOptions* opts,
- const char* description, TF_Status* out_status);
+ const char* description, TF_Status* status);
// Set the shapes and types for the output's handle.
//
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index 4a14a915bd..0af282a024 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -717,6 +718,14 @@ class IteratorTest(test.TestCase):
self.assertTrue(
iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE in str(warning.message))
+ def testEagerIteratorAsync(self):
+ with context.eager_mode(), context.execution_mode(context.ASYNC):
+ val = 0
+ dataset = dataset_ops.Dataset.range(10)
+ for foo in dataset:
+ self.assertEqual(val, foo.numpy())
+ val += 1
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index d79b9d6011..0c76afd29d 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -488,23 +488,27 @@ class EagerIterator(object):
def _next_internal(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
"""
- with ops.device(self._device):
- # TODO(ashankar): Consider removing this ops.device() contextmanager
- # and instead mimic ops placement in graphs: Operations on resource
- # handles execute on the same device as where the resource is placed.
- # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
- # because in eager mode this code will run synchronously on the calling
- # thread. Therefore we do not need to make a defensive context switch
- # to a background thread, and can achieve a small constant performance
- # boost by invoking the iterator synchronously.
- ret = gen_dataset_ops.iterator_get_next_sync(
- self._resource,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ with ops.device(self._device):
+ # TODO(ashankar): Consider removing this ops.device() contextmanager
+ # and instead mimic ops placement in graphs: Operations on resource
+ # handles execute on the same device as where the resource is placed.
+ # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
+ # because in eager mode this code will run synchronously on the calling
+ # thread. Therefore we do not need to make a defensive context switch
+ # to a background thread, and can achieve a small constant performance
+ # boost by invoking the iterator synchronously.
+ ret = gen_dataset_ops.iterator_get_next_sync(
+ self._resource,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
def next(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 4195586313..250b4b1b6a 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -913,6 +913,7 @@ cuda_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
+ tags = ["no_windows"], # TODO: needs investigation on Windows
)
py_test(
@@ -920,6 +921,7 @@ py_test(
size = "small",
srcs = ["cli/profile_analyzer_cli_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":debugger_cli_common",
":profile_analyzer_cli",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 209b012621..92774d4d50 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -31,7 +31,6 @@ from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -50,12 +49,10 @@ def op_attr_type(op_type, attr_name):
try:
return _op_attr_type_cache[(op_type, attr_name)]
except KeyError:
- with errors.raise_exception_on_not_ok_status() as status:
- h = context.context()._handle # pylint: disable=protected-access
- attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(
- h, op_type, attr_name, status)
- _op_attr_type_cache[(op_type, attr_name)] = attr_type
- return attr_type
+ h = context.context()._handle # pylint: disable=protected-access
+ attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name)
+ _op_attr_type_cache[(op_type, attr_name)] = attr_type
+ return attr_type
def make_attr(attr_type, value):
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 9ca5041c38..7ad37058fd 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -201,6 +201,9 @@ class MicroBenchmarks(test.Benchmark):
m = self._m_2
self._run(lambda: gen_array_ops.identity(m), 30000)
+ def benchmark_slowpath_tf_identity(self):
+ self._run(lambda: gen_array_ops.identity(1), 30000)
+
def benchmark_tfe_py_execute_identity(self):
m = self._m_2
ctx_handle = context.context()._handle
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 8c1bb06bc3..9e146f021e 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
-from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
@@ -86,6 +85,7 @@ class _EagerContext(threading.local):
self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
self.mode = _default_mode
+ self.is_eager = _default_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
@@ -223,34 +223,27 @@ class Context(object):
assert self._context_devices is None
opts = pywrap_tensorflow.TFE_NewContextOptions()
try:
- with errors.raise_exception_on_not_ok_status() as status:
- if self._config is not None:
- config_str = self._config.SerializeToString()
- pywrap_tensorflow.TFE_ContextOptionsSetConfig(
- opts, config_str, len(config_str), status)
- if self._device_policy is not None:
- pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
- opts, self._device_policy)
- if self._execution_mode == ASYNC:
- pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
- self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
+ if self._config is not None:
+ config_str = self._config.SerializeToString()
+ pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
+ if self._device_policy is not None:
+ pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
+ opts, self._device_policy)
+ if self._execution_mode == ASYNC:
+ pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
+ self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
# Store list of devices
self._context_devices = []
- with errors.raise_exception_on_not_ok_status() as status:
- device_list = pywrap_tensorflow.TFE_ContextListDevices(
- self._context_handle, status)
+ device_list = pywrap_tensorflow.TFE_ContextListDevices(
+ self._context_handle)
try:
self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
- with errors.raise_exception_on_not_ok_status() as status:
- dev_name = pywrap_tensorflow.TF_DeviceListName(
- device_list, i, status)
+ dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
self._context_devices.append(pydev.canonical_name(dev_name))
- with errors.raise_exception_on_not_ok_status() as status:
- dev_type = pywrap_tensorflow.TF_DeviceListType(
- device_list, i, status)
+ dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
if dev_type == "GPU":
self._num_gpus += 1
@@ -287,9 +280,12 @@ class Context(object):
@tf_contextlib.contextmanager
def _mode(self, mode):
+ """A context manager to allow setting the mode to EAGER/GRAPH."""
ctx = self._eager_context
old_mode = ctx.mode
+ old_is_eager = ctx.is_eager
ctx.mode = mode
+ ctx.is_eager = mode == EAGER_MODE
if mode == EAGER_MODE:
# Entering graph mode does not provide us with sufficient information to
# record a context switch; graph-based context switches are only logged
@@ -298,13 +294,14 @@ class Context(object):
try:
yield
finally:
+ ctx.is_eager = old_is_eager
ctx.mode = old_mode
if mode == EAGER_MODE:
self.context_switches.pop()
def executing_eagerly(self):
"""Returns True if current thread has eager executing enabled."""
- return self._eager_context.mode == EAGER_MODE
+ return self._eager_context.is_eager
def scalar_cache(self):
"""Per-device cache for scalars."""
@@ -411,9 +408,7 @@ class Context(object):
if mode is None:
mode = SYNC
self._eager_context.execution_mode = mode
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle,
- mode == ASYNC, status)
+ pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, mode == ASYNC)
@tf_contextlib.contextmanager
def execution_mode(self, mode):
@@ -427,8 +422,7 @@ class Context(object):
def async_wait(self):
"""Waits for ops dispatched in ASYNC mode to finish."""
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextAsyncWait(self._handle, status)
+ pywrap_tensorflow.TFE_ContextAsyncWait(self._handle)
def async_clear_error(self):
"""Clears errors raised during ASYNC execution."""
@@ -448,11 +442,9 @@ class Context(object):
Args:
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextAddFunction(
- self._handle, # pylint: disable=protected-access
- fn,
- status)
+ pywrap_tensorflow.TFE_ContextAddFunction(
+ self._handle, # pylint: disable=protected-access
+ fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
@@ -464,12 +456,10 @@ class Context(object):
fdef: A FunctionDef protocol buffer message.
"""
fdef_string = fdef.SerializeToString()
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextAddFunctionDef(
- self._handle, # pylint: disable=protected-access
- fdef_string,
- len(fdef_string),
- status)
+ pywrap_tensorflow.TFE_ContextAddFunctionDef(
+ self._handle, # pylint: disable=protected-access
+ fdef_string,
+ len(fdef_string))
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.
@@ -512,23 +502,19 @@ class Context(object):
To retrieve the accumulated metadata call context.export_run_metadata()
and to stop tracing call context.disable_run_metadata().
"""
- if not self._context_handle:
- self._initialize_handle_and_devices()
- pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
+ pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
@tf_contextlib.contextmanager
def device_policy(self, policy):
- if not self._context_handle:
- self._initialize_handle_and_devices()
- old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
- self._context_handle)
+ handle = self._handle
+ old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle)
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
- self._handle, policy)
+ handle, policy)
try:
yield
finally:
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
- self._handle, old)
+ handle, old)
def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata."""
@@ -548,9 +534,8 @@ class Context(object):
if not self._context_handle:
return None
with c_api_util.tf_buffer() as buffer_:
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextExportRunMetadata(
- self._context_handle, buffer_, status)
+ pywrap_tensorflow.TFE_ContextExportRunMetadata(
+ self._context_handle, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data))
@@ -579,6 +564,10 @@ def context():
return _context
+def context_safe():
+ return _context
+
+
# TODO(agarwal): remove this.
def get_default_context():
"""Same as context."""
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 343012e552..711eddcec1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -34,7 +34,6 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -79,14 +78,10 @@ def capture_value(tensor_map, value, dtype, name):
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- captured_value._op._graph._c_graph, # pylint: disable=protected-access
- captured_value._as_tf_output(), # pylint: disable=protected-access
- shapes,
- ranks,
- types,
- status)
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ captured_value._op._graph._c_graph, # pylint: disable=protected-access
+ captured_value._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
tensor_map[ops.tensor_id(value)] = (value, captured_value)
else:
@@ -275,23 +270,20 @@ class _EagerDefinedFunction(object):
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
"""
- with errors.raise_exception_on_not_ok_status() as status:
- fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
- graph._c_graph, # pylint: disable=protected-access
- compat.as_str(name),
- False,
- [o._c_op for o in operations], # pylint: disable=protected-access
- [t._as_tf_output() for t in inputs], # pylint: disable=protected-access
- [t._as_tf_output() for t in outputs], # pylint: disable=protected-access
- [],
- None,
- compat.as_str(""),
- status)
+ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
+ graph._c_graph, # pylint: disable=protected-access
+ compat.as_str(name),
+ False,
+ [o._c_op for o in operations], # pylint: disable=protected-access
+ [t._as_tf_output() for t in inputs], # pylint: disable=protected-access
+ [t._as_tf_output() for t in outputs], # pylint: disable=protected-access
+ [],
+ None,
+ compat.as_str(""))
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
+ pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 837cad974a..000152855d 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import collections
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import errors
VSpace = collections.namedtuple(
@@ -60,6 +59,5 @@ def imperative_grad(
or if only non-differentiable functions of the source were used in the
computation of target.
"""
- with errors.raise_exception_on_not_ok_status() as status:
- return pywrap_tensorflow.TFE_Py_TapeGradient(
- tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access
+ return pywrap_tensorflow.TFE_Py_TapeGradient(
+ tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
index c2ce8efd7f..9afab0077b 100644
--- a/tensorflow/python/eager/python_eager_op_gen.cc
+++ b/tensorflow/python/eager/python_eager_op_gen.cc
@@ -117,7 +117,7 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
const string& function_name)
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
op_name_ = function_name_;
- op_name_.Consume("_");
+ str_util::ConsumePrefix(&op_name_, "_");
}
~GenEagerPythonOp() override {}
@@ -366,8 +366,8 @@ string GenEagerPythonOp::Code() {
void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
// Handle graph-mode case
strings::StrAppend(&result_,
- " _ctx = _context.context()\n"
- " if not _ctx.executing_eagerly():\n",
+ " _ctx = _context._context\n"
+ " if _ctx is None or not _ctx._eager_context.is_eager:\n",
function_setup,
" _, _, _op = _op_def_lib._apply_op_helper(\n");
AddBodyNoReturn(" ");
@@ -492,7 +492,7 @@ bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
strings::StrAppend(function_setup, indentation, " ", attr_api_name,
" = ", default_value, "\n");
}
- if (attr_type.starts_with("list(")) {
+ if (str_util::StartsWith(attr_type, "list(")) {
ExpectListArg(indentation, attr_api_name, function_setup);
}
@@ -683,13 +683,14 @@ bool GenEagerPythonOp::AddEagerFallbackCode(
return true;
}
- AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters);
+ AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix),
+ strings::StrCat(parameters, ", ctx=None"));
strings::StrAppend(
&result_, " r\"\"\"This is the slowpath function for Eager mode.\n");
strings::StrAppend(&result_, " This is for function ", function_name_,
"\n \"\"\"\n");
- strings::StrAppend(&result_, " _ctx = _context.context()\n");
+ strings::StrAppend(&result_, " _ctx = ctx if ctx else _context.context()\n");
string function_setup;
if (!GetEagerFunctionSetup(" ", &function_setup)) {
@@ -712,9 +713,9 @@ bool GenEagerPythonOp::AddEagerFallbackCode(
}
void GenEagerPythonOp::AddEagerFastPathExecute() {
- string fastpath_execute_params =
- strings::StrCat("_ctx._handle, _ctx.device_name, \"", op_def_.name(),
- "\", ", "name, _ctx._post_execution_callbacks");
+ string fastpath_execute_params = strings::StrCat(
+ "_ctx._context_handle, _ctx._eager_context.device_name, \"",
+ op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
string fallback_params;
for (int i = 0; i < api_def_.in_arg_size(); i++) {
@@ -755,6 +756,8 @@ void GenEagerPythonOp::AddEagerFastPathExecute() {
strings::StrAppend(&result_, " ", "return _result\n");
// Handle fallback.
+ if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
+ strings::StrAppend(&fallback_params, "ctx=_ctx");
strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
strings::StrAppend(
&result_, " ", "return ", function_name_, kEagerFallbackSuffix,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 8a398f6447..d99bd0b0ff 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1844,6 +1844,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
+
+ if (op_exec_info.ctx == nullptr) {
+ // The context hasn't been initialized. It will be in the slow path.
+ RaiseFallbackException(
+ "This function does not handle the case of the path where "
+ "all inputs are not already EagerTensors.");
+ return nullptr;
+ }
+
op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
op_exec_info.op_def = GetOpDef(op_exec_info.op_name);
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index f93bc221cc..5d8b19223f 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -966,5 +966,6 @@ cuda_py_test(
tags = [
"multi_gpu",
"noasan", # flaky time outs
+ "notsan", # flaky
],
)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index ab69a093a2..4d3eff71ad 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -188,7 +188,7 @@ class Estimator(object):
self._config = config
# The distribute field contains an instance of DistributionStrategy.
- self._distribution = self._config.distribute
+ self._distribution = self._config.train_distribute
# Model directory.
model_dir = compat_internal.path_to_str(model_dir)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 41415b89e9..f62c9cece6 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -44,7 +44,7 @@ _DEFAULT_REPLACEABLE_LIST = [
'keep_checkpoint_max',
'keep_checkpoint_every_n_hours',
'log_step_count_steps',
- 'distribute'
+ 'train_distribute'
]
_SAVE_CKPT_ERR = (
@@ -302,7 +302,7 @@ class RunConfig(object):
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100,
- distribute=None):
+ train_distribute=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -426,10 +426,10 @@ class RunConfig(object):
the feature.
log_step_count_steps: The frequency, in number of global steps, that the
global step/sec and the loss will be logged during training.
- distribute: an optional instance of
+ train_distribute: an optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
- then Estimator will distribute the user's model according to the policy
- specified by that strategy.
+ then Estimator will distribute the user's model during training,
+ according to the policy specified by that strategy.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -466,7 +466,7 @@ class RunConfig(object):
keep_checkpoint_max=keep_checkpoint_max,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
log_step_count_steps=log_step_count_steps,
- distribute=distribute)
+ train_distribute=train_distribute)
self._init_distributed_setting_from_environment_var(tf_config)
@@ -678,10 +678,10 @@ class RunConfig(object):
return self._service
@property
- def distribute(self):
+ def train_distribute(self):
"""Returns the optional `tf.contrib.distribute.DistributionStrategy` object.
"""
- return self._distribute
+ return self._train_distribute
def replace(self, **kwargs):
"""Returns a new instance of `RunConfig` replacing specified properties.
@@ -697,7 +697,7 @@ class RunConfig(object):
- `keep_checkpoint_max`,
- `keep_checkpoint_every_n_hours`,
- `log_step_count_steps`,
- - `distribute`.
+ - `train_distribute`.
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 219105d386..295d4ca094 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -43,6 +43,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/keras",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 92c6ff21c4..3a315e5c2e 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -139,6 +139,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -460,6 +462,154 @@ def linear_model(features,
return predictions
+class _FCLinearWrapper(base.Layer):
+ """Wraps a _FeatureColumn in a layer for use in a linear model.
+
+ See `linear_model` above.
+ """
+
+ def __init__(self,
+ feature_column,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_FCLinearWrapper, self).__init__(
+ trainable=trainable, name=name, **kwargs)
+ self._feature_column = feature_column
+ self._units = units
+ self._sparse_combiner = sparse_combiner
+ self._weight_collections = weight_collections
+ self._state = {}
+
+ def build(self, _):
+ self._state = self._feature_column._create_state( # pylint: disable=protected-access
+ self._weight_collections, self.add_variable)
+
+ if isinstance(self._feature_column, _CategoricalColumn):
+ weight = self.add_variable(
+ name='weights',
+ shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ else:
+ num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
+ weight = self.add_variable(
+ name='weights',
+ shape=[num_elements, self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ ops.add_to_collections(self._weight_collections, weight)
+ self._weight_var = weight
+ self.built = True
+
+ def call(self, builder):
+ weighted_sum = _create_weighted_sum(
+ column=self._feature_column,
+ builder=builder,
+ units=self._units,
+ sparse_combiner=self._sparse_combiner,
+ weight_collections=self._weight_collections,
+ trainable=self.trainable,
+ weight_var=self._weight_var,
+ state=self._state)
+ return weighted_sum
+
+
+class _BiasLayer(base.Layer):
+ """A layer for the bias term.
+ """
+
+ def __init__(self,
+ units=1,
+ trainable=True,
+ weight_collections=None,
+ name=None,
+ **kwargs):
+ super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
+ self._units = units
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ self._bias_variable = self.add_variable(
+ 'bias_weights',
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ ops.add_to_collections(self._weight_collections, self._bias_variable)
+ self.built = True
+
+ def call(self, _):
+ return self._bias_variable
+
+
+class _LinearModel(training.Model):
+ """Creates a linear model using feature columns.
+ """
+
+ def __init__(self,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_LinearModel, self).__init__(name=name, **kwargs)
+ self._feature_columns = _clean_feature_columns(feature_columns)
+ self._weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
+
+ column_layers = {}
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
+ column_name = vs.name
+ column_layer = _FCLinearWrapper(column, units, sparse_combiner,
+ self._weight_collections, trainable,
+ column_name, **kwargs)
+ column_layers[column_name] = column_layer
+ self._column_layers = self._add_layers(column_layers)
+ self._bias_layer = _BiasLayer(
+ units=units,
+ trainable=trainable,
+ weight_collections=self._weight_collections,
+ name='bias_layer',
+ **kwargs)
+
+ def call(self, features):
+ for column in self._feature_columns:
+ if not isinstance(column, (_DenseColumn, _CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ weighted_sums = []
+ ordered_columns = []
+ builder = _LazyBuilder(features)
+ for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
+ ordered_columns.append(layer._feature_column) # pylint: disable=protected-access
+ weighted_sum = layer(builder)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, self._bias_layer(builder), name='weighted_sum') # pylint: disable=not-callable
+ return predictions
+
+ def _add_layers(self, layers):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ for name, layer in layers.items():
+ setattr(self, 'layer-%s' % name, layer)
+ return layers
+
+
def _transform_features(features, feature_columns):
"""Returns transformed features based on features columns passed in.
@@ -1643,6 +1793,19 @@ class _FeatureColumn(object):
"""
pass
+ def _create_state(self, weight_collections=None, creator=None):
+ """Returns an object that captures the state of the column.
+
+ Args:
+ weight_collections: Collections to add the variable to
+ creator: Variable creator method called, if provided.
+
+ Returns:
+ An object that encapsulates the state of the column. Can return None.
+ """
+ del weight_collections, creator # Unused
+ return None
+
class _DenseColumn(_FeatureColumn):
"""Represents a column which can be represented as `Tensor`.
@@ -1662,7 +1825,11 @@ class _DenseColumn(_FeatureColumn):
pass
@abc.abstractmethod
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
"""Returns a `Tensor`.
The output of this function will be used by model-builder-functions. For
@@ -1680,6 +1847,9 @@ class _DenseColumn(_FeatureColumn):
will be created) are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}).
+ state: An object encapsulating the state of the column. Columns that
+ create state using the _create_state method would have that state
+ passed in to this method.
Returns:
`Tensor` of shape [batch_size] + `_variable_shape`.
@@ -1687,13 +1857,14 @@ class _DenseColumn(_FeatureColumn):
pass
-def _create_weighted_sum(
- column,
- builder,
- units,
- sparse_combiner,
- weight_collections,
- trainable):
+def _create_weighted_sum(column,
+ builder,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None,
+ state=None):
"""Creates a weighted sum for a dense or sparse column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
@@ -1702,32 +1873,50 @@ def _create_weighted_sum(
units=units,
sparse_combiner=sparse_combiner,
weight_collections=weight_collections,
- trainable=trainable)
+ trainable=trainable,
+ weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
builder=builder,
units=units,
weight_collections=weight_collections,
- trainable=trainable)
+ trainable=trainable,
+ weight_var=weight_var,
+ state=state)
-def _create_dense_column_weighted_sum(
- column, builder, units, weight_collections, trainable):
+def _create_dense_column_weighted_sum(column,
+ builder,
+ units,
+ weight_collections,
+ trainable,
+ weight_var=None,
+ state=None):
"""Create a weighted sum of a dense column for linear_model."""
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
+ if state is not None:
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ state=state)
+ else:
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=[num_elements, units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
return math_ops.matmul(tensor, weight, name='weighted_sum')
@@ -1777,8 +1966,13 @@ class _CategoricalColumn(_FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(
- column, builder, units, sparse_combiner, weight_collections, trainable):
+def _create_categorical_column_weighted_sum(column,
+ builder,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None):
"""Create a weighted sum of a categorical column for linear_model."""
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
@@ -1792,12 +1986,15 @@ def _create_categorical_column_weighted_sum(
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column._num_buckets, units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=(column._num_buckets, units), # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
return _safe_embedding_lookup_sparse(
weight,
id_tensor,
@@ -2195,8 +2392,33 @@ class _EmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor_internal(
- self, inputs, weight_collections=None, trainable=None):
+ def _create_state(self, weight_collections=None, creator=None):
+ variables_map = {}
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if creator is not None:
+ embedding_weights = creator(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable)
+ ops.add_to_collections(weight_collections, embedding_weights)
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable,
+ collections=weight_collections)
+ variables_map['embedding_weights'] = embedding_weights
+ return variables_map
+
+ def _get_dense_tensor_internal(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
@@ -2204,14 +2426,10 @@ class _EmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable and trainable,
- collections=weight_collections)
+ if state is None:
+ state = self._create_state(weight_collections)
+ embedding_weights = state['embedding_weights']
+
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -2229,7 +2447,11 @@ class _EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
if isinstance(self.categorical_column, _SequenceCategoricalColumn):
raise ValueError(
'In embedding_column: {}. '
@@ -2242,8 +2464,10 @@ class _EmbeddingColumn(
self.name, type(self.categorical_column),
self.categorical_column))
return self._get_dense_tensor_internal(
- inputs=inputs, weight_collections=weight_collections,
- trainable=trainable)
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ state=state)
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
@@ -2299,7 +2523,39 @@ class _SharedEmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _create_state(self, weight_collections=None, creator=None):
+ variables_map = {}
+ shared_embedding_collection = ops.get_collection(
+ self.shared_embedding_collection_name)
+ if not shared_embedding_collection:
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if creator is not None:
+ embedding_weights = creator(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable)
+ ops.add_to_collections(weight_collections, embedding_weights)
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable,
+ collections=weight_collections)
+ ops.add_to_collection(self.shared_embedding_collection_name,
+ embedding_weights)
+ variables_map['embedding_weights'] = embedding_weights
+
+ return variables_map
+
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None,
+ state=None):
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 6f366e7722..07588af37e 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.feature_column.feature_column import _CategoricalColumn
from tensorflow.python.feature_column.feature_column import _DenseColumn
from tensorflow.python.feature_column.feature_column import _FeatureColumn
from tensorflow.python.feature_column.feature_column import _LazyBuilder
+from tensorflow.python.feature_column.feature_column import _LinearModel
from tensorflow.python.feature_column.feature_column import _transform_features
from tensorflow.python.feature_column.feature_column import InputLayer
from tensorflow.python.framework import constant_op
@@ -339,6 +340,20 @@ class NumericColumnTest(test.TestCase):
sess.run(price_var.assign([[10.]]))
self.assertAllClose([[10.], [50.]], predictions.eval())
+ def test_keras_linear_model(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
class BucketizedColumnTest(test.TestCase):
@@ -561,6 +576,62 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
+ def test_keras_linear_model_one_input_value(self):
+ """Tests _LinearModel for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_keras_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_keras_linear_model_two_input_values(self):
+ """Tests _LinearModel for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_keras_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
class HashedCategoricalColumnTest(test.TestCase):
@@ -767,6 +838,28 @@ class HashedCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
self.assertAllClose(((4.,), (6.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_keras_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
class CrossedColumnTest(test.TestCase):
@@ -1060,6 +1153,96 @@ class CrossedColumnTest(test.TestCase):
dense_shape=(2, 2)),
}, (crossed,))
+ def test_keras_linear_model(self):
+ """Tests _LinearModel.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_keras_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
+ crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_keras_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(_CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {
+ self.name:
+ parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name):
+ parsing_ops.VarLenFeature(dtypes.float32),
+ }
+
+ @property
+ def _num_buckets(self):
+ return 5
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return _CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ get_keras_linear_model_predictions({
+ t.name:
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name):
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
def get_linear_model_bias():
with variable_scope.variable_scope('linear_model', reuse=True):
@@ -1071,6 +1254,28 @@ def get_linear_model_column_var(column):
'linear_model/' + column.name)[0]
+def get_keras_linear_model_bias():
+ with variable_scope.variable_scope('linear_model', reuse=True):
+ with variable_scope.variable_scope('bias_layer', reuse=True):
+ return variable_scope.get_variable('bias_weights')
+
+
+def get_keras_linear_model_predictions(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True):
+ keras_linear_model = _LinearModel(
+ feature_columns,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ name='linear_model')
+ return keras_linear_model(features) # pylint: disable=not-callable
+
+
@test_util.with_c_api
class LinearModelTest(test.TestCase):
@@ -1698,6 +1903,629 @@ class LinearModelTest(test.TestCase):
sess.run(net, feed_dict={features['price']: np.array(1)})
+@test_util.with_c_api
+class _LinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ get_keras_linear_model_predictions(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(_FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.linear_model(
+ features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]},
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
+
+ def test_dense_bias(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [wire_cast, price])
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(_DenseColumn, _CategoricalColumn):
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def _variable_shape(self):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def _num_buckets(self):
+ return 4
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return _CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [dense_and_sparse_column])
+ bias = get_keras_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(
+ dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
+ [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], units=3)
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
+ [1000., 1100.,
+ 1200.], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ if ops._USE_C_API:
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ predictions = get_keras_linear_model_predictions(features, [price])
+ else:
+ predictions = get_keras_linear_model_predictions(features, [price])
+ with _initialized_session():
+ with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
+ predictions.eval()
+
+ def test_dense_reshaping(self):
+ price = fc.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ bias = get_keras_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_dense_collection(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(
+ features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price])
+ bias = get_keras_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_keras_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_numpy_input_fn(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = get_keras_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price':
+ constant_op.constant([
+ -1.,
+ 12.,
+ ]),
+ 'body-style':
+ sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_keras_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = get_keras_linear_model_predictions(
+ features, [price_buckets, body_style, country])
+ bias = get_keras_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ get_keras_linear_model_predictions(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = get_keras_linear_model_predictions(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
class InputLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -2715,6 +3543,32 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
self.assertAllClose(((3.,), (5.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_keras_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
class VocabularyListCategoricalColumnTest(test.TestCase):
@@ -3082,6 +3936,31 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
# 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
self.assertAllClose(((3.,), (5.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ wire_column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_keras_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
class IdentityCategoricalColumnTest(test.TestCase):
@@ -3306,6 +4185,28 @@ class IdentityCategoricalColumnTest(test.TestCase):
# weight_var[2] + weight_var[1] = 3+2 = 5
self.assertAllClose(((1.,), (5.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_keras_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
class TransformFeaturesTest(test.TestCase):
@@ -3537,6 +4438,25 @@ class IndicatorColumnTest(test.TestCase):
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
self.assertAllClose([[2. + 3.]], predictions.eval())
+ def test_keras_linear_model(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = get_keras_linear_model_predictions(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
def test_input_layer(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
@@ -3727,6 +4647,72 @@ class EmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def test_get_dense_tensor_with_state(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Create embedding_weights variable.
+ weight_collections = [
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
+ ]
+ state = embedding_column._create_state(weight_collections)
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column._get_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input
+ }), state=state)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
('embedding_weights:0',), tuple([v.name for v in global_vars]))
with _initialized_session():
@@ -4023,6 +5009,82 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_layer/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_layer/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
def test_input_layer(self):
# Inputs.
vocabulary_size = 3
@@ -4445,6 +5507,80 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ embedding_var = global_vars[0]
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, embedding_var.eval())
+ self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
+ self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+
+ def test_get_dense_tensor_with_state(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+ input_features = {'aaa': input_a, 'bbb': input_b}
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups_a = (
+ # example 0:
+ (7., 11.), # ids [2], embedding = [7, 11]
+ # example 1:
+ (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ )
+ expected_lookups_b = (
+ # example 0:
+ (1., 2.), # ids [0], embedding = [1, 2]
+ # example 1:
+ (0., 0.), # ids [], embedding = [0, 0]
+ )
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Create state.
+ weight_collections = [
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
+ ]
+ state = embedding_column_a._create_state(weight_collections)
+
+ # Provide sparse input and get dense result.
+ embedding_lookup_a = embedding_column_a._get_dense_tensor(
+ _LazyBuilder(input_features), state=state)
+ embedding_lookup_b = embedding_column_b._get_dense_tensor(
+ _LazyBuilder(input_features), state=state)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
('embedding_weights:0',), tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
@@ -4595,6 +5731,97 @@ class SharedEmbeddingColumnTest(test.TestCase):
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 2
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column_a.name: input_a,
+ categorical_column_b.name: input_b,
+ }, (embedding_column_a, embedding_column_b))
+ # Linear weights do not follow the column name. But this is a rare use
+ # case, and fixing it would add too much complexity to the code.
+ expected_var_names = (
+ 'linear_model/bias_layer/bias_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_layer/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ linear_weights_a = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ linear_weights_b = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights_a.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
+ linear_weights_b.assign(((3.,), (5.,))).eval()
+ # example 0, ids [0], embedding[0] = [1, 2]
+ # example 1, ids [], embedding[1] = 0, 0]
+ # sum(embeddings * linear_weights)
+ # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
+ self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+
def _test_input_layer(self, trainable=True):
# Inputs.
vocabulary_size = 3
@@ -4880,6 +6107,101 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
+ def test_keras_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_keras_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_keras_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
+ get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_keras_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,))
+ with _initialized_session():
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_keras_linear_model_mismatched_dense_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_keras_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
def test_linear_model(self):
column = fc.weighted_categorical_column(
categorical_column=fc.categorical_column_with_identity(
diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py
index 6c522de452..4356a534b4 100644
--- a/tensorflow/python/framework/c_api_util.py
+++ b/tensorflow/python/framework/c_api_util.py
@@ -33,7 +33,7 @@ class ScopedTFStatus(object):
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
- if c_api.TF_DeleteStatus is not None:
+ if c_api is not None and c_api.TF_DeleteStatus is not None:
c_api.TF_DeleteStatus(self.status)
@@ -46,7 +46,7 @@ class ScopedTFGraph(object):
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
- if c_api.TF_DeleteGraph is not None:
+ if c_api is not None and c_api.TF_DeleteGraph is not None:
c_api.TF_DeleteGraph(self.graph)
@@ -59,7 +59,7 @@ class ScopedTFImportGraphDefOptions(object):
def __del__(self):
# Note: when we're destructing the global context (i.e when the process is
# terminating) we can have already deleted other modules.
- if c_api.TF_DeleteImportGraphDefOptions is not None:
+ if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
c_api.TF_DeleteImportGraphDefOptions(self.options)
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 2a40316d51..84106c32c6 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -473,6 +473,8 @@ _CODE_TO_EXCEPTION_CLASS = {
DATA_LOSS: DataLossError,
}
+c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS)
+
_EXCEPTION_CLASS_TO_CODE = dict((
(class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items()))
@@ -499,6 +501,7 @@ def _make_specific_exception(node_def, op, message, error_code):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
+# TODO(b/77295559): expand use of TF_Status* SWIG typemap and deprecate this.
@tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name
class raise_exception_on_not_ok_status(object):
"""Context manager to check for C API status."""
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 82dd2a3356..c5caf9ebc0 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -30,7 +30,6 @@ from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -275,8 +274,7 @@ class _DefinedFunction(object):
self._create_definition_if_needed()
if self._c_func:
with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
+ c_api.TF_FunctionToFunctionDef(self._c_func, buf)
fdef = function_pb2.FunctionDef()
proto_data = c_api.TF_GetBuffer(buf)
fdef.ParseFromString(compat.as_bytes(proto_data))
@@ -399,18 +397,16 @@ class _DefinedFunction(object):
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
- with errors.raise_exception_on_not_ok_status() as status:
- self._c_func = c_api.TF_GraphToFunction_wrapper(
- temp_graph._c_graph,
- base_func_name,
- self._func_name is None, # append_hash_to_fn_name
- None, # opers
- [t._as_tf_output() for t in inputs],
- [t._as_tf_output() for t in outputs],
- output_names,
- None, # opts
- description,
- status)
+ self._c_func = c_api.TF_GraphToFunction_wrapper(
+ temp_graph._c_graph,
+ base_func_name,
+ self._func_name is None, # append_hash_to_fn_name
+ None, # opers
+ [t._as_tf_output() for t in inputs],
+ [t._as_tf_output() for t in outputs],
+ output_names,
+ None, # opts
+ description)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
@@ -433,9 +429,8 @@ class _DefinedFunction(object):
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
- serialized, status)
+ c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
+ serialized)
def _create_hash_str(self, input_arg, output_arg, node_def):
"""Creates an 8-character string unique to this input.
@@ -830,8 +825,7 @@ def _from_definition(fdef, grad_func=None):
# pylint: disable=protected-access
if ops._USE_C_API:
serialized = fdef.SerializeToString()
- with errors.raise_exception_on_not_ok_status() as status:
- result._c_func = c_api.TF_FunctionImportFunctionDef(serialized, status)
+ result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
result._extra_inputs = []
else:
result._definition = fdef
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 4ea34d7bb2..23f529b988 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -485,9 +485,8 @@ def import_graph_def(graph_def,
with graph._lock: # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
- with errors.raise_exception_on_not_ok_status() as status:
- results = c_api.TF_GraphImportGraphDefWithResults(
- graph._c_graph, serialized, options, status) # pylint: disable=protected-access
+ results = c_api.TF_GraphImportGraphDefWithResults(
+ graph._c_graph, serialized, options) # pylint: disable=protected-access
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 369669c2e6..2c913d1e02 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -219,6 +219,23 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(outer_inner.name, "outer/inner_1")
self.assertEqual(outer_inner_c.name, "outer/inner/c_1")
+ def testEmptyNameScope(self):
+ with ops.Graph().as_default():
+ # Create name scope but don't create any ops with it
+ with ops.name_scope("foo"):
+ pass
+
+ # Import graph def that uses name scope name
+ op, = importer.import_graph_def(
+ self._MakeGraphDef("node { name: 'foo' op: 'IntOutput' }"),
+ return_elements=["foo"],
+ name="")
+
+ if ops._USE_C_API:
+ self.assertEqual(op.name, "foo")
+ else:
+ self.assertEqual(op.name, "foo_1")
+
def testInputMap(self):
with ops.Graph().as_default():
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 1f2aa264c1..535c6017f5 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -26,7 +26,6 @@ import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf
-from tensorflow.python.framework import errors_impl
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -54,8 +53,7 @@ def load_op_library(library_filename):
Raises:
RuntimeError: when unable to load the library or get the python wrappers.
"""
- with errors_impl.raise_exception_on_not_ok_status() as status:
- lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
+ lib_handle = py_tf.TF_LoadLibrary(library_filename)
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
@@ -99,5 +97,4 @@ def load_file_system_library(library_filename):
Raises:
RuntimeError: when unable to load the library.
"""
- with errors_impl.raise_exception_on_not_ok_status() as status:
- lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
+ py_tf.TF_LoadLibrary(library_filename)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 6930737a0c..2d55f98a1c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -63,7 +63,7 @@ from tensorflow.python.util.tf_export import tf_export
# calls to the C API. Currently disabled by default but can be manually enabled
# in code or via the environment variable. This will be removed once all
# functionality is supported and there's no performance penalty with it enabled.
-_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "0") is not "0"
+_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "1") is not "0"
_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0"
@@ -373,15 +373,12 @@ class Tensor(_TensorLike):
"""
graph = self._op._graph._c_graph # pylint: disable=protected-access
if graph and _USE_C_SHAPES:
- with errors.raise_exception_on_not_ok_status() as status:
- num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
- status)
+ num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
if num_dims == -1:
dim_list = None
else:
- with errors.raise_exception_on_not_ok_status() as status:
- dim_list = c_api.TF_GraphGetTensorShape_wrapper(
- graph, self._as_tf_output(), num_dims, status)
+ dim_list = c_api.TF_GraphGetTensorShape_wrapper(
+ graph, self._as_tf_output(), num_dims)
dim_list = [None if i == -1 else i for i in dim_list]
return tensor_shape.TensorShape(dim_list)
return self._shape_val
@@ -489,13 +486,11 @@ class Tensor(_TensorLike):
else:
dim_list.append(dim.value)
try:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_GraphSetTensorShape_wrapper(
- self._op._graph._c_graph, # pylint: disable=protected-access
- self._as_tf_output(),
- dim_list,
- unknown_shape,
- status)
+ c_api.TF_GraphSetTensorShape_wrapper(
+ self._op._graph._c_graph, # pylint: disable=protected-access
+ self._as_tf_output(),
+ dim_list,
+ unknown_shape)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@@ -1514,13 +1509,10 @@ def _create_c_op(graph, node_def, inputs, control_inputs):
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_SetAttrValueProto(op_desc,
- compat.as_str(name), serialized, status)
+ c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
try:
- with errors.raise_exception_on_not_ok_status() as status:
- c_op = c_api.TF_FinishOperation(op_desc, status)
+ c_op = c_api.TF_FinishOperation(op_desc)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@@ -1943,12 +1935,10 @@ class Operation(object):
if self._c_op:
# Reset cached inputs.
self._inputs_val = None
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.UpdateEdge(
- self._graph._c_graph, # pylint: disable=protected-access
- tensor._as_tf_output(), # pylint: disable=protected-access
- self._tf_input(index),
- status)
+ c_api.UpdateEdge(
+ self._graph._c_graph, # pylint: disable=protected-access
+ tensor._as_tf_output(), # pylint: disable=protected-access
+ self._tf_input(index))
else:
self._inputs_val[index].consumers().remove(self)
self._inputs_val[index] = tensor
@@ -2124,6 +2114,30 @@ class Operation(object):
return self._control_inputs_val
@property
+ def _control_outputs(self):
+ """The `Operation` objects which have a control dependency on this op.
+
+ Before any of the ops in self._control_outputs can execute tensorflow will
+ ensure self has finished executing.
+
+ Returns:
+ A list of `Operation` objects.
+
+ """
+ if self._c_op:
+ control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [
+ self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+ ]
+ # pylint: enable=protected-access
+ else:
+ # TODO(apassos) this should be less inefficient.
+ return [o for o in self._graph.get_operations()
+ if self in o.control_inputs]
+
+ @property
def _control_inputs(self):
logging.warning("Operation._control_inputs is private, use "
"Operation.control_inputs instead. "
@@ -2169,8 +2183,7 @@ class Operation(object):
# pylint: enable=line-too-long
if self._c_op:
with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_OperationToNodeDef(self._c_op, buf, status)
+ c_api.TF_OperationToNodeDef(self._c_op, buf)
data = c_api.TF_GetBuffer(buf)
node_def = node_def_pb2.NodeDef()
node_def.ParseFromString(compat.as_bytes(data))
@@ -2228,11 +2241,9 @@ class Operation(object):
buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
- with errors.raise_exception_on_not_ok_status() as status:
- # pylint: disable=protected-access
- c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
- status)
- # pylint: enable=protected-access
+ # pylint: disable=protected-access
+ c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
+ # pylint: enable=protected-access
finally:
c_api.TF_DeleteBuffer(buf)
else:
@@ -2254,8 +2265,7 @@ class Operation(object):
if self._c_op:
try:
with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
+ c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
data = c_api.TF_GetBuffer(buf)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
@@ -2469,11 +2479,10 @@ def _set_shapes_for_outputs_c_api(op):
# The C API computes the shapes when the TF_Operation is created. Fetch the
# output shapes from the C object.
for output in op.outputs:
- with errors.raise_exception_on_not_ok_status() as status:
- # pylint: disable=protected-access
- shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
- op._graph._c_graph, output._as_tf_output(), status)
- # pylint: enable=protected-access
+ # pylint: disable=protected-access
+ shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+ op._graph._c_graph, output._as_tf_output())
+ # pylint: enable=protected-access
if unknown_shape:
output.set_shape(tensor_shape.unknown_shape())
elif not shape_vector:
@@ -2994,8 +3003,7 @@ class Graph(object):
# pylint: enable=line-too-long
if self._c_graph:
with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_GraphVersions(self._c_graph, buf, status)
+ c_api.TF_GraphVersions(self._c_graph, buf)
data = c_api.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data))
@@ -3098,8 +3106,7 @@ class Graph(object):
if self._c_graph:
with self._lock:
with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_GraphToGraphDef(self._c_graph, buf, status)
+ c_api.TF_GraphToGraphDef(self._c_graph, buf)
data = c_api.TF_GetBuffer(buf)
graph = graph_pb2.GraphDef()
graph.ParseFromString(compat.as_bytes(data))
@@ -3208,14 +3215,10 @@ class Graph(object):
# remove this when all functions are generated using the C API by default
# as this will be unnecessary.
if not function._c_func:
- with errors.raise_exception_on_not_ok_status() as status:
- serialized = function.definition.SerializeToString()
- function._c_func = c_api.TF_FunctionImportFunctionDef(
- serialized, status)
- with errors.raise_exception_on_not_ok_status() as status:
- gradient = function._grad_func._c_func if function._grad_func else None
- c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
- status)
+ serialized = function.definition.SerializeToString()
+ function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ gradient = function._grad_func._c_func if function._grad_func else None
+ c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
else:
# If there is already a function with the same name, raise an error
# if bodies are different. Else, do nothing. The C API version above
@@ -3365,8 +3368,12 @@ class Graph(object):
"""
self._check_not_finalized()
ret = Operation(c_op, self)
- assert ret.name not in self._names_in_use
- self._names_in_use[ret.name] = 1
+ # If a name_scope was created with ret.name but no nodes were created in it,
+ # the name will still appear in _names_in_use even though the name hasn't
+ # been used. This is ok, just leave _names_in_use as-is in this case.
+ # TODO(skyewm): make the C API guarantee no name conflicts.
+ if ret.name not in self._names_in_use:
+ self._names_in_use[ret.name] = 1
self._create_op_helper(ret, compute_device=compute_device)
return ret
@@ -3732,11 +3739,9 @@ class Graph(object):
"""Returns the `OpDef` proto for `type`. `type` is a string."""
if self._c_graph:
with c_api_util.tf_buffer() as buf:
- with errors.raise_exception_on_not_ok_status() as status:
- # pylint: disable=protected-access
- c_api.TF_GraphGetOpDef(self._c_graph,
- compat.as_bytes(type), buf, status)
- # pylint: enable=protected-access
+ # pylint: disable=protected-access
+ c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
+ # pylint: enable=protected-access
data = c_api.TF_GetBuffer(buf)
op_def = op_def_pb2.OpDef()
op_def.ParseFromString(compat.as_bytes(data))
@@ -4512,6 +4517,22 @@ class Graph(object):
return tf.matmul(tensor, tensor)
```
+ Also note that though execution of ops created under this scope will trigger
+ execution of the dependencies, the ops created under this scope might still
+ be pruned from a normal tensorflow graph. For example, in the following
+ snippet of code the dependencies are never executed:
+
+ ```python
+ loss = model.loss()
+ with tf.control_dependencies(dependencies):
+ loss = loss + tf.constant(1) # note: dependencies ignored in the
+ # backward pass
+ return tf.gradients(loss, model.variables)
+ ```
+
+ This is because evaluating the gradient graph does not require evaluating
+ the constant(1) op created in the forward pass.
+
Args:
control_inputs: A list of `Operation` or `Tensor` objects which
must be executed or computed before running the operations
@@ -5350,6 +5371,10 @@ def enable_eager_execution(config=None, device_policy=None,
raise ValueError(
"tf.enable_eager_execution must be called at program startup.")
+ # Monkey patch to get rid of an unnecessary conditional since the context is
+ # now initialized.
+ context.context = context.context_safe
+
def eager_run(main=None, argv=None):
"""Runs the program with an optional main function and argv list.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index aa51391871..58bead91ed 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -473,6 +473,7 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(z.control_inputs, [x, x])
z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x, x, x, y, y])
+ self.assertEqual(x._control_outputs, [z])
def testAddControlInputC(self):
# The C API dedups redundant control edges, pure Python does not
@@ -487,6 +488,7 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(z.control_inputs, [x])
z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x, y])
+ self.assertEqual(x._control_outputs, [z])
def testRemoveAllControlInputs(self):
a = constant_op.constant(1)
diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py
index c7ff23e4ff..48a834392b 100644
--- a/tensorflow/python/framework/smart_cond.py
+++ b/tensorflow/python/framework/smart_cond.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as c_api
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
@@ -83,9 +82,8 @@ def smart_constant_value(pred):
# wanted to limit the change hidden behind _USE_C_API).
# pylint: disable=protected-access
if pred_value is None and ops._USE_C_API:
- with errors.raise_exception_on_not_ok_status() as status:
- pred_value = c_api.TF_TryEvaluateConstant_wrapper(
- pred.graph._c_graph, pred._as_tf_output(), status)
+ pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
+ pred._as_tf_output())
# pylint: enable=protected-access
else:
diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py
index 06955b8858..d08b4bf48a 100644
--- a/tensorflow/python/framework/versions.py
+++ b/tensorflow/python/framework/versions.py
@@ -29,7 +29,7 @@ __cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__
__monolithic_build__ = pywrap_tensorflow.__monolithic_build__
VERSION = __version__
-tf_export("VERSION").export_constant(__name__, "VERSION")
+tf_export("VERSION", "__version__").export_constant(__name__, "VERSION")
GIT_VERSION = __git_version__
tf_export("GIT_VERSION").export_constant(__name__, "GIT_VERSION")
COMPILER_VERSION = __compiler_version__
diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py
new file mode 100644
index 0000000000..ab1d0ed25b
--- /dev/null
+++ b/tensorflow/python/grappler/constant_folding_test.py
@@ -0,0 +1,69 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Grappler Constant Folding."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ConstantFoldingTest(test.TestCase):
+
+ # See b/76008022.
+ def testScanInsideWhile(self):
+
+ def loop_cond(idx_step, *unused_args):
+ return idx_step < 1
+
+ def loop_body(idx_step, y):
+ x = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
+ x = functional_ops.scan(
+ math_ops.add,
+ x,
+ initializer=array_ops.zeros([20, 30], dtype=dtypes.float32),
+ back_prop=False,
+ parallel_iterations=1)
+
+ with ops.device('/cpu:0'):
+ y = array_ops.identity(x)
+
+ return idx_step + 1, y
+
+ if test.is_gpu_available(cuda_only=True):
+ init_y = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
+ _, y = control_flow_ops.while_loop(
+ loop_cond,
+ loop_body,
+ loop_vars=[0, init_y],
+ back_prop=False,
+ parallel_iterations=1)
+ with session.Session() as sess:
+ y_v = sess.run(y)
+ self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py
index 4a083849bd..1748efdd13 100644
--- a/tensorflow/python/grappler/item.py
+++ b/tensorflow/python/grappler/item.py
@@ -51,9 +51,7 @@ class Item(object):
self._BuildTFItem()
def IdentifyImportantOps(self, sort_topologically=False):
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically,
- status)
+ return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically)
def GetOpProperties(self):
ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item)
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 3ee4d7807e..1c0f072dd3 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -17,12 +17,16 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.grappler import item as gitem
from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -74,6 +78,47 @@ class PyWrapOptimizeGraphTest(test.TestCase):
self.assertEqual(a2.op.name, optimized_graph.node[3].name)
self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
+ def testLoops(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ def _Cond(_, counter):
+ return counter < end
+
+ def _Body(buf, counter):
+ buf = array_ops.concat([buf, [counter]], 0)
+ counter += 1
+ return [buf, counter]
+
+ start = array_ops.placeholder(shape=[], dtype=dtypes.int32)
+ end = array_ops.placeholder(shape=[], dtype=dtypes.int32)
+ init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32)
+ loop_vars = [init_buf, start]
+ shape_inv = [
+ tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([])
+ ]
+ buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv)
+
+ f = -array_ops.ones_like(buf, optimize=False)
+ buf_shape = array_ops.shape(buf)
+ f_shape = array_ops.shape(f)
+ ops.add_to_collection('train_op', buf_shape)
+ ops.add_to_collection('train_op', f_shape)
+
+ # Optimize the graph.
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
+ mg.graph_def.CopyFrom(optimized_graph)
+
+ # Check that the nodes referenced in various collections have been preserved
+ item = gitem.Item(mg)
+ props = item.GetOpProperties()
+ buf_prop = props[buf.op.name]
+ f_prop = props[f.op.name]
+ self.assertEqual(buf_prop, f_prop)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 2a06907f49..57f5097639 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -637,7 +637,10 @@ py_test(
size = "small",
srcs = ["_impl/keras/utils/io_utils_test.py"],
srcs_version = "PY2AND3",
- tags = ["notsan"],
+ tags = [
+ "no_windows", # TODO: needs investigation on Windows
+ "notsan",
+ ],
deps = [
":keras",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py
index 236e17653e..b518898ad8 100644
--- a/tensorflow/python/keras/_impl/keras/activations.py
+++ b/tensorflow/python/keras/_impl/keras/activations.py
@@ -23,6 +23,8 @@ import six
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.layers.base import Layer
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -43,10 +45,10 @@ def softmax(x, axis=-1):
"""
ndim = K.ndim(x)
if ndim == 2:
- return K.softmax(x)
+ return nn.softmax(x)
elif ndim > 2:
- e = K.exp(x - K.max(x, axis=axis, keepdims=True))
- s = K.sum(e, axis=axis, keepdims=True)
+ e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
+ s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
return e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D')
@@ -79,12 +81,12 @@ def selu(x):
@tf_export('keras.activations.softplus')
def softplus(x):
- return K.softplus(x)
+ return nn.softplus(x)
@tf_export('keras.activations.softsign')
def softsign(x):
- return K.softsign(x)
+ return nn.softsign(x)
@tf_export('keras.activations.relu')
@@ -94,12 +96,12 @@ def relu(x, alpha=0., max_value=None):
@tf_export('keras.activations.tanh')
def tanh(x):
- return K.tanh(x)
+ return nn.tanh(x)
@tf_export('keras.activations.sigmoid')
def sigmoid(x):
- return K.sigmoid(x)
+ return nn.sigmoid(x)
@tf_export('keras.activations.hard_sigmoid')
diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
index c26a28ed40..d928a7afdc 100644
--- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
@@ -22,8 +22,10 @@ import json
import numpy as np
+from tensorflow.python.framework import constant_op
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -151,11 +153,11 @@ def _preprocess_symbolic_input(x, data_format, mode):
std = None
if _IMAGENET_MEAN is None:
- _IMAGENET_MEAN = K.constant(-np.array(mean))
+ _IMAGENET_MEAN = constant_op.constant(-np.array(mean), dtype=K.floatx())
# Zero-center by mean pixel
if K.dtype(x) != K.dtype(_IMAGENET_MEAN):
- x = K.bias_add(x, K.cast(_IMAGENET_MEAN, K.dtype(x)), data_format)
+ x = K.bias_add(x, math_ops.cast(_IMAGENET_MEAN, K.dtype(x)), data_format)
else:
x = K.bias_add(x, _IMAGENET_MEAN, data_format)
if std is not None:
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 7baf27642a..3aac6a9065 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -2795,6 +2796,8 @@ class Function(object):
else:
feed_dict = {}
+ session = get_session()
+ data_tensors_to_feed = []
for tensor, value in zip(self.inputs, inputs):
if value is None:
continue
@@ -2803,9 +2806,20 @@ class Function(object):
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
np.expand_dims(sparse_coo.col, 1)), 1)
value = (indices, sparse_coo.data, sparse_coo.shape)
- feed_dict[tensor] = value
+ elif tensor_util.is_tensor(value):
+ data_tensors_to_feed.append((tensor, value))
+ else:
+ feed_dict[tensor] = value
+
+ if data_tensors_to_feed:
+ # This is a *temporary* workaround (i.e. hack) to feed a symbolic tensor
+ # to `feed_dict`. It is very inefficient. It will be removed as soon
+ # as it becomes possible to pass symbolic tensors to `feed_dict`.
+ data_tensor_values = session.run([x[1] for x in data_tensors_to_feed])
+ for i, v in enumerate(data_tensor_values):
+ feed_dict[data_tensors_to_feed[i][0]] = v
+
fetches = self.outputs + [self.updates_op] + self.fetches
- session = get_session()
updated = session.run(
fetches=fetches, feed_dict=feed_dict, **self.session_kwargs)
return updated[:len(self.outputs)]
diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py
index 271fbbb63d..abe95d8e0c 100644
--- a/tensorflow/python/keras/_impl/keras/constraints.py
+++ b/tensorflow/python/keras/_impl/keras/constraints.py
@@ -24,6 +24,7 @@ import six
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -65,7 +66,8 @@ class MaxNorm(Constraint):
self.axis = axis
def __call__(self, w):
- norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True))
+ norms = K.sqrt(
+ math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
desired = K.clip(norms, 0, self.max_value)
return w * (desired / (K.epsilon() + norms))
@@ -79,7 +81,7 @@ class NonNeg(Constraint):
"""
def __call__(self, w):
- return w * K.cast(K.greater_equal(w, 0.), K.floatx())
+ return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
@tf_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
@@ -105,7 +107,9 @@ class UnitNorm(Constraint):
def __call__(self, w):
return w / (
- K.epsilon() + K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True)))
+ K.epsilon() + K.sqrt(
+ math_ops.reduce_sum(
+ math_ops.square(w), axis=self.axis, keepdims=True)))
def get_config(self):
return {'axis': self.axis}
@@ -148,7 +152,8 @@ class MinMaxNorm(Constraint):
self.axis = axis
def __call__(self, w):
- norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True))
+ norms = K.sqrt(
+ math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
desired = (
self.rate * K.clip(norms, self.min_value, self.max_value) +
(1 - self.rate) * norms)
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
index b50277c8ff..9ab4b6fdcf 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -783,7 +783,7 @@ class TopologyConstructionTest(test.TestCase):
def test_activity_regularization_with_model_composition(self):
def reg(x):
- return keras.backend.sum(x)
+ return math_ops.reduce_sum(x)
net_a_input = keras.Input((2,))
net_a = net_a_input
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 971245c162..71de657da8 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -1181,6 +1181,9 @@ class Model(Network):
batch_size=batch_size)
elif validation_split and 0. < validation_split < 1.:
+ if training_utils.has_symbolic_tensors(x):
+ raise ValueError('If your data is in the form of symbolic tensors, '
+ 'you cannot use `validation_split`.')
if hasattr(x[0], 'shape'):
split_at = int(x[0].shape[0] * (1. - validation_split))
else:
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
index 67858a578c..4cdb5f108a 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
@@ -31,9 +31,8 @@ from tensorflow.python.keras._impl.keras import callbacks as cbks
from tensorflow.python.keras._impl.keras import losses
from tensorflow.python.keras._impl.keras import metrics as metrics_module
from tensorflow.python.keras._impl.keras.engine import training_utils
-from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches
-from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
-from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays
+from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -173,6 +172,41 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
return outs, total_loss, loss_metrics
+def slice_arrays(arrays, indices, contiguous=True):
+ """Slices batches out of provided arrays (workaround for eager tensors).
+
+ Unfortunately eager tensors don't have the same slicing behavior as
+ Numpy arrays (they folow the same slicing behavior as symbolic TF tensors),
+ hence we cannot use `generic_utils.slice_arrays` directly
+ and we have to implement this workaround based on `concat`. This has a
+ performance cost.
+
+ Arguments:
+ arrays: Single array or list of arrays.
+ indices: List of indices in the array that should be included in the output
+ batch.
+ contiguous: Boolean flag indicating whether the indices are contiguous.
+
+ Returns:
+ Slice of data (either single array or list of arrays).
+ """
+ if any(tensor_util.is_tensor(x) for x in arrays):
+ converted_to_list = False
+ if not isinstance(arrays, list):
+ converted_to_list = True
+ arrays = [arrays]
+ if not contiguous:
+ entries = [[x[i:i + 1] for i in indices] for x in arrays]
+ slices = [array_ops.concat(x, axis=0) for x in entries]
+ else:
+ slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
+ if converted_to_list:
+ slices = slices[0]
+ return slices
+ else:
+ return generic_utils.slice_arrays(arrays, indices)
+
+
def _process_single_batch(model,
inputs,
targets,
@@ -270,9 +304,8 @@ def test_on_batch(model, inputs, targets, sample_weights=None):
model, inputs, targets, sample_weights=sample_weights, training=False)
if not isinstance(outs, list):
outs = [outs]
- metric_names, metrics_results = _eager_metrics_fn(
+ _, metrics_results = _eager_metrics_fn(
model, outs, targets)
- model.metrics_names.append(metric_names)
if not isinstance(loss, list):
loss = [loss]
return loss + loss_metrics + metrics_results
@@ -328,6 +361,12 @@ def fit_loop(
Raises:
ValueError: In case of invalid argument values.
"""
+ if not batch_size:
+ raise ValueError('With eager execution, `batch_size` should be specified.')
+ if steps_per_epoch or validation_steps:
+ raise ValueError('With eager execution, `steps_per_epoch` and '
+ '`validation_steps` are not valid arguments '
+ '(set `batch_size` instead).')
# Required for Eager mode
with backend.learning_phase_scope(1):
do_validation = False
@@ -410,15 +449,18 @@ def fit_loop(
elif shuffle:
np.random.shuffle(index_array)
- batches = make_batches(num_train_samples, batch_size)
+ batches = generic_utils.make_batches(num_train_samples, batch_size)
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
try:
- inputs_batch = slice_arrays(inputs, batch_ids)
- targets_batch = slice_arrays(targets, batch_ids)
+ inputs_batch = slice_arrays(inputs, batch_ids,
+ contiguous=not shuffle)
+ targets_batch = slice_arrays(targets, batch_ids,
+ contiguous=not shuffle)
if sample_weights:
- sample_weights_batch = slice_arrays(sample_weights, batch_ids)
+ sample_weights_batch = slice_arrays(sample_weights, batch_ids,
+ contiguous=not shuffle)
else:
sample_weights_batch = None
except TypeError:
@@ -539,8 +581,8 @@ def test_loop(model, inputs, targets,
feed_data, batch_size=batch_size, steps=steps, steps_name='steps')
outs = []
if verbose == 1:
- progbar = Progbar(target=num_samples)
- batches = make_batches(num_samples, batch_size)
+ progbar = generic_utils.Progbar(target=num_samples)
+ batches = generic_utils.make_batches(num_samples, batch_size)
index_array = np.arange(num_samples)
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
@@ -620,12 +662,12 @@ def predict_loop(model, inputs,
inputs, batch_size, steps, 'steps')
if verbose == 1:
if steps is not None:
- progbar = Progbar(target=steps)
+ progbar = generic_utils.Progbar(target=steps)
else:
- progbar = Progbar(target=num_samples)
+ progbar = generic_utils.Progbar(target=num_samples)
outs = []
- batches = make_batches(num_samples, batch_size)
+ batches = generic_utils.make_batches(num_samples, batch_size)
index_array = np.arange(num_samples)
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
index 8848b393d5..6cdb6b0753 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import numpy as np
from tensorflow.python.framework import ops
@@ -308,6 +307,100 @@ class TrainingTest(test.TestCase):
model.compile(loss=None,
optimizer='rms')
+ def test_model_methods_with_eager_tensors_multi_io(self):
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ metrics = ['mae']
+ model.compile(
+ optimizer,
+ loss,
+ metrics=metrics,
+ loss_weights=loss_weights,
+ sample_weight_mode=None)
+
+ input_a = keras.backend.zeros(shape=(10, 3))
+ input_b = keras.backend.zeros(shape=(10, 3))
+ target_d = keras.backend.zeros(shape=(10, 4))
+ target_e = keras.backend.zeros(shape=(10, 4))
+
+ model.fit(
+ [input_a, input_b], [target_d, target_e],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ # Test: no shuffle.
+ model.fit(
+ [input_a, input_b], [target_d, target_e],
+ epochs=1,
+ batch_size=5,
+ verbose=0,
+ shuffle=False)
+ # Test: validation data.
+ model.fit([input_a, input_b], [target_d, target_e],
+ epochs=1, batch_size=2, verbose=0,
+ validation_data=([input_a, input_b], [target_d, target_e]))
+ model.train_on_batch([input_a, input_b], [target_d, target_e])
+ model.predict([input_a, input_b], batch_size=5)
+ model.evaluate([input_a, input_b], [target_d, target_e],
+ batch_size=2, verbose=0)
+ model.test_on_batch([input_a, input_b], [target_d, target_e])
+
+ # Test: mix np and tensors.
+ input_b = np.zeros(shape=(10, 3)).astype('float32')
+ target_e = np.zeros(shape=(10, 4)).astype('float32')
+ model.fit(
+ [input_a, input_b], [target_d, target_e],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit([input_a, input_b], [target_d, target_e],
+ epochs=1, batch_size=2, verbose=0,
+ validation_data=([input_a, input_b], [target_d, target_e]))
+ model.fit(
+ [input_a, input_b], [target_d, target_e],
+ epochs=1,
+ batch_size=5,
+ verbose=0,
+ shuffle=False)
+ model.train_on_batch([input_a, input_b], [target_d, target_e])
+ model.predict([input_a, input_b], batch_size=5)
+ model.evaluate([input_a, input_b], [target_d, target_e],
+ batch_size=2, verbose=0)
+ model.test_on_batch([input_a, input_b], [target_d, target_e])
+
+ def test_model_methods_with_eager_tensors_single_io(self):
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = keras.backend.zeros(shape=(10, 3))
+ targets = keras.backend.zeros(shape=(10, 4))
+
+ model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0)
+ model.fit(inputs, targets, epochs=1, batch_size=3, verbose=0, shuffle=False)
+ model.fit(inputs, targets, epochs=1, batch_size=4, verbose=0,
+ validation_data=(inputs, targets))
+ model.evaluate(inputs, targets, batch_size=2, verbose=0)
+ model.predict(inputs, batch_size=2)
+ model.train_on_batch(inputs, targets)
+ model.test_on_batch(inputs, targets)
+
class LossWeightingTest(test.TestCase):
@@ -533,14 +626,5 @@ class LossWeightingTest(test.TestCase):
if __name__ == '__main__':
- # Bazel sets these environment variables to very long paths.
- # Tempfile uses them to create long paths, and in turn multiprocessing
- # library tries to create sockets named after paths. Delete whatever bazel
- # writes to these to avoid tests failing due to socket addresses being too
- # long.
- for var in ('TMPDIR', 'TMP', 'TEMP'):
- if var in os.environ:
- del os.environ[var]
-
ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index fd91dbba52..08fd26dd18 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -1117,6 +1117,121 @@ class TestTrainingUtils(test.TestCase):
class TestTrainingWithDataTensors(test.TestCase):
+ def test_training_and_eval_methods_on_symbolic_tensors_single_io(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = 'rmsprop'
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = keras.backend.zeros(shape=(10, 3))
+ targets = keras.backend.zeros(shape=(10, 4))
+
+ model.fit(inputs, targets, epochs=1, steps_per_epoch=2, verbose=0)
+ model.evaluate(inputs, targets, steps=2, verbose=0)
+ model.predict(inputs, steps=2)
+ model.train_on_batch(inputs, targets)
+ model.test_on_batch(inputs, targets)
+ model.fit(inputs, targets,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=(inputs, targets), validation_steps=2)
+
+ def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self):
+ with self.test_session():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = 'rmsprop'
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+
+ input_a_tf = keras.backend.zeros(shape=(10, 3))
+ input_b_tf = keras.backend.zeros(shape=(10, 3))
+
+ output_d_tf = keras.backend.zeros(shape=(10, 4))
+ output_e_tf = keras.backend.zeros(shape=(10, 4))
+
+ model.fit(
+ [input_a_tf, input_b_tf], [output_d_tf, output_e_tf],
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'should specify the `steps_per_epoch`'):
+ model.fit(
+ [input_a_tf, input_b_tf], [output_d_tf, output_e_tf],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.train_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf])
+
+ # Test with dictionary inputs
+ model.fit(
+ {'input_a': input_a_tf,
+ 'input_b': input_b_tf},
+ {'dense': output_d_tf,
+ 'dropout': output_e_tf},
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0)
+ model.fit(
+ {'input_a': input_a_tf,
+ 'input_b': input_b_tf},
+ {'dense': output_d_tf,
+ 'dropout': output_e_tf},
+ validation_data=({'input_a': input_a_tf,
+ 'input_b': input_b_tf},
+ {'dense': output_d_tf,
+ 'dropout': output_e_tf}),
+ epochs=1,
+ steps_per_epoch=2,
+ validation_steps=2,
+ verbose=0)
+ model.train_on_batch(
+ {'input_a': input_a_tf,
+ 'input_b': input_b_tf},
+ {'dense': output_d_tf,
+ 'dropout': output_e_tf})
+
+ # Test with validation data
+ model.fit(
+ [input_a_tf, input_b_tf], [output_d_tf, output_e_tf],
+ validation_data=([input_a_tf, input_b_tf],
+ [output_d_tf, output_e_tf]),
+ epochs=1,
+ steps_per_epoch=2,
+ validation_steps=2,
+ verbose=0)
+ # Test with validation split
+ with self.assertRaisesRegexp(ValueError,
+ 'you cannot use `validation_split`'):
+ model.fit(
+ [input_a_tf, input_b_tf], [output_d_tf, output_e_tf],
+ epochs=2,
+ steps_per_epoch=2,
+ verbose=0,
+ validation_split=0.2,
+ validation_steps=2)
+
+ # Test evaluation / prediction methods
+ model.evaluate([input_a_tf, input_b_tf], [output_d_tf, output_e_tf],
+ steps=2, verbose=0)
+ model.predict([input_a_tf, input_b_tf], steps=2)
+ model.test_on_batch([input_a_tf, input_b_tf], [output_d_tf, output_e_tf])
+
def test_model_with_input_feed_tensor(self):
"""We test building a model with a TF variable as input.
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py
index 105638ce10..a3fc8ef2a0 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py
@@ -22,9 +22,11 @@ import copy
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import losses
+from tensorflow.python.ops import math_ops
def check_num_samples(ins,
@@ -64,15 +66,29 @@ def check_num_samples(ins,
if batch_size is not None:
raise ValueError(
'If ' + steps_name + ' is set, the `batch_size` must be None.')
- elif ins and hasattr(ins[0], 'shape'):
- num_samples = ins[0].shape[0]
- else:
+ if has_symbolic_tensors(ins) and steps is None:
+ raise ValueError('If your data is in the form of symbolic tensors, '
+ 'you should specify the `' + steps_name + '` argument '
+ '(instead of the `batch_size` argument).')
+ if ins and hasattr(ins[0], 'shape'):
+ num_samples = int(ins[0].shape[0])
+ elif steps is None:
raise ValueError(
'Either the input data should have '
'a defined shape, or ' + steps_name + ' should be specified.')
return num_samples
+def standardize_single_array(x):
+ if x is None:
+ return None
+ elif tensor_util.is_tensor(x):
+ return x
+ elif x.ndim == 1:
+ x = np.expand_dims(x, 1)
+ return x
+
+
def standardize_input_data(data,
names,
shapes=None,
@@ -130,9 +146,7 @@ def standardize_input_data(data,
else:
data = data.values if data.__class__.__name__ == 'DataFrame' else data
data = [data]
- data = [
- np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data
- ]
+ data = [standardize_single_array(x) for x in data]
if len(data) != len(names):
if data and hasattr(data[0], 'shape'):
@@ -158,7 +172,7 @@ def standardize_input_data(data,
# Check shapes compatibility.
if shapes:
for i in range(len(names)):
- if shapes[i] is not None:
+ if shapes[i] is not None and not tensor_util.is_tensor(data[i]):
data_shape = data[i].shape
shape = shapes[i]
if data[i].ndim != len(shape):
@@ -245,12 +259,13 @@ def check_array_lengths(inputs, targets, weights=None):
"""
def set_of_lengths(x):
- # return a set with the variation between
+ # Returns a set with the variation between
# different shapes, with None => 0
if x is None:
return {}
else:
- return set([y.shape[0] for y in x if y is not None])
+ return set([y.shape[0] for y in x
+ if y is not None and not tensor_util.is_tensor(y)])
set_x = set_of_lengths(inputs)
set_y = set_of_lengths(targets)
@@ -422,7 +437,7 @@ def weighted_masked_objective(fn):
score_array = fn(y_true, y_pred)
if mask is not None:
# Cast the mask to floatX to avoid float64 upcasting in theano
- mask = K.cast(mask, K.floatx())
+ mask = math_ops.cast(mask, K.floatx())
# mask should have the same shape as score_array
score_array *= mask
# the loss per batch should be proportional
@@ -436,7 +451,8 @@ def weighted_masked_objective(fn):
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
score_array *= weights
- score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
+ score_array /= K.mean(
+ math_ops.cast(math_ops.not_equal(weights, 0), K.floatx()))
return K.mean(score_array)
return weighted
@@ -532,3 +548,8 @@ def standardize_weights(y,
return weights
else:
return None
+
+
+def has_symbolic_tensors(ls):
+ return (any(tensor_util.is_tensor(v) for v in ls)
+ and not context.executing_eagerly())
diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py
index 8426d84df9..5d370ebbb5 100644
--- a/tensorflow/python/keras/_impl/keras/estimator.py
+++ b/tensorflow/python/keras/_impl/keras/estimator.py
@@ -466,8 +466,8 @@ def model_to_estimator(keras_model=None,
keras_model_fn, model_dir=model_dir, config=config)
# Pass the config into keras backend's default session.
- with session.Session(config=estimator._session_config) as sess:
- K.set_session(sess)
+ sess = session.Session(config=estimator._session_config)
+ K.set_session(sess)
keras_weights = keras_model.get_weights()
if keras_model._is_graph_network:
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
index c40ee109aa..11ca89d625 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -146,7 +147,7 @@ class PReLU(Layer):
if K.backend() == 'theano':
neg = (
K.pattern_broadcast(self.alpha, self.param_broadcast) *
- (inputs - K.abs(inputs)) * 0.5)
+ (inputs - math_ops.abs(inputs)) * 0.5)
else:
neg = -self.alpha * K.relu(-inputs)
return pos + neg
@@ -232,7 +233,8 @@ class ThresholdedReLU(Layer):
self.theta = K.cast_to_floatx(theta)
def call(self, inputs, mask=None):
- return inputs * K.cast(K.greater(inputs, self.theta), K.floatx())
+ return inputs * math_ops.cast(
+ math_ops.greater(inputs, self.theta), K.floatx())
def get_config(self):
config = {'theta': float(self.theta)}
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index d95a094245..b78962d66a 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -29,6 +29,8 @@ from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -438,9 +440,9 @@ class ConvLSTM2D(ConvRecurrent2D):
def get_initial_state(self, inputs):
# (samples, timesteps, rows, cols, filters)
- initial_state = K.zeros_like(inputs)
+ initial_state = array_ops.zeros_like(inputs)
# (samples, rows, cols, filters)
- initial_state = K.sum(initial_state, axis=1)
+ initial_state = math_ops.reduce_sum(initial_state, axis=1)
shape = list(self.kernel_shape)
shape[-1] = self.filters
initial_state = self.input_conv(
@@ -483,8 +485,8 @@ class ConvLSTM2D(ConvRecurrent2D):
def get_constants(self, inputs, training=None):
constants = []
if self.implementation == 0 and 0 < self.dropout < 1:
- ones = K.zeros_like(inputs)
- ones = K.sum(ones, axis=1)
+ ones = array_ops.zeros_like(inputs)
+ ones = math_ops.reduce_sum(ones, axis=1)
ones += 1
def dropped_inputs():
@@ -501,8 +503,8 @@ class ConvLSTM2D(ConvRecurrent2D):
if 0 < self.recurrent_dropout < 1:
shape = list(self.kernel_shape)
shape[-1] = self.filters
- ones = K.zeros_like(inputs)
- ones = K.sum(ones, axis=1)
+ ones = array_ops.zeros_like(inputs)
+ ones = math_ops.reduce_sum(ones, axis=1)
ones = self.input_conv(ones, K.zeros(shape), padding=self.padding)
ones += 1.
diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py
index 73e4f15f7e..c74fc1e4c0 100644
--- a/tensorflow/python/keras/_impl/keras/layers/core.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core.py
@@ -37,6 +37,8 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import func_dump
from tensorflow.python.keras._impl.keras.utils.generic_utils import func_load
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.layers import core as tf_core_layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -75,12 +77,12 @@ class Masking(Layer):
self.mask_value = mask_value
def compute_mask(self, inputs, mask=None):
- return K.any(K.not_equal(inputs, self.mask_value), axis=-1)
+ return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1)
def call(self, inputs):
boolean_mask = K.any(
- K.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
- return inputs * K.cast(boolean_mask, inputs.dtype)
+ math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
+ return inputs * math_ops.cast(boolean_mask, inputs.dtype)
def compute_output_shape(self, input_shape):
return input_shape
@@ -170,7 +172,7 @@ class SpatialDropout1D(Dropout):
self.input_spec = InputSpec(ndim=3)
def _get_noise_shape(self, inputs):
- input_shape = K.shape(inputs)
+ input_shape = array_ops.shape(inputs)
noise_shape = (input_shape[0], 1, input_shape[2])
return noise_shape
@@ -222,7 +224,7 @@ class SpatialDropout2D(Dropout):
self.input_spec = InputSpec(ndim=4)
def _get_noise_shape(self, inputs):
- input_shape = K.shape(inputs)
+ input_shape = array_ops.shape(inputs)
if self.data_format == 'channels_first':
return (input_shape[0], input_shape[1], 1, 1)
elif self.data_format == 'channels_last':
@@ -275,7 +277,7 @@ class SpatialDropout3D(Dropout):
self.input_spec = InputSpec(ndim=5)
def _get_noise_shape(self, inputs):
- input_shape = K.shape(inputs)
+ input_shape = array_ops.shape(inputs)
if self.data_format == 'channels_first':
return (input_shape[0], input_shape[1], 1, 1, 1)
elif self.data_format == 'channels_last':
@@ -414,7 +416,8 @@ class Reshape(Layer):
return tensor_shape.TensorShape(output_shape)
def call(self, inputs):
- return K.reshape(inputs, (K.shape(inputs)[0],) + self.target_shape)
+ return array_ops.reshape(inputs,
+ (array_ops.shape(inputs)[0],) + self.target_shape)
def get_config(self):
config = {'target_shape': self.target_shape}
@@ -467,7 +470,7 @@ class Permute(Layer):
return tensor_shape.TensorShape(output_shape)
def call(self, inputs):
- return K.permute_dimensions(inputs, (0,) + self.dims)
+ return array_ops.transpose(inputs, perm=(0,) + self.dims)
def get_config(self):
config = {'dims': self.dims}
diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py
index 2ca816adbd..551d1b1c3a 100644
--- a/tensorflow/python/keras/_impl/keras/layers/core_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras._impl import keras
from tensorflow.python.keras._impl.keras import testing_utils
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -159,7 +160,7 @@ class CoreLayersTest(test.TestCase):
# test with lambda
ld = keras.layers.Lambda(
- lambda x: keras.backend.concatenate([keras.backend.square(x), x]))
+ lambda x: keras.backend.concatenate([math_ops.square(x), x]))
config = ld.get_config()
ld = keras.layers.Lambda.from_config(config)
@@ -235,4 +236,3 @@ class CoreLayersTest(test.TestCase):
if __name__ == '__main__':
test.main()
-
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
index 006ecd3135..540e2d945c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -24,6 +24,8 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -126,7 +128,7 @@ class Embedding(Layer):
if not self.mask_zero:
return None
else:
- return K.not_equal(inputs, 0)
+ return math_ops.not_equal(inputs, 0)
@shape_type_conversion
def compute_output_shape(self, input_shape):
@@ -152,8 +154,8 @@ class Embedding(Layer):
def call(self, inputs):
if K.dtype(inputs) != 'int32':
- inputs = K.cast(inputs, 'int32')
- out = K.gather(self.embeddings, inputs)
+ inputs = math_ops.cast(inputs, 'int32')
+ out = array_ops.gather(self.embeddings, inputs)
return out
def get_config(self):
diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py
index c660cbd449..7c87e6c067 100644
--- a/tensorflow/python/keras/_impl/keras/layers/merge.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge.py
@@ -23,6 +23,9 @@ from __future__ import print_function
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@@ -127,7 +130,7 @@ class _Merge(Layer):
for x in inputs:
x_ndim = K.ndim(x)
for _ in range(max_ndim - x_ndim):
- x = K.expand_dims(x, 1)
+ x = array_ops.expand_dims(x, axis=1)
reshaped_inputs.append(x)
return self._merge_function(reshaped_inputs)
else:
@@ -137,19 +140,22 @@ class _Merge(Layer):
for x in inputs:
x_ndim = K.ndim(x)
if x_ndim is None:
- x_shape = K.shape(x)
+ x_shape = array_ops.shape(x)
batch_size = x_shape[0]
- new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
- x_transposed = K.reshape(x,
- K.stack([batch_size,
- K.prod(x_shape[1:])]))
- x_transposed = K.permute_dimensions(x_transposed, (1, 0))
- x_transposed = K.reshape(x_transposed, new_shape)
+ new_shape = K.concatenate(
+ [x_shape[1:],
+ array_ops.expand_dims(batch_size, axis=-1)])
+ x_transposed = array_ops.reshape(
+ x,
+ array_ops.stack(
+ [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0))
+ x_transposed = array_ops.transpose(x_transposed, perm=(1, 0))
+ x_transposed = array_ops.reshape(x_transposed, new_shape)
reshaped_inputs.append(x_transposed)
transposed = True
elif x_ndim > 1:
dims = list(range(1, x_ndim)) + [0]
- reshaped_inputs.append(K.permute_dimensions(x, dims))
+ reshaped_inputs.append(array_ops.transpose(x, perm=dims))
transposed = True
else:
# We don't transpose inputs if they are 1D vectors or scalars.
@@ -159,17 +165,18 @@ class _Merge(Layer):
if transposed:
# If inputs have been transposed, we have to transpose the output too.
if y_ndim is None:
- y_shape = K.shape(y)
- y_ndim = K.shape(y_shape)[0]
+ y_shape = array_ops.shape(y)
+ y_ndim = array_ops.shape(y_shape)[0]
batch_size = y_shape[y_ndim - 1]
- new_shape = K.concatenate(
- [K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
- y = K.reshape(y, (-1, batch_size))
- y = K.permute_dimensions(y, (1, 0))
- y = K.reshape(y, new_shape)
+ new_shape = K.concatenate([
+ array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
+ ])
+ y = array_ops.reshape(y, (-1, batch_size))
+ y = array_ops.transpose(y, perm=(1, 0))
+ y = array_ops.reshape(y, new_shape)
elif y_ndim > 1:
dims = [y_ndim - 1] + list(range(y_ndim - 1))
- y = K.permute_dimensions(y, dims)
+ y = array_ops.transpose(y, perm=dims)
return y
else:
return self._merge_function(inputs)
@@ -207,7 +214,7 @@ class _Merge(Layer):
'should have the same length.')
if all([m is None for m in mask]):
return None
- masks = [K.expand_dims(m, 0) for m in mask if m is not None]
+ masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
@@ -325,7 +332,7 @@ class Maximum(_Merge):
def _merge_function(self, inputs):
output = inputs[0]
for i in range(1, len(inputs)):
- output = K.maximum(output, inputs[i])
+ output = math_ops.maximum(output, inputs[i])
return output
@@ -340,7 +347,7 @@ class Minimum(_Merge):
def _merge_function(self, inputs):
output = inputs[0]
for i in range(1, len(inputs)):
- output = K.minimum(output, inputs[i])
+ output = math_ops.minimum(output, inputs[i])
return output
@@ -418,10 +425,10 @@ class Concatenate(_Merge):
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
# Input is unmasked. Append all 1s to masks,
- masks.append(K.ones_like(input_i, dtype='bool'))
+ masks.append(array_ops.ones_like(input_i, dtype='bool'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
- masks.append(K.expand_dims(mask_i))
+ masks.append(array_ops.expand_dims(mask_i, axis=-1))
else:
masks.append(mask_i)
concatenated = K.concatenate(masks, axis=self.axis)
@@ -511,8 +518,8 @@ class Dot(_Merge):
else:
axes.append(self.axes[i])
if self.normalize:
- x1 = K.l2_normalize(x1, axis=axes[0])
- x2 = K.l2_normalize(x2, axis=axes[1])
+ x1 = nn.l2_normalize(x1, axis=axes[0])
+ x2 = nn.l2_normalize(x2, axis=axes[1])
output = K.batch_dot(x1, x2, axes)
return output
diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py
index e309d160e5..72dc7a1ff8 100644
--- a/tensorflow/python/keras/_impl/keras/layers/noise.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise.py
@@ -23,6 +23,8 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -58,7 +60,7 @@ class GaussianNoise(Layer):
def noised():
return inputs + K.random_normal(
- shape=K.shape(inputs), mean=0., stddev=self.stddev)
+ shape=array_ops.shape(inputs), mean=0., stddev=self.stddev)
return K.in_train_phase(noised, inputs, training=training)
@@ -104,7 +106,7 @@ class GaussianDropout(Layer):
def noised():
stddev = np.sqrt(self.rate / (1.0 - self.rate))
return inputs * K.random_normal(
- shape=K.shape(inputs), mean=1.0, stddev=stddev)
+ shape=array_ops.shape(inputs), mean=1.0, stddev=stddev)
return K.in_train_phase(noised, inputs, training=training)
return inputs
@@ -153,7 +155,7 @@ class AlphaDropout(Layer):
self.supports_masking = True
def _get_noise_shape(self, inputs):
- return self.noise_shape if self.noise_shape else K.shape(inputs)
+ return self.noise_shape if self.noise_shape else array_ops.shape(inputs)
def call(self, inputs, training=None):
if 0. < self.rate < 1.:
@@ -164,9 +166,9 @@ class AlphaDropout(Layer):
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale
- kept_idx = K.greater_equal(
+ kept_idx = math_ops.greater_equal(
K.random_uniform(noise_shape, seed=seed), rate)
- kept_idx = K.cast(kept_idx, K.floatx())
+ kept_idx = math_ops.cast(kept_idx, K.floatx())
# Get affine transformation params
a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 791f9b3113..7f9f77c296 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -33,6 +33,9 @@ from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -503,9 +506,12 @@ class RNN(Layer):
def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
- initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
- initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
- initial_state = K.expand_dims(initial_state) # (samples, 1)
+ initial_state = array_ops.zeros_like(inputs)
+ # shape of initial_state = (samples, timesteps, input_dim)
+ initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2))
+ # shape of initial_state = (samples,)
+ initial_state = array_ops.expand_dims(initial_state, axis=-1)
+ # shape of initial_state = (samples, 1)
if hasattr(self.cell.state_size, '__len__'):
return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size]
else:
@@ -631,7 +637,7 @@ class RNN(Layer):
if self.stateful:
updates = []
for i in range(len(states)):
- updates.append(K.update(self.states[i], states[i]))
+ updates.append(state_ops.assign(self.states[i], states[i]))
self.add_update(updates, inputs)
if self.return_sequences:
@@ -907,8 +913,7 @@ class SimpleRNNCell(Layer):
prev_output = states[0]
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs,
- K.shape(inputs)[-1]),
+ _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]),
self.dropout,
training=training)
if (0 < self.recurrent_dropout < 1 and
@@ -1309,8 +1314,7 @@ class GRUCell(Layer):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs,
- K.shape(inputs)[-1]),
+ _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]),
self.dropout,
training=training,
count=3)
@@ -1793,8 +1797,7 @@ class LSTMCell(Layer):
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs,
- K.shape(inputs)[-1]),
+ _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]),
self.dropout,
training=training,
count=4)
@@ -2176,7 +2179,7 @@ class LSTM(RNN):
def _generate_dropout_ones(inputs, dims):
- return K.ones((K.shape(inputs)[0], dims))
+ return K.ones((array_ops.shape(inputs)[0], dims))
def _generate_dropout_mask(ones, rate, training=None, count=1):
@@ -2351,9 +2354,12 @@ class Recurrent(Layer):
def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
- initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
- initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
- initial_state = K.expand_dims(initial_state) # (samples, 1)
+ initial_state = array_ops.zeros_like(inputs)
+ # shape of initial_state = (samples, timesteps, input_dim)
+ initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2))
+ # shape of initial_state = (samples,)
+ initial_state = array_ops.expand_dims(initial_state, axis=-1)
+ # shape of initial_state = (samples, 1)
initial_state = K.tile(initial_state, [1,
self.units]) # (samples, output_dim)
initial_state = [initial_state for _ in range(len(self.states))]
@@ -2456,7 +2462,7 @@ class Recurrent(Layer):
if self.stateful:
updates = []
for i in range(len(states)):
- updates.append(K.update(self.states[i], states[i]))
+ updates.append(state_ops.assign(self.states[i], states[i]))
self.add_update(updates, inputs)
# Properly set learning phase
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
index de022153f6..fb743b617f 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -24,6 +24,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.keras._impl import keras
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
@@ -395,8 +398,8 @@ class RNNTest(test.TestCase):
# Test `get_losses_for` and `losses`
x = keras.Input((None, 1))
- loss_1 = keras.backend.sum(x)
- loss_2 = keras.backend.sum(cells[0].kernel)
+ loss_1 = math_ops.reduce_sum(x)
+ loss_2 = math_ops.reduce_sum(cells[0].kernel)
cells[0].add_loss(loss_1, inputs=x)
cells[0].add_loss(loss_2)
self.assertEqual(len(layer.losses), 2)
@@ -410,10 +413,10 @@ class RNNTest(test.TestCase):
layer.build((None, None, 1))
x = keras.Input((None, 1))
- update_1 = keras.backend.update_add(
- cells[0].kernel, x[0, 0, 0] * cells[0].kernel)
- update_2 = keras.backend.update_add(
- cells[0].kernel, keras.backend.ones_like(cells[0].kernel))
+ update_1 = state_ops.assign_add(cells[0].kernel,
+ x[0, 0, 0] * cells[0].kernel)
+ update_2 = state_ops.assign_add(cells[0].kernel,
+ array_ops.ones_like(cells[0].kernel))
cells[0].add_update(update_1, inputs=x)
cells[0].add_update(update_2)
self.assertEqual(len(layer.updates), 2)
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
index 76ddd9299d..c510e464ae 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -28,6 +28,7 @@ from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.layers import utils as tf_layers_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import tf_export
@@ -209,11 +210,11 @@ class TimeDistributed(Wrapper):
# We can go with reshape-based implementation for performance.
input_length = input_shape[1]
if not input_length:
- input_length = K.shape(inputs)[1]
+ input_length = array_ops.shape(inputs)[1]
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = tf_layers_util.object_list_uid(inputs)
- inputs = K.reshape(inputs, (-1,) + input_shape[2:])
+ inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:])
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
y = self.layer.call(inputs, **kwargs)
@@ -221,7 +222,7 @@ class TimeDistributed(Wrapper):
uses_learning_phase = y._uses_learning_phase
# Shape: (num_samples, timesteps, ...)
output_shape = self.compute_output_shape(input_shape).as_list()
- y = K.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
+ y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
# Apply activity regularizer if any:
if (hasattr(self.layer, 'activity_regularizer') and
diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py
index 1576ed7b99..1d634d3801 100644
--- a/tensorflow/python/keras/_impl/keras/losses.py
+++ b/tensorflow/python/keras/_impl/keras/losses.py
@@ -24,51 +24,55 @@ import six
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.metrics.mean_squared_error',
'keras.losses.mean_squared_error')
def mean_squared_error(y_true, y_pred):
- return K.mean(K.square(y_pred - y_true), axis=-1)
+ return K.mean(math_ops.square(y_pred - y_true), axis=-1)
@tf_export('keras.metrics.mean_absolute_error',
'keras.losses.mean_absolute_error')
def mean_absolute_error(y_true, y_pred):
- return K.mean(K.abs(y_pred - y_true), axis=-1)
+ return K.mean(math_ops.abs(y_pred - y_true), axis=-1)
@tf_export('keras.metrics.mean_absolute_percentage_error',
'keras.losses.mean_absolute_percentage_error')
def mean_absolute_percentage_error(y_true, y_pred):
- diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None))
+ diff = math_ops.abs(
+ (y_true - y_pred) / K.clip(math_ops.abs(y_true), K.epsilon(), None))
return 100. * K.mean(diff, axis=-1)
@tf_export('keras.metrics.mean_squared_logarithmic_error',
'keras.losses.mean_squared_logarithmic_error')
def mean_squared_logarithmic_error(y_true, y_pred):
- first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
- second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
- return K.mean(K.square(first_log - second_log), axis=-1)
+ first_log = math_ops.log(K.clip(y_pred, K.epsilon(), None) + 1.)
+ second_log = math_ops.log(K.clip(y_true, K.epsilon(), None) + 1.)
+ return K.mean(math_ops.square(first_log - second_log), axis=-1)
@tf_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
def squared_hinge(y_true, y_pred):
- return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)
+ return K.mean(
+ math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1)
@tf_export('keras.metrics.hinge', 'keras.losses.hinge')
def hinge(y_true, y_pred):
- return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1)
+ return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1)
@tf_export('keras.losses.categorical_hinge')
def categorical_hinge(y_true, y_pred):
- pos = K.sum(y_true * y_pred, axis=-1)
- neg = K.max((1. - y_true) * y_pred, axis=-1)
- return K.maximum(0., neg - pos + 1.)
+ pos = math_ops.reduce_sum(y_true * y_pred, axis=-1)
+ neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1)
+ return math_ops.maximum(0., neg - pos + 1.)
@tf_export('keras.losses.logcosh')
@@ -89,7 +93,7 @@ def logcosh(y_true, y_pred):
"""
def _logcosh(x):
- return x + K.softplus(-2. * x) - K.log(2.)
+ return x + nn.softplus(-2. * x) - math_ops.log(2.)
return K.mean(_logcosh(y_pred - y_true), axis=-1)
@@ -117,19 +121,19 @@ def binary_crossentropy(y_true, y_pred):
def kullback_leibler_divergence(y_true, y_pred):
y_true = K.clip(y_true, K.epsilon(), 1)
y_pred = K.clip(y_pred, K.epsilon(), 1)
- return K.sum(y_true * K.log(y_true / y_pred), axis=-1)
+ return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1)
@tf_export('keras.metrics.poisson', 'keras.losses.poisson')
def poisson(y_true, y_pred):
- return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
+ return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1)
@tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity')
def cosine_proximity(y_true, y_pred):
- y_true = K.l2_normalize(y_true, axis=-1)
- y_pred = K.l2_normalize(y_pred, axis=-1)
- return -K.sum(y_true * y_pred, axis=-1)
+ y_true = nn.l2_normalize(y_true, axis=-1)
+ y_pred = nn.l2_normalize(y_pred, axis=-1)
+ return -math_ops.reduce_sum(y_true * y_pred, axis=-1)
# Aliases.
diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py
index 82778a3dc4..747c3e6515 100644
--- a/tensorflow/python/keras/_impl/keras/metrics.py
+++ b/tensorflow/python/keras/_impl/keras/metrics.py
@@ -37,37 +37,45 @@ from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crosse
from tensorflow.python.keras._impl.keras.losses import squared_hinge
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred):
- return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
+ return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1)
@tf_export('keras.metrics.categorical_accuracy')
def categorical_accuracy(y_true, y_pred):
- return K.cast(
- K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx())
+ return math_ops.cast(
+ math_ops.equal(
+ math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)),
+ K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
- return K.cast(
- K.equal(
- K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1),
- K.floatx())), K.floatx())
+ return math_ops.cast(
+ math_ops.equal(
+ math_ops.reduce_max(y_true, axis=-1),
+ math_ops.cast(math_ops.argmax(y_pred, axis=-1), K.floatx())),
+ K.floatx())
@tf_export('keras.metrics.top_k_categorical_accuracy')
def top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
+ return K.mean(
+ nn.in_top_k(y_pred, math_ops.argmax(y_true, axis=-1), k), axis=-1)
@tf_export('keras.metrics.sparse_top_k_categorical_accuracy')
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
return K.mean(
- K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)
-
+ nn.in_top_k(y_pred,
+ math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'),
+ k),
+ axis=-1)
# Aliases
diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py
index 44289ea02a..9deaab0c05 100644
--- a/tensorflow/python/keras/_impl/keras/metrics_test.py
+++ b/tensorflow/python/keras/_impl/keras/metrics_test.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.keras._impl import keras
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
@@ -104,16 +106,15 @@ class KerasMetricsTest(test.TestCase):
The total number of true positives seen this epoch at the
completion of the batch.
"""
- y_true = keras.backend.cast(y_true, 'int32')
- y_pred = keras.backend.cast(keras.backend.round(y_pred), 'int32')
- correct_preds = keras.backend.cast(
- keras.backend.equal(y_pred, y_true), 'int32')
- true_pos = keras.backend.cast(
- keras.backend.sum(correct_preds * y_true), 'int32')
+ y_true = math_ops.cast(y_true, 'int32')
+ y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
+ correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32')
+ true_pos = math_ops.cast(
+ math_ops.reduce_sum(correct_preds * y_true), 'int32')
current_true_pos = self.true_positives * 1
- self.add_update(keras.backend.update_add(self.true_positives,
- true_pos),
- inputs=[y_true, y_pred])
+ self.add_update(
+ state_ops.assign_add(self.true_positives, true_pos),
+ inputs=[y_true, y_pred])
return current_true_pos + true_pos
metric_fn = BinaryTruePositives()
diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py
index acbb9091d3..9f383deb72 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers.py
@@ -31,6 +31,7 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_
from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
@@ -118,7 +119,8 @@ class Optimizer(object):
'Common ops without gradient: '
'K.argmax, K.round, K.eval.')
if hasattr(self, 'clipnorm') and self.clipnorm > 0:
- norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads]))
+ norm = K.sqrt(
+ sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads]))
grads = [clip_norm(g, self.clipnorm, norm) for g in grads]
if hasattr(self, 'clipvalue') and self.clipvalue > 0:
grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads]
@@ -204,20 +206,20 @@ class SGD(Optimizer):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr = lr * (1. / # pylint: disable=g-no-augmented-assignment
- (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr = lr * ( # pylint: disable=g-no-augmented-assignment
+ 1. / (1. + self.decay * math_ops.cast(self.iterations,
+ K.dtype(self.decay))))
# momentum
shapes = [K.int_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
self.weights = [self.iterations] + moments
for p, g, m in zip(params, grads, moments):
v = self.momentum * m - lr * g # velocity
- self.updates.append(K.update(m, v))
+ self.updates.append(state_ops.assign(m, v))
if self.nesterov:
new_p = p + self.momentum * v - lr * g
@@ -228,7 +230,7 @@ class SGD(Optimizer):
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
@@ -277,25 +279,25 @@ class RMSprop(Optimizer):
grads = self.get_gradients(loss, params)
accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
self.weights = accumulators
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr = lr * (1. / # pylint: disable=g-no-augmented-assignment
- (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr = lr * ( # pylint: disable=g-no-augmented-assignment
+ 1. / (1. + self.decay * math_ops.cast(self.iterations,
+ K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators):
# update accumulator
- new_a = self.rho * a + (1. - self.rho) * K.square(g)
- self.updates.append(K.update(a, new_a))
+ new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
+ self.updates.append(state_ops.assign(a, new_a))
new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
@@ -339,24 +341,24 @@ class Adagrad(Optimizer):
shapes = [K.int_shape(p) for p in params]
accumulators = [K.zeros(shape) for shape in shapes]
self.weights = accumulators
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr = lr * (1. / # pylint: disable=g-no-augmented-assignment
- (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr = lr * ( # pylint: disable=g-no-augmented-assignment
+ 1. / (1. + self.decay * math_ops.cast(self.iterations,
+ K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators):
- new_a = a + K.square(g) # update accumulator
- self.updates.append(K.update(a, new_a))
+ new_a = a + math_ops.square(g) # update accumulator
+ self.updates.append(state_ops.assign(a, new_a))
new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
@@ -403,18 +405,18 @@ class Adadelta(Optimizer):
accumulators = [K.zeros(shape) for shape in shapes]
delta_accumulators = [K.zeros(shape) for shape in shapes]
self.weights = accumulators + delta_accumulators
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr = lr * (1. / # pylint: disable=g-no-augmented-assignment
- (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr = lr * ( # pylint: disable=g-no-augmented-assignment
+ 1. / (1. + self.decay * math_ops.cast(self.iterations,
+ K.dtype(self.decay))))
for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
# update accumulator
- new_a = self.rho * a + (1. - self.rho) * K.square(g)
- self.updates.append(K.update(a, new_a))
+ new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
+ self.updates.append(state_ops.assign(a, new_a))
# use the new accumulator and the *old* delta_accumulator
update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon)
@@ -424,11 +426,11 @@ class Adadelta(Optimizer):
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
# update delta_accumulator
- new_d_a = self.rho * d_a + (1 - self.rho) * K.square(update)
- self.updates.append(K.update(d_a, new_d_a))
+ new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update)
+ self.updates.append(state_ops.assign(d_a, new_d_a))
return self.updates
def get_config(self):
@@ -483,17 +485,18 @@ class Adam(Optimizer):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr = lr * (1. / # pylint: disable=g-no-augmented-assignment
- (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr = lr * ( # pylint: disable=g-no-augmented-assignment
+ 1. / (1. + self.decay * math_ops.cast(self.iterations,
+ K.dtype(self.decay))))
- t = K.cast(self.iterations, K.floatx()) + 1
+ t = math_ops.cast(self.iterations, K.floatx()) + 1
lr_t = lr * (
- K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
+ K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
+ (1. - math_ops.pow(self.beta_1, t)))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
@@ -505,23 +508,23 @@ class Adam(Optimizer):
for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
- v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
+ v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
if self.amsgrad:
- vhat_t = K.maximum(vhat, v_t)
+ vhat_t = math_ops.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
- self.updates.append(K.update(vhat, vhat_t))
+ self.updates.append(state_ops.assign(vhat, vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
- self.updates.append(K.update(m, m_t))
- self.updates.append(K.update(v, v_t))
+ self.updates.append(state_ops.assign(m, m_t))
+ self.updates.append(state_ops.assign(v, v_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
@@ -573,16 +576,16 @@ class Adamax(Optimizer):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr = lr * (1. / # pylint: disable=g-no-augmented-assignment
- (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr = lr * ( # pylint: disable=g-no-augmented-assignment
+ 1. / (1. + self.decay * math_ops.cast(self.iterations,
+ K.dtype(self.decay))))
- t = K.cast(self.iterations, K.floatx()) + 1
- lr_t = lr / (1. - K.pow(self.beta_1, t))
+ t = math_ops.cast(self.iterations, K.floatx()) + 1
+ lr_t = lr / (1. - math_ops.pow(self.beta_1, t))
shapes = [K.int_shape(p) for p in params]
# zero init of 1st moment
@@ -594,18 +597,18 @@ class Adamax(Optimizer):
for p, g, m, u in zip(params, grads, ms, us):
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
- u_t = K.maximum(self.beta_2 * u, K.abs(g))
+ u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g))
p_t = p - lr_t * m_t / (u_t + self.epsilon)
- self.updates.append(K.update(m, m_t))
- self.updates.append(K.update(u, u_t))
+ self.updates.append(state_ops.assign(m, m_t))
+ self.updates.append(state_ops.assign(u, u_t))
new_p = p_t
# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
@@ -659,16 +662,17 @@ class Nadam(Optimizer):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
- t = K.cast(self.iterations, K.floatx()) + 1
+ t = math_ops.cast(self.iterations, K.floatx()) + 1
# Due to the recommendations in [2], i.e. warming momentum schedule
momentum_cache_t = self.beta_1 * (
- 1. - 0.5 * (K.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
+ 1. - 0.5 *
+ (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
momentum_cache_t_1 = self.beta_1 * (
1. - 0.5 *
- (K.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
+ (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
m_schedule_new = self.m_schedule * momentum_cache_t
m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
self.updates.append((self.m_schedule, m_schedule_new))
@@ -684,13 +688,13 @@ class Nadam(Optimizer):
g_prime = g / (1. - m_schedule_new)
m_t = self.beta_1 * m + (1. - self.beta_1) * g
m_t_prime = m_t / (1. - m_schedule_next)
- v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
- v_t_prime = v_t / (1. - K.pow(self.beta_2, t))
+ v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
+ v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
m_t_bar = (
1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
- self.updates.append(K.update(m, m_t))
- self.updates.append(K.update(v, v_t))
+ self.updates.append(state_ops.assign(m, m_t))
+ self.updates.append(state_ops.assign(v, v_t))
p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
new_p = p_t
@@ -699,7 +703,7 @@ class Nadam(Optimizer):
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
- self.updates.append(K.update(p, new_p))
+ self.updates.append(state_ops.assign(p, new_p))
return self.updates
def get_config(self):
@@ -743,7 +747,7 @@ class TFOptimizer(Optimizer):
global_step = training_util.get_global_step()
opt_update = self.optimizer.apply_gradients(grads, global_step)
else:
- self.updates = [K.update_add(self.iterations, 1)]
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
if not params:
return self.updates
diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py
index 2c30844647..74c37d370e 100644
--- a/tensorflow/python/keras/_impl/keras/regularizers.py
+++ b/tensorflow/python/keras/_impl/keras/regularizers.py
@@ -23,6 +23,7 @@ import six
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -55,9 +56,9 @@ class L1L2(Regularizer):
def __call__(self, x):
regularization = 0.
if self.l1:
- regularization += K.sum(self.l1 * K.abs(x))
+ regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
if self.l2:
- regularization += K.sum(self.l2 * K.square(x))
+ regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
return regularization
def get_config(self):
diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
index 4c8009dfd8..902972ecbb 100644
--- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
@@ -35,7 +35,7 @@ def count_params(weights):
Returns:
The total number of scalars composing the weights
"""
- return int(np.sum([K.count_params(p) for p in set(weights)]))
+ return int(np.sum([np.prod(p.get_shape().as_list()) for p in set(weights)]))
def print_summary(model, line_length=None, positions=None, print_fn=None):
@@ -193,8 +193,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
else:
trainable_count = count_params(model.trainable_weights)
- non_trainable_count = int(
- np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
+ non_trainable_count = count_params(model.non_trainable_weights)
print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
print_fn('Trainable params: {:,}'.format(trainable_count))
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index ea210346c1..6c34ea1816 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -295,7 +295,6 @@ tf_py_test(
"//tensorflow/python:nn_grad",
],
data = ["//tensorflow/core:image_testdata"],
- tags = ["no_windows"],
)
tf_py_test(
@@ -1142,7 +1141,6 @@ tf_py_test(
"//tensorflow/python:variables",
],
data = ["//tensorflow/core:lmdb_testdata"],
- tags = ["no_windows"],
)
cuda_py_test(
@@ -2332,7 +2330,6 @@ cuda_py_test(
"//tensorflow/python:variables",
],
shard_count = 4,
- tags = ["no_windows"],
)
cuda_py_test(
@@ -2463,7 +2460,6 @@ cuda_py_test(
"//tensorflow/python/eager:context",
],
shard_count = 10,
- tags = ["no_windows"],
)
cuda_py_test(
@@ -2523,7 +2519,10 @@ cuda_py_test(
"//tensorflow/python:sparse_ops",
],
shard_count = 5,
- tags = ["noasan"],
+ tags = [
+ "noasan",
+ "optonly", # b/77589990
+ ],
)
cuda_py_test(
@@ -2726,6 +2725,7 @@ cuda_py_test(
],
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
shard_count = 20,
+ tags = ["no_windows"],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 64c1760d5e..5a20eebbc5 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -780,6 +780,14 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase):
grad = GradSliceChecker(self, sess, var, np.array(8))
_ = grad[tuple()]
+ def testInt64Indices(self):
+ with self.test_session(use_gpu=True) as sess:
+ a = math_ops.range(3)
+ index = constant_op.constant(1, dtype=dtypes.int64)
+ b = 2 * a[index]
+ grad, = gradients_impl.gradients(b, a)
+ self.assertAllEqual(sess.run(grad), [0, 2, 0])
+
class StridedSliceGradTypeTest(test_util.TensorFlowTestCase):
"""Test varied index types and host located memory."""
@@ -999,30 +1007,38 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
class ShapeSizeRankTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes()
def testDenseShape(self):
- with self.test_session():
- t_value = [[0, 42], [24, 0]]
- self.assertAllEqual((2, 2), array_ops.shape(t_value).eval())
- self.assertEqual(4, array_ops.size(t_value).eval())
- self.assertEqual(2, array_ops.rank(t_value).eval())
+ t_value = [[0, 42], [24, 0]]
+ self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value)))
+ self.assertEqual(4, self.evaluate(array_ops.size(t_value)))
+ self.assertEqual(2, self.evaluate(array_ops.rank(t_value)))
- t = constant_op.constant(t_value)
- self.assertAllEqual((2, 2), array_ops.shape(t).eval())
- self.assertEqual(4, array_ops.size(t).eval())
- self.assertEqual(2, array_ops.rank(t).eval())
+ t = constant_op.constant(t_value)
+ self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t)))
+ self.assertEqual(4, self.evaluate(array_ops.size(t)))
+ self.assertEqual(2, self.evaluate(array_ops.rank(t)))
+ @test_util.run_in_graph_and_eager_modes()
def testSparseShape(self):
- with self.test_session():
- sp_value = sparse_tensor.SparseTensorValue(
- indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
- self.assertAllEqual((2, 2), array_ops.shape(sp_value).eval())
- self.assertEqual(4, array_ops.size(sp_value).eval())
- self.assertEqual(2, array_ops.rank(sp_value).eval())
-
- sp = sparse_tensor.SparseTensor.from_value(sp_value)
- self.assertAllEqual((2, 2), array_ops.shape(sp).eval())
- self.assertEqual(4, array_ops.size(sp).eval())
- self.assertEqual(2, array_ops.rank(sp).eval())
+ sp_value = sparse_tensor.SparseTensorValue(
+ indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
+ self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp_value)))
+ self.assertEqual(4, self.evaluate(array_ops.size(sp_value)))
+ self.assertEqual(2, self.evaluate(array_ops.rank(sp_value)))
+
+ sp = sparse_tensor.SparseTensor.from_value(sp_value)
+ self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(sp)))
+ self.assertEqual(4, self.evaluate(array_ops.size(sp)))
+ self.assertEqual(2, self.evaluate(array_ops.rank(sp)))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testSizeDtype(self):
+ tensor = [1]
+ self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype)
+ self.assertEqual(
+ dtypes.int64,
+ self.evaluate(array_ops.size(tensor, out_type=dtypes.int64)).dtype)
@test_util.with_c_api
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index ec8ac74163..f4616fd661 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
@@ -344,6 +345,8 @@ class Conv3DTest(test.TestCase):
if data_format == "NCDHW":
conv = test_util.NCHWToNHWC(conv)
+ self.assertEqual(conv.shape, tensor_shape.TensorShape(output_shape))
+
if test_input:
jacob_t, jacob_n = gradient_checker.compute_gradient(
orig_input_tensor, input_shape, conv, output_shape)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 8db0bb6f0d..34e7751243 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -2165,5 +2165,47 @@ class AccumulateTest(test.TestCase):
math_ops.accumulate_n([a], tensor_dtype=np.int32)
+class PolyvalTest(test.TestCase):
+
+ def _runtest(self, dtype, degree):
+ x = np.random.rand(2, 2).astype(dtype)
+ coeffs = [np.random.rand(2, 2).astype(dtype) for _ in range(degree + 1)]
+ np_val = np.polyval(coeffs, x)
+ with self.test_session():
+ tf_val = math_ops.polyval(coeffs, x)
+ self.assertAllClose(np_val, tf_val.eval())
+
+ def testSimple(self):
+ for dtype in [
+ np.int32, np.float32, np.float64, np.complex64, np.complex128
+ ]:
+ for degree in range(5):
+ self._runtest(dtype, degree)
+
+ def testBroadcast(self):
+ dtype = np.float32
+ degree = 3
+ shapes = [(1,), (2, 1), (1, 2), (2, 2)]
+ for x_shape in shapes:
+ for coeff_shape in shapes:
+ x = np.random.rand(*x_shape).astype(dtype)
+ coeffs = [
+ np.random.rand(*coeff_shape).astype(dtype)
+ for _ in range(degree + 1)
+ ]
+ np_val = np.polyval(coeffs, x)
+ with self.test_session():
+ tf_val = math_ops.polyval(coeffs, x)
+ self.assertAllClose(np_val, tf_val.eval())
+
+ def testEmpty(self):
+ x = np.random.rand(2, 2).astype(np.float32)
+ coeffs = []
+ np_val = np.polyval(coeffs, x)
+ with self.test_session():
+ tf_val = math_ops.polyval(coeffs, x)
+ self.assertAllClose(np_val, tf_val.eval())
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index df99a0ed25..a8def95b14 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -281,6 +281,22 @@ class UniformTest(test.TestCase):
expected_pdf = [1.0, 0.1]
self.assertAllClose(expected_pdf, pdf.eval())
+ def testUniformFloat64(self):
+ uniform = uniform_lib.Uniform(
+ low=np.float64(0.), high=np.float64(1.))
+
+ self.assertAllClose(
+ [1., 1.],
+ self.evaluate(uniform.prob(np.array([0.5, 0.6], dtype=np.float64))))
+
+ self.assertAllClose(
+ [0.5, 0.6],
+ self.evaluate(uniform.cdf(np.array([0.5, 0.6], dtype=np.float64))))
+
+ self.assertAllClose(0.5, self.evaluate(uniform.mean()))
+ self.assertAllClose(1 / 12., self.evaluate(uniform.variance()))
+ self.assertAllClose(0., self.evaluate(uniform.entropy()))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 1301ef9d19..34fb655035 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -24,6 +24,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -39,6 +40,7 @@ import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+# pylint: disable=invalid-name
def simple_scoped_fn(a, x):
"""Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
with variable_scope.variable_scope("body"):
@@ -158,6 +160,13 @@ class FunctionalOpsTest(test.TestCase):
values=constant_op.constant([0, 1, 2]),
dense_shape=[2, 2]))
+ @test_util.run_in_graph_and_eager_modes()
+ def testMapOverScalarErrors(self):
+ with self.assertRaisesRegexp(ValueError, "not scalars"):
+ functional_ops.map_fn(lambda x: x, [1, 2])
+ with self.assertRaisesRegexp(ValueError, "not a scalar"):
+ functional_ops.map_fn(lambda x: x, 1)
+
def testMap_Scoped(self):
with self.test_session() as sess:
@@ -607,6 +616,276 @@ class FunctionalOpsTest(test.TestCase):
mul = sess.run(remote_op)
self.assertEqual(mul, 9)
+ def testIf(self):
+
+ @function.Defun(dtypes.float32)
+ def Twice(x):
+ return x * 2
+
+ @function.Defun(dtypes.float32)
+ def Thrice(x):
+ return x * 3 + 1
+
+ with self.test_session(use_gpu=False) as sess:
+
+ def Run(x):
+ return sess.run(
+ functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice))[0]
+
+ self.assertAllEqual(Run(9.), 18.)
+ self.assertAllEqual(Run(-8.), -23.)
+ self.assertAllEqual(Run(0.), 1.)
+
+ def testWhile(self):
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ # TODO(b/65752372): Set `use_gpu=False` because
+ # `functional_ops.While()` does not reliably work on GPU (apparently
+ # because the result of evaluating the condition may be in device
+ # memory, but it is read on the host).
+ with self.test_session(use_gpu=False) as sess:
+
+ def Run(n):
+ return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
+
+ self.assertAllEqual(Run(20.), 210.)
+ self.assertAllEqual(Run(100.), 5050.)
+
+ def testWhileError(self):
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def CondReturnsTooManyArgs(n, x):
+ return n > 0, x
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def BodyReturnsTooManyArgs(n, x):
+ return n - 1, x + n, x
+
+ # TODO(b/65752372): Set `use_gpu=False` because
+ # `functional_ops.While()` does not reliably work on GPU (apparently
+ # because the result of evaluating the condition may be in device
+ # memory, but it is read on the host).
+ with self.test_session(use_gpu=False):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Expected a single scalar.*got 2 tensors."):
+ functional_ops.While([5., 0.], CondReturnsTooManyArgs, Body)[0].eval()
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "While loop body returned 3 arguments. Expected: 2"):
+ functional_ops.While([5., 0.], Cond, BodyReturnsTooManyArgs)[0].eval()
+
+ def testWhileInMultipleSubgraphs(self):
+
+ @function.Defun(* [dtypes.float32] * 2)
+ def Cond(n, x): # pylint: disable=unused-argument
+ return n > 0
+
+ @function.Defun(* [dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ # TODO(b/65752372): Set `use_gpu=False` because
+ # `functional_ops.While()` does not reliably work on GPU (apparently
+ # because the result of evaluating the condition may be in device
+ # memory, but it is read on the host).
+ with self.test_session(use_gpu=False) as sess:
+ n = array_ops.placeholder(dtypes.float32)
+ _, result = functional_ops.While([n, 0.], Cond, Body)
+ c = constant_op.constant(37.)
+
+ self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
+ self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
+ # Test that the result is the same when we run a different subgraph.
+ self.assertAllEqual(5050., sess.run([result, c], feed_dict={n: 100.})[0])
+
+ def _tfSum(self, rewrite_with_while):
+ # On GPU, don't rewrite using a while loop.
+ use_gpu = not rewrite_with_while
+ with self.test_session(use_gpu=use_gpu) as sess:
+
+ @function.Defun(dtypes.int32, dtypes.float32)
+ def Body(n, x):
+ return x + math_ops.to_float(n)
+
+ xs = [
+ # 1 + 2 + ... + 20
+ functional_ops.For(
+ 1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
+ # 100 + 99 + ... + 1
+ functional_ops.For(
+ 100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
+ ]
+ xvals = sess.run(xs)
+ self.assertAllEqual(210, xvals[0])
+ self.assertAllEqual(5050, xvals[1])
+
+ def testFor(self):
+ self._tfSum(False)
+
+ def testForWithWhile(self):
+ self._tfSum(True)
+
+ def testForWithWhileNaming(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody")
+ def TestBody(n, x):
+ return x + math_ops.to_float(n)
+
+ _ = functional_ops.For(
+ 1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0]
+
+ names = []
+ for func in g.as_graph_def().library.function:
+ names.append(func.signature.name)
+ self.assertTrue("TestBody" in names)
+ self.assertTrue("TestBody_Cond" in names)
+ self.assertTrue("TestBody_Body" in names)
+
+ def testForCapturedInputs(self):
+ v = variables.Variable(1.0)
+
+ @function.Defun(dtypes.int32)
+ def TestNullary(n):
+ v + math_ops.to_float(n) # pylint: disable=expression-not-assigned
+
+ @function.Defun(dtypes.int32, dtypes.float32)
+ def TestUnary(n, x):
+ return x + math_ops.to_float(n) + v
+
+ @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32)
+ def TestBinary(n, x, x2):
+ return x + math_ops.to_float(n) + v, x2 + v
+
+ for rewrite_with_while in (True, False):
+ # TODO(b/65752372): Set `use_gpu=False` because
+ # `functional_ops.While()` does not reliably work on GPU (apparently
+ # because the result of evaluating the condition may be in device
+ # memory, but it is read on the host).
+ use_gpu = not rewrite_with_while
+ with self.test_session(use_gpu=use_gpu) as sess:
+ result_nullary = functional_ops.For(
+ 1, 10, 1, [], TestNullary,
+ rewrite_with_while=rewrite_with_while)
+ result_unary = functional_ops.For(
+ 1, 10, 1, [0.], TestUnary,
+ rewrite_with_while=rewrite_with_while)
+ result_binary = functional_ops.For(
+ 1, 10, 1, [0., 0.], TestBinary,
+ rewrite_with_while=rewrite_with_while)
+ sess.run(variables.global_variables_initializer())
+ assert not result_nullary
+ # The nullary variant doesn't return anything so we can't easily run it.
+ # As a total hack, fetch the operation by name and run it.
+ sess.run(ops.get_default_graph().get_operation_by_name(
+ "While" if rewrite_with_while else "For"))
+ assert len(result_unary) == 1
+ self.assertEqual([54.0], sess.run(result_unary))
+ assert len(result_binary) == 2
+ self.assertEqual([54.0, 9.0], sess.run(result_binary))
+
+ def _tfMLP(self, xval, wsval, bsval, rewrite_with_while):
+ # On GPU, don't rewrite using a while loop.
+ use_gpu = not rewrite_with_while
+ with self.test_session(use_gpu=use_gpu):
+
+ @function.Defun(dtypes.int32, *[dtypes.float64] * 3)
+ def MLP(i, a, ws, bs):
+ a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :])
+ return a, ws, bs
+
+ ret = functional_ops.For(
+ 0,
+ wsval.shape[0],
+ 1, [xval, wsval, bsval],
+ MLP,
+ rewrite_with_while=rewrite_with_while)[0]
+
+ return ret.eval()
+
+ def _npMLP(self, xval, wsval, bsval):
+ for i in range(wsval.shape[0]):
+ xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :])
+ return xval
+
+ def _testForMLP(self, rewrite_with_while):
+ # We construct a 5-layer Multi-Layer Perceptron network here.
+ # Each layer have the same number of hidden unites (3), and the
+ # activation function is tanh(). We feed the input (xval) with
+ # batch size 2.
+ xval = np.random.normal(size=(2, 3))
+ wsval = np.random.normal(size=(5, 3, 3))
+ bsval = np.random.normal(size=(5, 3))
+ np_ans = self._npMLP(xval, wsval, bsval)
+ tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while)
+ self.assertAllClose(np_ans, tf_for_ans)
+
+ def testForMLP(self):
+ self._testForMLP(False)
+
+ def testForMLPWhile(self):
+ self._testForMLP(True)
+
+ def testForError(self):
+
+ @function.Defun(dtypes.int32, dtypes.float32)
+ def Foo(i, v):
+ return math_ops.to_float(i) + v
+
+ @function.Defun(dtypes.int32, dtypes.float32)
+ def ReturnsTooManyArgs(unused_i, v):
+ return v, v
+
+ with self.test_session(use_gpu=True):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "must be a scalar"):
+ functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Invalid start/limit/delta"):
+ functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval()
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "For loop body returned 2 arguments. Expected: 1"):
+ functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval()
+
+ def testGradient(self):
+
+ @function.Defun(dtypes.float32)
+ def Poly(x):
+ # y = 2x^3+3x^2+4x+8
+ return 2 * x * x * x + 3 * x * x + 4 * x + 8
+
+ @function.Defun(dtypes.float32)
+ def Grad(x):
+ # dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4)
+ return functional_ops.Gradient([x, 1.0], Poly)[0]
+
+ with self.test_session(use_gpu=False) as sess:
+ a = constant_op.constant(0.)
+ avals = [Poly(a), Grad(a)]
+ b = constant_op.constant(1.)
+ bvals = [Poly(b), Grad(b)]
+ self.assertAllEqual(sess.run(avals), [8., 4.])
+ self.assertAllEqual(sess.run(bvals), [17., 16.])
+
if __name__ == "__main__":
test.main()
+
+# pylint: enable=invalid-name
diff --git a/tensorflow/python/kernel_tests/large_concat_op_test.py b/tensorflow/python/kernel_tests/large_concat_op_test.py
index 66afb6ec01..184d1dde2a 100644
--- a/tensorflow/python/kernel_tests/large_concat_op_test.py
+++ b/tensorflow/python/kernel_tests/large_concat_op_test.py
@@ -19,10 +19,12 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+@test_util.with_c_api
class LargeConcatOpTest(test.TestCase):
"""Tests that belong in concat_op_test.py, but run over large tensors."""
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
index e1edffc3d9..7b291e29de 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.platform import test
@@ -94,8 +95,8 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
class BroadcastMatrixBatchDimsTest(test.TestCase):
def test_zero_batch_matrices_returned_as_empty_list(self):
- self.assertAllEqual(
- [], linear_operator_util.broadcast_matrix_batch_dims([]))
+ self.assertAllEqual([],
+ linear_operator_util.broadcast_matrix_batch_dims([]))
def test_one_batch_matrix_returned_after_tensor_conversion(self):
arr = rng.rand(2, 3, 4)
@@ -194,6 +195,44 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
linear_operator_util.broadcast_matrix_batch_dims([y, x])
+class CholeskySolveWithBroadcastTest(test.TestCase):
+
+ def test_static_dims_broadcast(self):
+ # batch_shape = [2]
+ chol = rng.rand(3, 3)
+ rhs = rng.rand(2, 3, 7)
+ chol_broadcast = chol + np.zeros((2, 1, 1))
+
+ with self.test_session():
+ result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
+ self.assertAllEqual((2, 3, 7), result.get_shape())
+ expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
+ self.assertAllEqual(expected.eval(), result.eval())
+
+ def test_dynamic_dims_broadcast_64bit(self):
+ # batch_shape = [2, 2]
+ chol = rng.rand(2, 3, 3)
+ rhs = rng.rand(2, 1, 3, 7)
+ chol_broadcast = chol + np.zeros((2, 2, 1, 1))
+ rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
+
+ chol_ph = array_ops.placeholder(dtypes.float64)
+ rhs_ph = array_ops.placeholder(dtypes.float64)
+
+ with self.test_session() as sess:
+ result, expected = sess.run(
+ [
+ linear_operator_util.cholesky_solve_with_broadcast(
+ chol_ph, rhs_ph),
+ linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
+ ],
+ feed_dict={
+ chol_ph: chol,
+ rhs_ph: rhs,
+ })
+ self.assertAllEqual(expected, result)
+
+
class MatmulWithBroadcastTest(test.TestCase):
def test_static_dims_broadcast(self):
@@ -209,7 +248,7 @@ class MatmulWithBroadcastTest(test.TestCase):
expected = math_ops.matmul(x, y_broadcast)
self.assertAllEqual(expected.eval(), result.eval())
- def test_dynamic_dims_broadcast_32bit(self):
+ def test_dynamic_dims_broadcast_64bit(self):
# batch_shape = [2]
# for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
x = rng.rand(2, 1, 3)
@@ -221,9 +260,90 @@ class MatmulWithBroadcastTest(test.TestCase):
with self.test_session() as sess:
result, expected = sess.run(
- [linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
- math_ops.matmul(x, y_broadcast)],
- feed_dict={x_ph: x, y_ph: y})
+ [
+ linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
+ math_ops.matmul(x, y_broadcast)
+ ],
+ feed_dict={
+ x_ph: x,
+ y_ph: y
+ })
+ self.assertAllEqual(expected, result)
+
+
+class MatrixSolveWithBroadcastTest(test.TestCase):
+
+ def test_static_dims_broadcast(self):
+ # batch_shape = [2]
+ matrix = rng.rand(3, 3)
+ rhs = rng.rand(2, 3, 7)
+ matrix_broadcast = matrix + np.zeros((2, 1, 1))
+
+ with self.test_session():
+ result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
+ self.assertAllEqual((2, 3, 7), result.get_shape())
+ expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
+ self.assertAllEqual(expected.eval(), result.eval())
+
+ def test_dynamic_dims_broadcast_64bit(self):
+ # batch_shape = [2, 2]
+ matrix = rng.rand(2, 3, 3)
+ rhs = rng.rand(2, 1, 3, 7)
+ matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
+ rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
+
+ matrix_ph = array_ops.placeholder(dtypes.float64)
+ rhs_ph = array_ops.placeholder(dtypes.float64)
+
+ with self.test_session() as sess:
+ result, expected = sess.run(
+ [
+ linear_operator_util.matrix_solve_with_broadcast(
+ matrix_ph, rhs_ph),
+ linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
+ ],
+ feed_dict={
+ matrix_ph: matrix,
+ rhs_ph: rhs,
+ })
+ self.assertAllEqual(expected, result)
+
+
+class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
+
+ def test_static_dims_broadcast(self):
+ # batch_shape = [2]
+ matrix = rng.rand(2, 3, 3)
+ rhs = rng.rand(3, 7)
+ rhs_broadcast = rhs + np.zeros((2, 1, 1))
+
+ with self.test_session():
+ result = linear_operator_util.matrix_triangular_solve_with_broadcast(
+ matrix, rhs)
+ self.assertAllEqual((2, 3, 7), result.get_shape())
+ expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
+ self.assertAllEqual(expected.eval(), result.eval())
+
+ def test_dynamic_dims_broadcast_64bit(self):
+ # batch_shape = [2]
+ matrix = rng.rand(2, 3, 3)
+ rhs = rng.rand(3, 7)
+ rhs_broadcast = rhs + np.zeros((2, 1, 1))
+
+ matrix_ph = array_ops.placeholder(dtypes.float64)
+ rhs_ph = array_ops.placeholder(dtypes.float64)
+
+ with self.test_session() as sess:
+ result, expected = sess.run(
+ [
+ linear_operator_util.matrix_triangular_solve_with_broadcast(
+ matrix_ph, rhs_ph),
+ linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
+ ],
+ feed_dict={
+ matrix_ph: matrix,
+ rhs_ph: rhs,
+ })
self.assertAllEqual(expected, result)
@@ -244,7 +364,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
operator = DomainDimensionStubOperator(3)
# Should not raise
linear_operator_util.assert_compatible_matrix_dimensions(
- operator, x).run()
+ operator, x).run() # pyformat: disable
def test_incompatible_dimensions_raise(self):
with self.test_session():
@@ -252,7 +372,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
operator = DomainDimensionStubOperator(3)
with self.assertRaisesOpError("Incompatible matrix dimensions"):
linear_operator_util.assert_compatible_matrix_dimensions(
- operator, x).run()
+ operator, x).run() # pyformat: disable
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index ad802f7e1f..55653489af 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -1124,40 +1124,91 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5)
- def testAUCPRSpecialCase(self):
+ # Regarding the AUC-PR tests: note that the preferred method when
+ # calculating AUC-PR is summation_method='careful_interpolation'.
+ def testCorrectAUCPRSpecialCase(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
- auc, update_op = metrics.auc(labels, predictions, curve='PR')
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='careful_interpolation')
+
+ sess.run(variables.local_variables_initializer())
+ # expected ~= 0.79726744594
+ expected = 1 - math.log(1.5) / 2
+ self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+ def testCorrectAnotherAUCPRSpecialCase(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
+ shape=(1, 7),
+ dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='careful_interpolation')
+
+ sess.run(variables.local_variables_initializer())
+ # expected ~= 0.61350593198
+ expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3
+ self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+ def testThirdCorrectAUCPRSpecialCase(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
+ shape=(1, 7),
+ dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='careful_interpolation')
+
+ sess.run(variables.local_variables_initializer())
+ # expected ~= 0.90410597584
+ expected = 1 - math.log(4./3) / 3
+ self.assertAlmostEqual(expected, sess.run(update_op), delta=1e-3)
+ self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
+
+ def testIncorrectAUCPRSpecialCase(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='trapezoidal')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
- def testAnotherAUCPRSpecialCase(self):
+ def testAnotherIncorrectAUCPRSpecialCase(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
- auc, update_op = metrics.auc(labels, predictions, curve='PR')
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='trapezoidal')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3)
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
- def testThirdAUCPRSpecialCase(self):
+ def testThirdIncorrectAUCPRSpecialCase(self):
with self.test_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
- auc, update_op = metrics.auc(labels, predictions, curve='PR')
+ auc, update_op = metrics.auc(labels, predictions, curve='PR',
+ summation_method='trapezoidal')
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3)
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index c31d5a1f91..edc63264a3 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -802,6 +802,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_update(v, [1], [3.0])
self.assertAllEqual([1.0, 3.0], v.numpy())
+ def testScatterAddStateOps(self):
+ with context.eager_mode():
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add")
+ state_ops.scatter_add(v, [1], [3])
+ self.assertAllEqual([1.0, 5.0], v.numpy())
+
def testScatterUpdateCast(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 242cdff6f3..ec741d3265 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -694,7 +694,8 @@ class Layer(checkpointable.CheckpointableBase):
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ if all(hasattr(x, 'get_shape') for x in input_list):
+ input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
self.build(input_shapes)
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially
diff --git a/tensorflow/python/lib/core/py_exception_registry.cc b/tensorflow/python/lib/core/py_exception_registry.cc
new file mode 100644
index 0000000000..6637de632b
--- /dev/null
+++ b/tensorflow/python/lib/core/py_exception_registry.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+
+#include <Python.h>
+
+namespace tensorflow {
+
+PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr;
+
+void PyExceptionRegistry::Init(PyObject* code_to_exc_type_map) {
+ DCHECK(singleton_ == nullptr) << "PyExceptionRegistry::Init() already called";
+ singleton_ = new PyExceptionRegistry;
+
+ DCHECK(PyDict_Check(code_to_exc_type_map));
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(code_to_exc_type_map, &pos, &key, &value)) {
+ TF_Code code = static_cast<TF_Code>(PyLong_AsLong(key));
+ singleton_->exc_types_[code] = value;
+ // The exception classes should also have the lifetime of the process, but
+ // incref just in case.
+ Py_INCREF(value);
+ }
+}
+
+PyObject* PyExceptionRegistry::Lookup(TF_Code code) {
+ DCHECK(singleton_ != nullptr) << "Must call PyExceptionRegistry::Init() "
+ "before PyExceptionRegistry::Lookup()";
+ DCHECK_NE(code, TF_OK);
+ DCHECK(singleton_->exc_types_.find(code) != singleton_->exc_types_.end())
+ << "Unknown error code passed to PyExceptionRegistry::Lookup: " << code;
+ return singleton_->exc_types_[code];
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/lib/core/py_exception_registry.h b/tensorflow/python/lib/core/py_exception_registry.h
new file mode 100644
index 0000000000..2b0f23b548
--- /dev/null
+++ b/tensorflow/python/lib/core/py_exception_registry.h
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
+
+#include <map>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/platform/logging.h"
+
+#ifndef PyObject_HEAD
+struct _object;
+typedef _object PyObject;
+#endif
+
+namespace tensorflow {
+
+// Global registry mapping C API error codes to the corresponding custom Python
+// exception type. This is used to expose the exception types to C extension
+// code (i.e. so we can raise custom exceptions via SWIG).
+//
+// Init() must be called exactly once at the beginning of the process before
+// Lookup() can be used.
+//
+// Example usage:
+// TF_Status* status = TF_NewStatus();
+// TF_Foo(..., status);
+//
+// if (TF_GetCode(status) != TF_OK) {
+// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
+// // Arguments to OpError base class. Set `node_def` and `op` to None.
+// PyObject* args =
+// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
+// PyErr_SetObject(exc_type, args);
+// Py_DECREF(args);
+// TF_DeleteStatus(status);
+// return NULL;
+// }
+class PyExceptionRegistry {
+ public:
+ // Initializes the process-wide registry. Should be called exactly once near
+ // the beginning of the process. The arguments are the various Python
+ // exception types (e.g. `cancelled_exc` corresponds to
+ // errors.CancelledError).
+ static void Init(PyObject* code_to_exc_type_map);
+
+ // Returns the Python exception type corresponding to `code`. Init() must be
+ // called before using this function. `code` should not be TF_OK.
+ static PyObject* Lookup(TF_Code code);
+
+ private:
+ static PyExceptionRegistry* singleton_;
+ PyExceptionRegistry() = default;
+
+ // Maps error codes to the corresponding Python exception type.
+ std::map<TF_Code, PyObject*> exc_types_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
diff --git a/tensorflow/python/lib/core/py_exception_registry.i b/tensorflow/python/lib/core/py_exception_registry.i
new file mode 100644
index 0000000000..e872b74985
--- /dev/null
+++ b/tensorflow/python/lib/core/py_exception_registry.i
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow::PyExceptionRegistry;
+%unignore tensorflow::PyExceptionRegistry::Init;
+
+%include "tensorflow/python/lib/core/py_exception_registry.h"
+%unignoreall
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 8247d354db..32ea737a99 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/py_util.h"
@@ -77,9 +78,9 @@ string PyRepr(PyObject* obj) {
bool IsPyDimension(PyObject* obj) {
const char* tp_name = obj->ob_type->tp_name;
if (strcmp(tp_name, "Dimension") != 0) return false;
- bool ret =
- StringPiece(PyRepr(PyType(obj)))
- .ends_with("tensorflow.python.framework.tensor_shape.Dimension'>");
+ bool ret = str_util::EndsWith(
+ PyRepr(PyType(obj)),
+ "tensorflow.python.framework.tensor_shape.Dimension'>");
return ret;
}
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 6fcf9c91d8..bf2d6f68b5 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -78,8 +78,7 @@ def tf_record_iterator(path, options=None):
try:
while True:
try:
- with errors.raise_exception_on_not_ok_status() as status:
- reader.GetNext(status)
+ reader.GetNext()
except errors.OutOfRangeError:
break
yield reader.record()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3c6a5c9e56..57d2657838 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -255,10 +255,15 @@ def _SliceGrad(op, grad):
@ops.RegisterGradient("StridedSlice")
def _StridedSliceGrad(op, grad):
"""Gradient for StridedSlice op."""
- x = array_ops.shape(op.inputs[0])
begin = op.inputs[1]
end = op.inputs[2]
strides = op.inputs[3]
+ # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
+ # same dtype so we build a shape of the same type as other args.
+ # Note that the choice of `begin` for specifying `out_type` is arbitrary.
+ # We could choose any of {begin|end|strides}.dtype since they are required to
+ # be the same.
+ x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
return array_ops.strided_slice_grad(
x,
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 207866610b..68d446602e 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -387,7 +387,10 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
"""
if context.executing_eagerly() and not isinstance(
input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
- return np.prod(ops.convert_to_tensor(input)._shape_tuple()) # pylint: disable=protected-access
+ input = ops.convert_to_tensor(input)
+ np_out_type = out_type.as_numpy_dtype
+ num_elements = np.prod(input._shape_tuple(), dtype=np_out_type) # pylint: disable=protected-acces:
+ return ops.convert_to_tensor(num_elements, dtype=out_type)
with ops.name_scope(name, "Size", [input]) as name:
if isinstance(input, (sparse_tensor.SparseTensor,
sparse_tensor.SparseTensorValue)):
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index ec623b55eb..0891bffdd5 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -166,7 +166,8 @@ class Uniform(distribution.Distribution):
return self.low + self.range() * samples
def _prob(self, x):
- broadcasted_x = x * array_ops.ones(self.batch_shape_tensor())
+ broadcasted_x = x * array_ops.ones(
+ self.batch_shape_tensor(), dtype=x.dtype)
return array_ops.where(
math_ops.is_nan(broadcasted_x),
broadcasted_x,
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a840b1eddf..161f6f3659 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,22 +27,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_functional_ops import *
-# pylint: enable=wildcard-import
# pylint: disable=unused-import
-from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
+from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
+from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -365,7 +367,15 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
dtype_flat = output_flatten(dtype)
# Convert elems to tensor array. n may be known statically.
- n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
+ static_shape = elems_flat[0].shape
+ if static_shape.ndims is not None and static_shape.ndims < 1:
+ if len(elems_flat) == 1:
+ raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
+ else:
+ raise ValueError(
+ "elements in elems must be 1+ dimensional Tensors, not scalars"
+ )
+ n = static_shape[0].value or array_ops.shape(elems_flat[0])[0]
# TensorArrays are always flat
elems_ta = [
@@ -634,3 +644,249 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
varscope.set_caching_device(None)
return output_pack(results_flat)
+
+
+# pylint: disable=invalid-name
+def If(cond, inputs, then_branch, else_branch, name=None):
+ r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs).
+
+ Args:
+ cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
+ converted to a boolean according to the following rule: if the
+ scalar is a numerical value, non-zero means True and zero means
+ False; if the scalar is a string, non-empty means True and empty
+ means False.
+ inputs: A list of input tensors.
+ then_branch: A function takes 'inputs' and returns a list of tensors,
+ whose types are the same as what else_branch returns.
+ else_branch: A function takes 'inputs' and returns a list of tensors.
+ whose types are the same as what then_branch returns.
+ name: A name for the operation (optional).
+
+ Returns:
+ A list of tensors returned by either then_branch(inputs)
+ or else_branch(inputs).
+ """
+ # pylint: disable=protected-access
+ return gen_functional_ops._if(
+ cond,
+ inputs, [_.type for _ in then_branch.definition.signature.output_arg],
+ then_branch,
+ else_branch,
+ name=name)
+
+
+def Gradient(inputs, f, name=None):
+ r"""Computes the gradient function for function f via backpropagation.
+
+ Args:
+ inputs: A list of tensors of size N + M.
+ f: The function we want to compute the gradient for.
+
+ The function 'f' must be a numerical function which takes N inputs and
+ produces M outputs. Its gradient function 'g', which is a function
+ taking N + M inputs and produces N outputs.
+
+ I.e. if we have
+ (y1, y2, ..., yM) = f(x1, x2, ..., xN),
+ then, g is
+ (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN,
+ dL/dy1, dL/dy2, ..., dL/dyM),
+
+ where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
+ loss function). dL/dxi is the partial derivative of L with respect
+ to xi.
+
+ name: A name for the operation (optional).
+
+ Returns:
+ A list of tensors of size N.
+ """
+ # TODO(zhifengc): Pretty-print the above spec in latex.
+ # TODO(zhfiengc): Needs some math expert to say the comment above better.
+ tlist = [_.type for _ in f.definition.signature.input_arg]
+ return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
+
+
+# pylint: disable=invalid-name,protected-access
+def While(input_, cond, body, name=None, hostmem=None):
+ r"""output = input; While (Cond(output)) { output = Body(output) }.
+
+ Args:
+ input_: A list of `Tensor` objects.
+ A list of input tensors whose types are T.
+ cond: . A function takes 'input' and returns a tensor. If the tensor is
+ a scalar of non-boolean, the scalar is converted to a boolean
+ according to the following rule: if the scalar is a numerical
+ value, non-zero means True and zero means False; if the scalar is
+ a string, non-empty means True and empty means False. If the
+ tensor is not a scalar, non-emptiness means True and False
+ otherwise.
+ body: . A funcion takes a list of tensors and returns another
+ list tensors. Both lists have the same types as specified
+ by T.
+ name: A name for the operation (optional).
+ hostmem: A list of integer. If i is in the list, input[i] is a
+ host memory tensor.
+
+ Returns:
+ A list of `Tensor` objects. Has the same type as `input`.
+ A list of output tensors whose types are T.
+ """
+ ret = gen_functional_ops._while(input_, cond, body, name=name)
+ if hostmem:
+ input_attr = attr_value_pb2.AttrValue()
+ input_attr.list.i.extend(hostmem)
+ ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access
+
+ output_attr = attr_value_pb2.AttrValue()
+ output_attr.list.i.extend(hostmem)
+ ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
+ return ret
+
+
+# b/36459430
+#
+# Ideally, we do not need this rewrite For loop into a While loop.
+# However, today, if a While runs on GPU and the condition returns a
+# boolean, the While kernel crashes. Even if we fix the crash, the
+# bool needs to be copied between GPU and CPU. So, a for loop is much
+# preferred when running on GPU.
+#
+# On the other hand, For op has no directly XLA kernel. So, when we run
+# a for loop, we need to rewrite it using a While op.
+#
+# It should be possible and probably better to write a XLA C++ kernel
+# implementing the logic in _ForUsingWhile.
+def _ForUsingWhile(start,
+ limit,
+ delta,
+ inputs,
+ forbody,
+ name=None,
+ hostmem=None):
+ """Helper to implement a For loop using a While."""
+ # To support negative delta (e.g., range(100, 0, -3)), we iterate
+ # over the range(n) and use iter * delta + start as the real
+ # iteration index. (e.g., for i in range(34): iter = i * (-3) +
+ # 100).
+ d = math_ops.abs(delta)
+ # XLA on TPUs doesn't support integer division
+ n = math_ops.cast(
+ math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
+ math_ops.cast(d, dtypes.float32), dtypes.int32)
+
+ # Carried loop variables ("extra_args") are implicitly added to the input list
+ # of the WhileBody function. WhileCond does not call forbody, and so does not
+ # depend on any of forbody's extra_args. Since WhileCond and WhileBody
+ # must have identical inputs, we have to augment the cond signature to take
+ # the same types as the carried loop variables.
+ body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
+ cond_sig = body_sig + [t.dtype for t in forbody.captured_inputs]
+
+ cond_name = "%s_Cond" % forbody.name
+
+ @function.Defun(*cond_sig, func_name=cond_name)
+ def WhileCond(i, n, *args):
+ del args
+ return i < n
+
+ body_name = "%s_Body" % forbody.name
+
+ @function.Defun(*body_sig, func_name=body_name)
+ def WhileBody(i, n, start, delta, *args):
+ """A While wrapper for forbody that handles loop-carried captured inputs."""
+ for_result = forbody(start + i * delta, *args)
+ # Nullary functions return an Operation. Normal functions can't do this
+ # because their return values are converted to Tensors.
+ if isinstance(for_result, ops.Operation):
+ for_result = ()
+ # Unary functions return a single Tensor value.
+ elif isinstance(for_result, ops.Tensor):
+ for_result = (for_result,)
+ extra_args = tuple(function.get_extra_args())
+ return (i + 1, n, start, delta) + tuple(for_result) + extra_args
+
+ if hostmem is not None:
+ hostmem = [(4 + _) for _ in hostmem]
+
+ results = While(
+ input_=[0, n, start, delta] + inputs + WhileBody.captured_inputs,
+ cond=WhileCond,
+ body=WhileBody,
+ name=name,
+ hostmem=hostmem)
+ # Slice off the loop-carried captured inputs.
+ return list(results[4:len(results) - len(WhileBody.captured_inputs)])
+
+
+def For(start,
+ limit,
+ delta,
+ inputs,
+ body,
+ name=None,
+ hostmem=None,
+ rewrite_with_while=None):
+ r"""out = input; for i in range(start, limit, delta) out = body(i, out).
+
+ Args:
+ start: A `Tensor` of type `int32`.
+ limit: A `Tensor` of type `int32`.
+ delta: A `Tensor` of type `int32`.
+ inputs: A list of `Tensor` objects.
+ A list of input tensors whose types are T.
+ body: A function takes a list of tensors and returns another
+ list of tensors. Both lists have the same types as (int32, T...).
+ name: A name for the operation (optional).
+ hostmem: A list of integer. If i is in the list, inputs[i] is a
+ host memory tensor. In other words, (i+1)-th argument of the body
+ function is expecting a host memory.
+ rewrite_with_while: If True, using While op to implement the For.
+
+ Returns:
+ A list of `Tensor` objects. Has the same type as `input`.
+ A list of output tensors whose types are T.
+ """
+ if rewrite_with_while:
+ return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
+ if body.captured_inputs:
+ wrapper_name = "%s_BodyWrapper" % body.name
+
+ @function.Defun(*body.declared_input_types, func_name=wrapper_name)
+ def BodyWrapper(*args):
+ """A wrapper for body that handles loop-carried captured inputs."""
+ body_result = body(*args)
+ extra_args = tuple(function.get_extra_args())
+ # Nullary functions return an Operation. Normal functions can't do this
+ # because their return values are converted to Tensors.
+ if isinstance(body_result, ops.Operation):
+ return extra_args
+ # Unary functions return a single Tensor value.
+ elif not isinstance(body_result, tuple):
+ return (body_result,) + extra_args
+ # N-ary functions return a tuple of Tensors.
+ else:
+ return body_result + extra_args
+
+ inputs += BodyWrapper.captured_inputs
+ ret = gen_functional_ops._for(
+ start, limit, delta, inputs, BodyWrapper, name=name)
+ # Slice off the loop-carried captured inputs.
+ ret = ret[:-len(BodyWrapper.captured_inputs)]
+ else:
+ ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
+ if hostmem:
+ num_for_params = 3 # start/limit/delta
+
+ input_attr = attr_value_pb2.AttrValue()
+ input_attr.list.i.extend([num_for_params + i for i in hostmem])
+ ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access
+
+ output_attr = attr_value_pb2.AttrValue()
+ output_attr.list.i.extend(hostmem)
+ ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
+ return ret
+
+
+# pylint: enable=invalid-name,protected-access
diff --git a/tensorflow/python/ops/linalg/linear_operator_util.py b/tensorflow/python/ops/linalg/linear_operator_util.py
index 427bd1e890..9dd40765c2 100644
--- a/tensorflow/python/ops/linalg/linear_operator_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_util.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@@ -102,6 +103,22 @@ def assert_is_batch_matrix(tensor):
"%s" % tensor)
+def shape_tensor(shape, name=None):
+ """Convert Tensor using default type, unless empty list or tuple."""
+ # Works just like random_ops._ShapeTensor.
+ if isinstance(shape, (tuple, list)) and not shape:
+ dtype = dtypes.int32
+ else:
+ dtype = None
+ return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+
+
+################################################################################
+# Broadcasting versions of common linear algebra functions.
+# TODO(b/77519145) Do this more efficiently in some special cases.
+################################################################################
+
+
def broadcast_matrix_batch_dims(batch_matrices, name=None):
"""Broadcast leading dimensions of zero or more [batch] matrices.
@@ -170,7 +187,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_static_shape(
- bcast_batch_shape, mat.get_shape()[:-2])
+ bcast_batch_shape,
+ mat.get_shape()[:-2])
if bcast_batch_shape.is_fully_defined():
# The [1, 1] at the end will broadcast with anything.
bcast_shape = bcast_batch_shape.concatenate([1, 1])
@@ -183,7 +201,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_dynamic_shape(
- bcast_batch_shape, array_ops.shape(mat)[:-2])
+ bcast_batch_shape,
+ array_ops.shape(mat)[:-2])
bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
@@ -195,6 +214,13 @@ def _broadcast_to_shape(x, shape):
return x + array_ops.zeros(shape=shape, dtype=x.dtype)
+def cholesky_solve_with_broadcast(chol, rhs, name=None):
+ """Solve systems of linear equations."""
+ with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]):
+ chol, rhs = broadcast_matrix_batch_dims([chol, rhs])
+ return linalg_ops.cholesky_solve(chol, rhs)
+
+
def matmul_with_broadcast(a,
b,
transpose_a=False,
@@ -206,6 +232,11 @@ def matmul_with_broadcast(a,
name=None):
"""Multiplies matrix `a` by matrix `b`, producing `a @ b`.
+ Works identically to `tf.matmul`, but broadcasts batch dims
+ of `a` and `b` (by replicating) if they are determined statically to be
+ different, or if static shapes are not fully defined. Thus, this may result
+ in an inefficient replication of data.
+
The inputs must be matrices (or tensors of rank > 2, representing batches of
matrices).
@@ -276,7 +307,7 @@ def matmul_with_broadcast(a,
ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
are both set to True.
"""
- with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name:
+ with ops.name_scope(name, "MatMulWithBroadcast", [a, b]):
a, b = broadcast_matrix_batch_dims([a, b])
return math_ops.matmul(
a,
@@ -289,11 +320,43 @@ def matmul_with_broadcast(a,
b_is_sparse=b_is_sparse)
-def shape_tensor(shape, name=None):
- """Convert Tensor using default type, unless empty list or tuple."""
- # Works just like random_ops._ShapeTensor.
- if isinstance(shape, (tuple, list)) and not shape:
- dtype = dtypes.int32
- else:
- dtype = None
- return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
+ """Solve systems of linear equations."""
+ with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
+ matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
+ return linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
+
+
+def matrix_triangular_solve_with_broadcast(matrix,
+ rhs,
+ lower=True,
+ adjoint=False,
+ name=None):
+ """Solves triangular systems of linear equations with by backsubstitution.
+
+ Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims
+ of `matrix` and `rhs` (by replicating) if they are determined statically to be
+ different, or if static shapes are not fully defined. Thus, this may result
+ in an inefficient replication of data.
+
+ Args:
+ matrix: A Tensor. Must be one of the following types:
+ `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
+ rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
+ Shape is `[..., M, K]`.
+ lower: An optional `bool`. Defaults to `True`. Indicates whether the
+ innermost matrices in `matrix` are lower or upper triangular.
+ adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
+ with matrix or its (block-wise) adjoint.
+ name: A name for the operation (optional).
+
+ Returns:
+ `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
+ """
+ with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
+ matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
+ return linalg_ops.matrix_triangular_solve(
+ matrix,
+ rhs,
+ lower=lower,
+ adjoint=adjoint)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 276897ab99..b460ce5b95 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -174,6 +174,7 @@ from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# Aliases for some automatically-generated names.
@@ -184,7 +185,6 @@ arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min) # pylin
tf_export("arg_max")(arg_max)
tf_export("arg_min")(arg_min)
-
# This is set by resource_variable_ops.py. It is included in this way since
# there is a circular dependency between math_ops and resource_variable_ops
_resource_variable_type = None
@@ -1343,8 +1343,7 @@ def _ReductionDims(x, axis, reduction_indices):
else:
# Fast path: avoid creating Rank and Range ops if ndims is known.
if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access
- return constant_op.constant(
- np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access
+ return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access
if (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.get_shape().is_fully_defined()):
rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D.
@@ -1522,7 +1521,7 @@ def reduce_mean(input_tensor,
input_tensor: The tensor to reduce. Should have numeric type.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
- `[-rank(input_tensor), rank(input_tensor)]`.
+ `[-rank(input_tensor), rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -2273,10 +2272,11 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
ValueError: If `inputs` don't all have same shape and dtype or the shape
cannot be inferred.
"""
+
def _input_error():
- return ValueError(
- "inputs must be a list of at least one Tensor with the "
- "same dtype and shape")
+ return ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+
if not inputs or not isinstance(inputs, (list, tuple)):
raise _input_error()
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
@@ -2294,8 +2294,8 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
# tensor_dtype is for safety only; operator's output type computed in C++
if tensor_dtype is not None and tensor_dtype != inputs[0].dtype:
- raise TypeError("tensor_dtype is {}, but input is of type {}"
- .format(tensor_dtype, inputs[0].dtype))
+ raise TypeError("tensor_dtype is {}, but input is of type {}".format(
+ tensor_dtype, inputs[0].dtype))
if len(inputs) == 1 and name is None:
return inputs[0]
@@ -2761,14 +2761,14 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
name=name)
else:
return gen_math_ops.sparse_segment_sum(
- data=data,
- indices=indices,
- segment_ids=segment_ids,
- name=name)
+ data=data, indices=indices, segment_ids=segment_ids, name=name)
@tf_export("sparse_segment_mean")
-def sparse_segment_mean(data, indices, segment_ids, name=None,
+def sparse_segment_mean(data,
+ indices,
+ segment_ids,
+ name=None,
num_segments=None):
r"""Computes the mean along sparse segments of a tensor.
@@ -2805,14 +2805,14 @@ def sparse_segment_mean(data, indices, segment_ids, name=None,
name=name)
else:
return gen_math_ops.sparse_segment_mean(
- data=data,
- indices=indices,
- segment_ids=segment_ids,
- name=name)
+ data=data, indices=indices, segment_ids=segment_ids, name=name)
@tf_export("sparse_segment_sqrt_n")
-def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
+def sparse_segment_sqrt_n(data,
+ indices,
+ segment_ids,
+ name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
@@ -2842,10 +2842,7 @@ def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
name=name)
else:
return gen_math_ops.sparse_segment_sqrt_n(
- data=data,
- indices=indices,
- segment_ids=segment_ids,
- name=name)
+ data=data, indices=indices, segment_ids=segment_ids, name=name)
@tf_export("tensordot", "linalg.tensordot")
@@ -3016,6 +3013,47 @@ def tensordot(a, b, axes, name=None):
return product
+@tf_export("math.polyval")
+def polyval(coeffs, x, name=None):
+ r"""Computes the elementwise value of a polynomial.
+
+ If `x` is a tensor and `coeffs` is a list n + 1 tensors, this function returns
+ the value of the n-th order polynomial
+
+ p(x) = coeffs[n-1] + coeffs[n-2] * x + ... + coeffs[0] * x**(n-1)
+
+ evaluated using Horner's method, i.e.
+
+ p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] +
+ x * coeffs[0]))
+
+ Args:
+ coeffs: A list of `Tensor` representing the coefficients of the polynomial.
+ x: A `Tensor` representing the variable of the polynomial.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `tensor` of the shape as the expression p(x) with usual broadcasting rules
+ for element-wise addition and multiplication applied.
+
+ @compatibility(numpy)
+ Equivalent to numpy.polyval.
+ @end_compatibility
+ """
+
+ with ops.name_scope(name, "polyval", nest.flatten(coeffs) + [x]) as name:
+ x = ops.convert_to_tensor(x, name="x")
+ if len(coeffs) < 1:
+ return array_ops.zeros_like(x, name=name)
+ coeffs = [
+ ops.convert_to_tensor(coeff, name=("coeff_%d" % index))
+ for index, coeff in enumerate(coeffs)
+ ]
+ p = coeffs[0]
+ for c in coeffs[1:]:
+ p = c + p * x
+ return p
+
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
# 1.0 API so we leave these here for backwards compatibility.
fft = gen_spectral_ops.fft
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 9f85188b35..05bcee8801 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -155,9 +155,7 @@ class RoundTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
def testRounding(self):
- x = [0.49, 0.7, -0.3, -0.8]
- # TODO(nolivia): Remove this when RoundOp is forwards compatible
- # x = np.arange(-5.0, 5.0, .25)
+ x = np.arange(-5.0, 5.0, .25)
for dtype in [np.float32, np.double, np.int32]:
x_np = np.array(x, dtype=dtype)
with test_util.device(use_gpu=True):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 9ec4954579..47eea6ef6b 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -626,10 +627,16 @@ def auc(labels,
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
name: An optional variable_scope name.
- summation_method: Specifies the Riemann summation method used, 'trapezoidal'
- [default] that applies the trapezoidal rule, 'minoring' that applies
- left summation for increasing intervals and right summation for decreasing
- intervals or 'majoring' that applies the opposite.
+ summation_method: Specifies the Riemann summation method used
+ (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
+ applies the trapezoidal rule; 'careful_interpolation', a variant of it
+ differing only by a more correct interpolation scheme for PR-AUC -
+ interpolating (true/false) positives but not the ratio that is precision;
+ 'minoring' that applies left summation for increasing intervals and right
+ summation for decreasing intervals; 'majoring' that does the opposite.
+ Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
+ (to be deprecated soon) as it applies the same method for ROC, and a
+ better one (see Davis & Goadrich 2006 for details) for the PR curve.
Returns:
auc: A scalar `Tensor` representing the current area-under-curve.
@@ -664,8 +671,62 @@ def auc(labels,
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
+ def interpolate_pr_auc(tp, fp, fn):
+ """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
+
+ Note here we derive & use a closed formula not present in the paper
+ - as follows:
+ Modeling all of TP (true positive weight),
+ FP (false positive weight) and their sum P = TP + FP (positive weight)
+ as varying linearly within each interval [A, B] between successive
+ thresholds, we get
+ Precision = (TP_A + slope * (P - P_A)) / P
+ with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
+ The area within the interval is thus (slope / total_pos_weight) times
+ int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
+ int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
+ where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
+ int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
+ Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
+ slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
+ where dTP == TP_B - TP_A.
+ Note that when P_A == 0 the above calculation simplifies into
+ int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
+ which is really equivalent to imputing constant precision throughout the
+ first bucket having >0 true positives.
+
+ Args:
+ tp: true positive counts
+ fp: false positive counts
+ fn: false negative counts
+ Returns:
+ pr_auc: an approximation of the area under the P-R curve.
+ """
+ dtp = tp[:num_thresholds - 1] - tp[1:]
+ p = tp + fp
+ prec_slope = _safe_div(dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope')
+ intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
+ safe_p_ratio = array_ops.where(
+ math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
+ _safe_div(p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'),
+ array_ops.ones_like(p[1:]))
+ return math_ops.reduce_sum(
+ _safe_div(
+ prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
+ tp[1:] + fn[1:],
+ name='pr_auc_increment'),
+ name='interpolate_pr_auc')
+
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
+ if curve == 'PR':
+ if summation_method == 'trapezoidal':
+ logging.warning(
+ 'Trapezoidal rule is known to produce incorrect PR-AUCs; '
+ 'please switch to "careful_interpolation" instead.')
+ elif summation_method == 'careful_interpolation':
+ # This one is a bit tricky and is handled separately.
+ return interpolate_pr_auc(tp, fp, fn)
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
fp_rate = math_ops.div(fp, fp + tn + epsilon)
@@ -675,7 +736,9 @@ def auc(labels,
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
x = rec
y = prec
- if summation_method == 'trapezoidal':
+ if summation_method in ('trapezoidal', 'careful_interpolation'):
+ # Note that the case ('PR', 'careful_interpolation') has been handled
+ # above.
return math_ops.reduce_sum(
math_ops.multiply(x[:num_thresholds - 1] - x[1:],
(y[:num_thresholds - 1] + y[1:]) / 2.),
@@ -923,8 +986,8 @@ def mean_per_class_accuracy(labels,
weights = array_ops.reshape(weights, [-1])
weights = math_ops.to_float(weights)
- is_correct = is_correct * weights
- ones = ones * weights
+ is_correct *= weights
+ ones *= weights
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 0c55386241..07ca32953f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1808,7 +1808,7 @@ def softmax_cross_entropy_with_logits_v2(
or `float64`).
Backpropagation will happen into both `logits` and `labels`. To disallow
- backpropagation into `labels`, pass label tensors through a `stop_gradients`
+ backpropagation into `labels`, pass label tensors through @{tf.stop_gradient}
before feeding it to this function.
**Note that to avoid confusion, it is required to pass only named arguments to
@@ -1895,7 +1895,7 @@ _XENT_DEPRECATION = """
Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.
-See tf.nn.softmax_cross_entropy_with_logits_v2.
+See @{tf.nn.softmax_cross_entropy_with_logits_v2}.
"""
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index da86d5f6ca..46a5f4fae6 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1081,6 +1081,42 @@ class DataFormatDimMapTest(test_lib.TestCase):
self._test([1, -3, -2], [2, 2, 3])
self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
+ def testNHWCtoNCHW(self):
+ x_val = [1, -3, -2]
+ y_val_expected = [2, 2, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def testNHWCtoHWNC(self):
+ x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
+ y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def testNHWCtoWHCN(self):
+ x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
+ y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def testArbitraryASCII(self):
+ x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
+ y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
class DataFormatVectorPermuteTest(test_lib.TestCase):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 2f39ea2e7d..07e25e540c 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -171,7 +171,9 @@ class ResourceVariable(variables.Variable):
to see all modifications to the value of the variable which happen in any
operation on which the read_value depends on (either directly, indirectly, or
via a control dependency) and guaranteed to not see any modification to the
- value of the variable on which the read_value operation does not depend on.
+ value of the variable from operations that depend on the read_value operation.
+ Updates from operations that have no dependency relationship to the read_value
+ operation might or might not be visible to read_value.
For example, if there is more than one assignment to a ResourceVariable in
a single session.run call there is a well-defined value for each operation
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 01fc3182bc..f6a11ca625 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -423,3 +423,55 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype),
use_locking, name)]):
return ref.read_value()
+
+
+@tf_export("scatter_add")
+def scatter_add(ref, indices, updates, use_locking=False, name=None):
+ # pylint: disable=line-too-long
+ r"""Adds sparse updates to the variable referenced by `resource`.
+
+ This operation computes
+
+ ```python
+ # Scalar indices
+ ref[indices, ...] += updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] += updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
+ ```
+
+ This operation outputs `ref` after the update is done.
+ This makes it easier to chain operations that need to use the updated value.
+ Duplicate entries are handled correctly: if multiple `indices` reference
+ the same location, their contributions add.
+
+ Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+ <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+ </div>
+
+ Args:
+ ref: A `Variable`.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into the first dimension of `ref`.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to store in `ref`.
+ use_locking: An optional `bool`. Defaults to `True`.
+ If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_add(ref, indices, updates,
+ use_locking=use_locking, name=name)
+ return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index c35735ca65..e33085ba62 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1164,7 +1164,7 @@ class _VariableScopeStore(threading.local):
self.variable_scopes_count[scope_name] = 1
def close_variable_subscopes(self, scope_name):
- for k in self.variable_scopes_count:
+ for k in list(self.variable_scopes_count.keys()):
if not scope_name or k.startswith(scope_name + "/"):
self.variable_scopes_count[k] = 0
diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i
index dbefca2be9..478dd46f7e 100644
--- a/tensorflow/python/platform/base.i
+++ b/tensorflow/python/platform/base.i
@@ -229,3 +229,25 @@ _COPY_TYPEMAPS(unsigned int, mode_t);
%define final %enddef
%define override %enddef
#endif
+
+// Typemaps to automatically raise a Python exception from bad output TF_Status.
+// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
+// raise_exception_on_not_ok_status (currently it only affects the C API).
+%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
+ $1 = TF_NewStatus();
+}
+
+%typemap(freearg) (TF_Status* status) {
+ TF_DeleteStatus($1);
+}
+
+%typemap(argout) TF_Status* status {
+ TF_Code code = TF_GetCode($1);
+ if (code != TF_OK) {
+ PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
+ // Arguments to OpError.
+ PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
+ SWIG_SetErrorObj(exc, exc_args);
+ SWIG_fail;
+ }
+}
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 39fabb9c1b..7acb8eeb1a 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+%include "tensorflow/python/platform/base.i"
+
%ignore "";
%rename("%s") TFE_NewContext;
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 82b908ac0e..26e8acd897 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -25,6 +25,7 @@ limitations under the License.
%include "tensorflow/python/util/tfprof.i"
%include "tensorflow/python/lib/core/py_func.i"
+%include "tensorflow/python/lib/core/py_exception_registry.i"
%include "tensorflow/python/lib/io/py_record_reader.i"
%include "tensorflow/python/lib/io/py_record_writer.i"
@@ -54,4 +55,3 @@ limitations under the License.
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
-
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py
index 902748d55e..dac6a06a89 100644
--- a/tensorflow/python/tools/optimize_for_inference.py
+++ b/tensorflow/python/tools/optimize_for_inference.py
@@ -87,7 +87,9 @@ def main(unused_args):
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
- FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
+ FLAGS.output_names.split(","),
+ FLAGS.placeholder_type_enum,
+ FLAGS.toco_compatible)
if FLAGS.frozen_graph:
f = gfile.FastGFile(FLAGS.output, "w")
@@ -138,6 +140,14 @@ def parse_args():
type=int,
default=dtypes.float32.as_datatype_enum,
help="The AttrValue enum to use for placeholders.")
+ parser.add_argument(
+ "--toco_compatible",
+ type=bool,
+ default=False,
+ help="""\
+ If true, only use ops compatible with Tensorflow
+ Lite Optimizing Converter.\
+ """)
return parser.parse_known_args()
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index 9c19271222..bb90d1cd6e 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -87,7 +87,7 @@ EPSILON_ATTR = {
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
- placeholder_type_enum):
+ placeholder_type_enum, toco_compatible=False):
"""Applies a series of inference optimizations on the input graph.
Args:
@@ -98,6 +98,8 @@ def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
results.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
+ toco_compatible: Boolean, if True, only runs optimizations that result in
+ TOCO compatible graph operations (default=False).
Returns:
An optimized version of the input graph.
@@ -110,8 +112,9 @@ def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
optimized_graph_def = graph_util.remove_training_nodes(
optimized_graph_def, output_node_names)
optimized_graph_def = fold_batch_norms(optimized_graph_def)
- optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
- output_node_names)
+ if not toco_compatible:
+ optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
+ output_node_names)
ensure_graph_is_valid(optimized_graph_def)
return optimized_graph_def
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index c82b898bd0..16e200d64d 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -376,7 +376,9 @@ class DistributionStrategy(object):
update. Allreduce is an algorithm for performing a reduction on
values from multiple devices and making the result available on
all of those devices.
- * TODO(josh11b): Future: partitioned variables
+ * In the future we will have support for TensorFlows' partitioned
+ variables, where a single variable is split across multiple
+ devices.
We have then a few approaches we want to support:
* Code written (as if) with no knowledge of class `DistributionStrategy`.
@@ -390,7 +392,6 @@ class DistributionStrategy(object):
```
with my_distribution.scope():
iterator = my_distribution.distribute_dataset(dataset)
- # TODO(josh11b): iterator = dataset.make_one_shot_iterator()
tower_train_ops = my_distribution.call_for_each_tower(
tower_fn, iterator.get_next())
train_op = tf.group(my_distribution.unwrap(tower_train_ops))
@@ -402,6 +403,10 @@ class DistributionStrategy(object):
using `my_distribution`'s policy, and library functions called by
`tower_fn` can use the `get_tower_context()` API to get enhanced
behavior in this case.
+
+ Note that in the future we will add support for initializable
+ Dataset iterators, at which point this example code will change.
+
* If you want to write a distributed algorithm, you may use any of
the `DistributionStrategy` APIs inside a
`with my_distribution.scope():` block of code.
@@ -514,7 +519,7 @@ class DistributionStrategy(object):
Steps 3 and 4 are done automatically by class `Optimizer` if you call
its `apply_gradients` method in a tower context. Otherwise you can
- manually call its `distributed_apply` method in a cross-tower context.
+ manually call its `_distributed_apply` method in a cross-tower context.
Another thing you might want to do in the middle of your tower function
is an all-reduce of some intermediate value, using `d.reduce()` or
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 44f00a96de..caa26581e8 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -515,8 +515,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
def _sparse_values_to_keep(t, keep_input):
"""Convert a per-row `keep_input` vector to a per-value one."""
# Get the rows of every value in the sparse Tensor.
- row_values = array_ops.reshape(
- t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0]
+ row_values = t.indices[:, 0]
# The value should be kept iff the row should be kept.
return array_ops.gather(keep_input, row_values)
if keep_input.shape.ndims == 1:
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index 27cdb860fe..1913fc20ee 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -75,7 +75,6 @@ cc_library(
":stream_executor",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:ops_util",
- "@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
] + if_cuda_is_configured([
"//tensorflow/core:cuda",
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 1aea0485fd..f408c06f46 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <functional>
#include <memory>
-#include "absl/strings/str_cat.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/env_var.h"
@@ -113,7 +112,7 @@ string ToString(libraryPropertyType type) {
case PATCH_LEVEL:
return "PATCH_LEVEL";
default:
- return absl::StrCat(
+ return port::StrCat(
"<unknown libraryPropertyType: ", static_cast<int>(type), ">");
}
}
@@ -375,7 +374,7 @@ port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
cudnnStatus_t status = cudnnGetProperty(type, value);
if (status != CUDNN_STATUS_SUCCESS) {
const string error =
- absl::StrCat("cudnnGetProperty failed for type: ", ToString(type),
+ port::StrCat("cudnnGetProperty failed for type: ", ToString(type),
" with status: ", ToString(status));
LOG(ERROR) << error;
return port::Status{port::error::INTERNAL, error};
@@ -419,7 +418,7 @@ port::Status CudnnSupport::Init() {
CudnnVersion loaded_version;
TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
- const tensorflow::string error = absl::StrCat(
+ const tensorflow::string error = port::StrCat(
"Loaded runtime CuDNN library: ", loaded_version.ToString(),
" but source was compiled with: ", source_version.ToString(),
". CuDNN library major and minor version needs to match or have "
diff --git a/tensorflow/stream_executor/cuda/cudnn_version.h b/tensorflow/stream_executor/cuda/cudnn_version.h
index 058cc87bfa..2ed02e1700 100644
--- a/tensorflow/stream_executor/cuda/cudnn_version.h
+++ b/tensorflow/stream_executor/cuda/cudnn_version.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
-#include "absl/strings/str_join.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace perftools {
namespace gputools {
@@ -30,8 +30,9 @@ struct CudnnVersion {
CudnnVersion(int major, int minor, int patch)
: major_version(major), minor_version(minor), patch_level(patch) {}
- std::string ToString() const {
- return absl::StrJoin({major_version, minor_version, patch_level}, ".");
+ tensorflow::string ToString() const {
+ return tensorflow::strings::StrCat(major_version, ".", minor_version, ".",
+ patch_level);
}
int major_version;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index fcc57d506e..fd44b0eb3b 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -304,6 +304,7 @@ def tf_cc_shared_object(
clean_dep("//tensorflow:darwin"): [
"-Wl,-install_name,@rpath/" + name.split("/")[-1],
],
+ clean_dep("//tensorflow:windows"): [],
"//conditions:default": [
"-Wl,-soname," + name.split("/")[-1],
],
@@ -342,6 +343,22 @@ register_extension_info(
label_regex_for_dep = "{extension_name}.*",
)
+# A simple wrap around native.cc_binary rule.
+# When using this rule, you should realize it doesn't link to any tensorflow
+# dependencies by default.
+def tf_native_cc_binary(name,
+ copts=tf_copts(),
+ **kwargs):
+ native.cc_binary(
+ name=name,
+ copts=copts,
+ **kwargs)
+
+register_extension_info(
+ extension_name = "tf_native_cc_binary",
+ label_regex_for_dep = "{extension_name}.*",
+)
+
def tf_gen_op_wrapper_cc(name,
out_ops_file,
pkg="",
@@ -622,9 +639,12 @@ def tf_cc_test(name,
linkopts=select({
clean_dep("//tensorflow:android"): [
"-pie",
- ],
+ ],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
+ clean_dep("//tensorflow:darwin"): [
+ "-lm",
+ ],
"//conditions:default": [
"-lpthread",
"-lm"
@@ -910,6 +930,7 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
if 'linkstatic' not in kwargs or kwargs['linkstatic'] != 1:
enable_text_relocation_linkopt = select({
clean_dep("//tensorflow:darwin"): [],
+ clean_dep("//tensorflow:windows"): [],
"//conditions:default": ['-Wl,-z,notext'],})
if 'linkopts' in kwargs:
kwargs['linkopts'] += enable_text_relocation_linkopt
@@ -1178,6 +1199,20 @@ def tf_custom_op_library_additional_deps():
"@protobuf_archive//:protobuf_headers",
clean_dep("//third_party/eigen3"),
clean_dep("//tensorflow/core:framework_headers_lib"),
+ ] + if_windows(["//tensorflow/python:pywrap_tensorflow_import_lib"])
+
+# A list of targets that contains the implemenation of
+# tf_custom_op_library_additional_deps. It's used to generate a DEF file for
+# exporting symbols from _pywrap_tensorflow.dll on Windows.
+def tf_custom_op_library_additional_deps_impl():
+ return [
+ "@protobuf_archive//:protobuf",
+ "@nsync//:nsync_cpp",
+ # for //third_party/eigen3
+ clean_dep("//third_party/eigen3"),
+ # for //tensorflow/core:framework_headers_lib
+ clean_dep("//tensorflow/core:framework"),
+ clean_dep("//tensorflow/core:reader_base"),
]
# Traverse the dependency graph along the "deps" attribute of the
@@ -1264,6 +1299,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]):
deps=deps + if_cuda(cuda_deps),
data=[name + "_check_deps"],
copts=tf_copts(is_external=True),
+ features = ["windows_export_all_symbols"],
linkopts=linkopts + select({
"//conditions:default": [
"-lm",
@@ -1410,7 +1446,8 @@ def tf_py_wrap_cc(name,
]) + tf_extension_copts()),
linkopts=tf_extension_linkopts() + extra_linkopts,
linkstatic=1,
- deps=deps + extra_deps)
+ deps=deps + extra_deps,
+ **kwargs)
native.genrule(
name="gen_" + cc_library_pyd_name,
srcs=[":" + cc_library_name],
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index 6722536358..9f1bdd8aae 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -93,6 +93,7 @@ genrule(
"api/logging/__init__.py",
"api/losses/__init__.py",
"api/manip/__init__.py",
+ "api/math/__init__.py",
"api/metrics/__init__.py",
"api/nn/__init__.py",
"api/nn/rnn_cell/__init__.py",
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
index 759ff752b0..05e603efb7 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
@@ -7,10 +7,6 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
- name: "distribute"
- mtype: "<type \'property\'>"
- }
- member {
name: "evaluation_master"
mtype: "<type \'property\'>"
}
@@ -82,9 +78,13 @@ tf_class {
name: "tf_random_seed"
mtype: "<type \'property\'>"
}
+ member {
+ name: "train_distribute"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\'], "
}
member_method {
name: "replace"
diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
new file mode 100644
index 0000000000..897718c05e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.math"
+tf_module {
+ member_method {
+ name: "polyval"
+ argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 937044aece..afa3b78eb7 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -405,6 +405,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "math"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "metrics"
mtype: "<type \'module\'>"
}
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 603b2a4327..7eeae05847 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -145,6 +145,9 @@ class ApiCompatibilityTest(test.TestCase):
verbose_diff_message = ''
# First check if the key is not found in one or the other.
if key in only_in_expected:
+ # TODO(annarev): remove once we switch to using tf_export decorators.
+ if key == 'tensorflow.math':
+ continue
diff_message = 'Object %s expected but not found (removed). %s' % (
key, additional_missing_object_message)
verbose_diff_message = diff_message
@@ -229,6 +232,13 @@ class ApiCompatibilityTest(test.TestCase):
for filename in golden_file_list
}
+ # TODO(annarev): remove once we switch to using tf_export decorators.
+ tf_module = golden_proto_dict['tensorflow'].tf_module
+ for i in range(len(tf_module.member)):
+ if tf_module.member[i].name == 'math':
+ del tf_module.member[i]
+ break
+
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 8b8ba31a0d..438c5d52f6 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -65,4 +65,6 @@ bazel test -c opt $BUILD_OPTS -k --test_output=errors \
--define=no_tensorflow_py_deps=true --test_lang_filters=py \
--test_tag_filters=-no_pip,-no_windows,-no_oss \
--build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \
- //${PY_TEST_DIR}/tensorflow/python/...
+ --flaky_test_attempts=3 \
+ //${PY_TEST_DIR}/tensorflow/python/... \
+ //${PY_TEST_DIR}/tensorflow/contrib/...
diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD
new file mode 100644
index 0000000000..e390e0fb05
--- /dev/null
+++ b/tensorflow/tools/def_file_filter/BUILD
@@ -0,0 +1,9 @@
+# Description:
+# Tools for filtering DEF file for TensorFlow on Windows
+#
+# On Windows, we use a DEF file generated by Bazel to export
+# symbols from the tensorflow dynamic library(_pywrap_tensorflow.dll).
+# The maximum number of symbols that can be exported per DLL is 64K,
+# so we have to filter some useless symbols through this python script.
+
+package(default_visibility = ["//visibility:public"])
diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl
new file mode 100644
index 0000000000..3cb72f4979
--- /dev/null
+++ b/tensorflow/tools/def_file_filter/BUILD.tpl
@@ -0,0 +1,15 @@
+# Description:
+# Tools for filtering DEF file for TensorFlow on Windows
+#
+# On Windows, we use a DEF file generated by Bazel to export
+# symbols from the tensorflow dynamic library(_pywrap_tensorflow.dll).
+# The maximum number of symbols that can be exported per DLL is 64K,
+# so we have to filter some useless symbols through this python script.
+
+package(default_visibility = ["//visibility:public"])
+
+py_binary(
+ name = "def_file_filter",
+ srcs = ["def_file_filter.py"],
+ srcs_version = "PY2AND3",
+)
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
new file mode 100644
index 0000000000..8bdc03eb0f
--- /dev/null
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
@@ -0,0 +1,168 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""def_file_filter.py - tool to filter a windows def file.
+
+The def file can be used to export symbols from the tensorflow dll to enable
+tf.load_library().
+
+Because the linker allows only 64K symbols to be exported per dll
+we filter the symbols down to the essentials. The regular expressions
+we use for this are specific to tensorflow.
+
+TODO: this works fine but there is an issue with exporting
+'const char * const' and importing it from a user_ops. The problem is
+on the importing end and using __declspec(dllimport) works around it.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import io
+import os
+import re
+import subprocess
+import sys
+import tempfile
+
+# External tools we use that come with visual studio sdk
+UNDNAME = "%{undname_bin_path}"
+
+# Exclude if matched
+EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::")
+
+# Include if matched before exclude
+INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|"
+ r"google::protobuf::internal::ArenaImpl::AllocateAligned|" # for contrib/data/_prefetching_ops
+ r"google::protobuf::internal::ArenaImpl::AddCleanup|" # for contrib/data/_prefetching_ops
+ r"google::protobuf::Arena::OnArenaAllocation|" # for contrib/data/_prefetching_ops
+ r"tensorflow::internal::LogMessage|"
+ r"tensorflow::internal::LogString|"
+ r"tensorflow::internal::CheckOpMessageBuilder|"
+ r"tensorflow::internal::MakeCheckOpValueString|"
+ r"tensorflow::internal::PickUnusedPortOrDie|"
+ r"tensorflow::internal::ValidateDevice|"
+ r"tensorflow::ops::internal::Enter|"
+ r"tensorflow::strings::internal::AppendPieces|"
+ r"tensorflow::strings::internal::CatPieces|"
+ r"tensorflow::io::internal::JoinPathImpl")
+
+# Include if matched after exclude
+INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
+ r"^(TFE_\w*)$|"
+ r"nsync::|"
+ r"tensorflow::|"
+ r"functor::|"
+ r"perftools::gputools")
+
+# We want to identify data members explicitly in the DEF file, so that no one
+# can implicitly link against the DLL if they use one of the variables exported
+# from the DLL and the header they use does not decorate the symbol with
+# __declspec(dllimport). It is easier to detect what a data symbol does
+# NOT look like, so doing it with the below regex.
+DATA_EXCLUDE_RE = re.compile(r"[)(]|"
+ r"vftable|"
+ r"vbtable|"
+ r"vcall|"
+ r"RTTI|"
+ r"protobuf::internal::ExplicitlyConstructed")
+
+def get_args():
+ """Parse command line."""
+ filename_list = lambda x: x.split(";")
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input", type=filename_list,
+ help="paths to input def file",
+ required=True)
+ parser.add_argument("--output", help="output deffile", required=True)
+ parser.add_argument("--target", help="name of the target", required=True)
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ """main."""
+ args = get_args()
+
+ # Pipe dumpbin to extract all linkable symbols from libs.
+ # Good symbols are collected in candidates and also written to
+ # a temp file.
+ candidates = []
+ tmpfile = tempfile.NamedTemporaryFile(mode="w", delete=False)
+ for def_file_path in args.input:
+ def_file = open(def_file_path, 'r')
+ for line in def_file:
+ cols = line.split()
+ sym = cols[0]
+ tmpfile.file.write(sym + "\n")
+ candidates.append(sym)
+ tmpfile.file.close()
+
+ # Run the symbols through undname to get their undecorated name
+ # so we can filter on something readable.
+ with open(args.output, "w") as def_fp:
+ # track dupes
+ taken = set()
+
+ # Header for the def file.
+ def_fp.write("LIBRARY " + args.target + "\n")
+ def_fp.write("EXPORTS\n")
+ def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n")
+
+ # Each symbols returned by undname matches the same position in candidates.
+ # We compare on undname but use the decorated name from candidates.
+ dupes = 0
+ proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE)
+ for idx, line in enumerate(io.TextIOWrapper(proc.stdout, encoding="utf-8")):
+ decorated = candidates[idx]
+ if decorated in taken:
+ # Symbol is already in output, done.
+ dupes += 1
+ continue
+
+ if not INCLUDEPRE_RE.search(line):
+ if EXCLUDE_RE.search(line):
+ continue
+ if not INCLUDE_RE.search(line):
+ continue
+
+ if "deleting destructor" in line:
+ # Some of the symbols convered by INCLUDEPRE_RE export deleting
+ # destructor symbols, which is a bad idea.
+ # So we filter out such symbols here.
+ continue
+
+ if DATA_EXCLUDE_RE.search(line):
+ def_fp.write("\t" + decorated + "\n")
+ else:
+ def_fp.write("\t" + decorated + " DATA\n")
+ taken.add(decorated)
+ def_fp.close()
+
+ exit_code = proc.wait()
+ if exit_code != 0:
+ print("{} failed, exit={}".format(UNDNAME, exit_code))
+ return exit_code
+
+ os.unlink(tmpfile.name)
+
+ print("symbols={}, taken={}, dupes={}"
+ .format(len(candidates), len(taken), dupes))
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
new file mode 100644
index 0000000000..47539b2423
--- /dev/null
+++ b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
@@ -0,0 +1,56 @@
+"""Repository rule for def file filter autoconfiguration.
+
+This repository reuses Bazel's VC detect mechanism to find undname.exe,
+which is a tool used in def_file_filter.py.
+
+def_file_filter.py is for filtering the DEF file for TensorFlow on Windows.
+On Windows, we use a DEF file generated by Bazel to export symbols from the
+tensorflow dynamic library(_pywrap_tensorflow.dll). The maximum number of
+symbols that can be exported per DLL is 64K, so we have to filter some useless
+symbols through this python script.
+
+`def_file_filter_config` depends on the following environment variables:
+ * `BAZEL_VC`
+ * `BAZEL_VS`
+ * `VS90COMNTOOLS`
+ * `VS100COMNTOOLS`
+ * `VS110COMNTOOLS`
+ * `VS120COMNTOOLS`
+ * `VS140COMNTOOLS`
+"""
+
+load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_vc_path")
+load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool")
+load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "auto_configure_fail")
+
+def _def_file_filter_configure_impl(repository_ctx):
+ if repository_ctx.os.name.lower().find("windows") == -1:
+ repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
+ repository_ctx.file("def_file_filter.py", "")
+ return
+ vc_path = find_vc_path(repository_ctx)
+ if vc_path == "visual-studio-not-found":
+ auto_configure_fail("Visual C++ build tools not found on your machine")
+ undname_bin_path = find_msvc_tool(repository_ctx, vc_path, "undname.exe").replace("\\", "\\\\")
+
+ repository_ctx.template(
+ "def_file_filter.py",
+ Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"),
+ {
+ "%{undname_bin_path}": undname_bin_path,
+ })
+ repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
+
+
+def_file_filter_configure = repository_rule(
+ implementation = _def_file_filter_configure_impl,
+ environ = [
+ "BAZEL_VC",
+ "BAZEL_VS",
+ "VS90COMNTOOLS",
+ "VS100COMNTOOLS",
+ "VS110COMNTOOLS",
+ "VS120COMNTOOLS",
+ "VS140COMNTOOLS"
+ ],
+)
diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc
index ab9a61afa7..80a954e062 100644
--- a/tensorflow/tools/graph_transforms/backports_test.cc
+++ b/tensorflow/tools/graph_transforms/backports_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
@@ -191,7 +192,7 @@ TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) {
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
ASSERT_EQ(1, node_lookup.count("v3_node"));
- EXPECT_TRUE(StringPiece(node_lookup.at("v3_node")->op()).ends_with("V2"));
+ EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2"));
}
}
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 250f54e20f..85660f94a8 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -283,6 +283,10 @@ Status FoldConstants(const GraphDef& input_graph_def,
};
}
+ TF_RETURN_IF_ERROR(context.GetOneInt64Parameter(
+ "max_constant_size_in_bytes", cf_opts.max_constant_size_in_bytes,
+ &cf_opts.max_constant_size_in_bytes));
+
// Constant folding.
bool was_mutated;
TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr,
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index 41106de008..a082399a87 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
@@ -209,10 +210,10 @@ class ConstantFoldingTest : public ::testing::Test {
for (const NodeDef& node : graph_def.node()) {
const StringPiece name(node.name());
const int occurrence_count = folded_node_map.count(node.name());
- if (name.ends_with("expect_removed")) {
+ if (str_util::EndsWith(name, "expect_removed")) {
EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
}
- if (name.ends_with("expect_remains")) {
+ if (str_util::EndsWith(name, "expect_remains")) {
EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
}
}
@@ -370,6 +371,46 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("b"));
EXPECT_EQ(1, node_map.count("c"));
}
+
+ void TestMaxConstantSizeInBytes() {
+ auto root = tensorflow::Scope::NewRootScope();
+
+ const int width = 100;
+
+ Tensor a_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&a_data, 1.0f);
+ Output a_const = ::tensorflow::ops::Const(
+ root.WithOpName("a_expect_remains"), Input::Initializer(a_data));
+
+ Tensor b_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&b_data, 1.0f);
+ Output b_const = ::tensorflow::ops::Const(
+ root.WithOpName("b_expect_remains"), Input::Initializer(b_data));
+
+ Output add = ::tensorflow::ops::Add(root.WithOpName("add_expect_remains"),
+ a_const, b_const);
+
+ Output placeholder = ::tensorflow::ops::Placeholder(
+ root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
+
+ Output mul = ::tensorflow::ops::Mul(
+ root.WithOpName("output_expect_remains"), add, placeholder);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&placeholder_tensor, 1.0f);
+
+ // Setting the maximum constant size to 10 bytes should stop the constant
+ // folding at add(a, b) that would have yielded a constant of
+ // 100*sizeof(float) bytes.
+ graph_transforms::TransformFuncContext context;
+ context.params["max_constant_size_in_bytes"] = {"10"};
+ TestConstantFolding(graph_def,
+ {{"placeholder_expect_remains", placeholder_tensor}},
+ {}, {"output_expect_remains"}, context);
+ }
};
TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
@@ -394,5 +435,9 @@ TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
TestRemoveUnusedNodesMultipleOutputs();
}
+TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
+ TestMaxConstantSizeInBytes();
+}
+
} // namespace graph_transforms
} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
index 2436c7e4a2..f401723808 100644
--- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
+++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
@@ -40,8 +40,8 @@ Status ExtractMinMaxRecords(const string& log_file_name,
for (const string& file_line : file_lines) {
// We expect to find a line with components separated by semicolons, so to
// start make sure that the basic structure is in place/
- StringPiece line(file_line);
- if (!line.contains(print_suffix + ";" + requant_prefix)) {
+ if (!str_util::StrContains(file_line,
+ print_suffix + ";" + requant_prefix)) {
continue;
}
std::vector<string> line_parts = str_util::Split(file_line, ';');
@@ -53,8 +53,7 @@ Status ExtractMinMaxRecords(const string& log_file_name,
bool min_max_found = false;
int min_max_index;
for (int i = 1; i < line_parts.size(); ++i) {
- StringPiece line_part(line_parts[i]);
- if (line_part.starts_with(requant_prefix)) {
+ if (str_util::StartsWith(line_parts[i], requant_prefix)) {
min_max_found = true;
min_max_index = i;
}
@@ -90,7 +89,7 @@ Status ExtractMinMaxRecords(const string& log_file_name,
continue;
}
StringPiece name_string = line_parts[min_max_index - 1];
- if (!name_string.ends_with(print_suffix)) {
+ if (!str_util::EndsWith(name_string, print_suffix)) {
continue;
}
string name =
diff --git a/tensorflow/tools/graph_transforms/insert_logging.cc b/tensorflow/tools/graph_transforms/insert_logging.cc
index e1ee2b420b..377665448c 100644
--- a/tensorflow/tools/graph_transforms/insert_logging.cc
+++ b/tensorflow/tools/graph_transforms/insert_logging.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -101,7 +102,7 @@ Status InsertLogging(const GraphDef& input_graph_def,
const bool op_matches = (ops.count(node.op()) > 0);
bool prefix_matches = false;
for (const string& prefix : prefixes) {
- if (StringPiece(node.name()).starts_with(prefix)) {
+ if (str_util::StartsWith(node.name(), prefix)) {
prefix_matches = true;
}
}
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 701e350fc3..cc82100148 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -88,7 +89,7 @@ void CreateConstNode(const Tensor& tensor, const string& name,
string GetMonolithicTensorKey(const string& tensor_slice_name) {
std::vector<string> names = Split(tensor_slice_name, "/");
- if (StringPiece(names[names.size() - 1]).starts_with("part_")) {
+ if (str_util::StartsWith(names[names.size() - 1], "part_")) {
CHECK_GE(names.size(), 2);
names.pop_back();
}
@@ -102,8 +103,8 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
for (const auto& node : input_graph_def.node()) {
std::vector<string> node_name_parts = Split(node.name(), "/");
if (node_name_parts.size() == 2 &&
- StringPiece(node_name_parts[0]).starts_with("save") &&
- StringPiece(node_name_parts[1]).starts_with("Assign") &&
+ str_util::StartsWith(node_name_parts[0], "save") &&
+ str_util::StartsWith(node_name_parts[1], "Assign") &&
node.input(0) == target_name) {
restore_node_name = node.input(1);
break;
diff --git a/tensorflow/tools/graph_transforms/transform_graph_test.cc b/tensorflow/tools/graph_transforms/transform_graph_test.cc
index bc2412fcbd..b276229aa4 100644
--- a/tensorflow/tools/graph_transforms/transform_graph_test.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
@@ -112,12 +113,11 @@ class TransformGraphTest : public ::testing::Test {
graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map);
for (const NodeDef& node : out_graph_def.node()) {
- const StringPiece name(node.name());
const int occurrence_count = out_node_map.count(node.name());
- if (name.ends_with("expect_removed")) {
+ if (str_util::EndsWith(node.name(), "expect_removed")) {
EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
}
- if (name.ends_with("expect_remains")) {
+ if (str_util::EndsWith(node.name(), "expect_remains")) {
EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
}
}
@@ -139,7 +139,7 @@ class TransformGraphTest : public ::testing::Test {
Status no_such_status =
TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def);
EXPECT_TRUE(
- StringPiece(no_such_status.ToString()).contains("not recognized"));
+ str_util::StrContains(no_such_status.ToString(), "not recognized"));
}
void TestParseTransformParameters() {
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index 55f28a9e1d..367048965d 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -88,7 +88,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix,
*suffix = ":" + input_parts[1];
}
StringPiece node_name_piece(input_parts[0]);
- if (node_name_piece.Consume("^")) {
+ if (str_util::ConsumePrefix(&node_name_piece, "^")) {
*prefix = "^";
} else {
*prefix = "";
@@ -200,8 +200,7 @@ Status SortByExecutionOrder(const GraphDef& input_graph_def,
// for merge only wait for one non-control input.
int32 num_control_edges = 0;
for (int i = 0; i < node_def.input_size(); ++i) {
- StringPiece input_name(node_def.input(i));
- if (input_name.starts_with("^")) {
+ if (str_util::StartsWith(node_def.input(i), "^")) {
num_control_edges++;
}
}
@@ -504,7 +503,7 @@ Status RenameNodeInputs(const GraphDef& input_graph_def,
const string& dest_name = input_to_rename.second;
bool is_match;
string match_name;
- if (StringPiece(source_name).ends_with(":*")) {
+ if (str_util::EndsWith(source_name, ":*")) {
is_match = true;
string prefix;
string unused_node_name;
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 62fec2c402..376644718f 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -48,36 +48,66 @@ py_binary(
deps = ["//tensorflow:tensorflow_py"],
)
+COMMON_PIP_DEPS = [
+ ":licenses",
+ "MANIFEST.in",
+ "README",
+ "setup.py",
+ ":included_headers",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/autograph:autograph",
+ "//tensorflow/contrib/autograph/converters:converters",
+ "//tensorflow/contrib/autograph/converters:test_lib",
+ "//tensorflow/contrib/autograph/impl:impl",
+ "//tensorflow/contrib/autograph/pyct:pyct",
+ "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
+ "//tensorflow/contrib/boosted_trees:boosted_trees_pip",
+ "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
+ "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:contrib_op_loader",
+ "//tensorflow/contrib/eager/python/examples:examples_pip",
+ "//tensorflow/contrib/eager/python:checkpointable_utils",
+ "//tensorflow/contrib/eager/python:evaluator",
+ "//tensorflow/contrib/gan:gan",
+ "//tensorflow/contrib/graph_editor:graph_editor_pip",
+ "//tensorflow/contrib/keras:keras",
+ "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
+ "//tensorflow/contrib/nn:nn_py",
+ "//tensorflow/contrib/predictor:predictor_pip",
+ "//tensorflow/contrib/receptive_field:receptive_field_pip",
+ "//tensorflow/contrib/session_bundle:session_bundle_pip",
+ "//tensorflow/contrib/signal:signal_py",
+ "//tensorflow/contrib/signal:test_util",
+ "//tensorflow/contrib/slim:slim",
+ "//tensorflow/contrib/slim/python/slim/data:data_pip",
+ "//tensorflow/contrib/slim/python/slim/nets:nets_pip",
+ "//tensorflow/contrib/specs:specs",
+ "//tensorflow/contrib/summary:summary_test_util",
+ "//tensorflow/contrib/tensor_forest:init_py",
+ "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
+ "//tensorflow/contrib/timeseries:timeseries_pip",
+ "//tensorflow/contrib/tpu",
+ "//tensorflow/examples/tutorials/mnist:package",
+ "//tensorflow/python:distributed_framework_test_lib",
+ "//tensorflow/python:meta_graph_testdata",
+ "//tensorflow/python:spectral_ops_test_util",
+ "//tensorflow/python:util_example_parser_configuration",
+ "//tensorflow/python/debug:debug_pip",
+ "//tensorflow/python/eager:eager_pip",
+ "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
+ "//tensorflow/python/saved_model:saved_model",
+ "//tensorflow/python/tools:tools_pip",
+ "//tensorflow/python:test_ops",
+ "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
+]
+
# On Windows, python binary is a zip file of runfiles tree.
# Add everything to its data dependency for generating a runfiles tree
# for building the pip package on Windows.
py_binary(
name = "simple_console_for_windows",
srcs = ["simple_console_for_windows.py"],
- data = [
- "MANIFEST.in",
- "README",
- "setup.py",
- ":included_headers",
- "//tensorflow/contrib/nn:nn_py",
- "//tensorflow/contrib/session_bundle:session_bundle_pip",
- "//tensorflow/contrib/signal:signal_py",
- "//tensorflow/contrib/slim/python/slim/data:data_pip",
- "//tensorflow/python:util_example_parser_configuration",
- "//tensorflow/python/debug:debug_pip",
- "//tensorflow/python/saved_model",
- "//tensorflow/python:spectral_ops_test_util",
- "//tensorflow/python/tools:tools_pip",
- "//tensorflow/python/eager:eager_pip",
- "//tensorflow/contrib/summary:summary_test_util",
- # These targets don't build on Windows yet. Exclude them for now.
- # "//tensorflow/contrib/slim",
- # "//tensorflow/contrib/slim/python/slim/nets:nets_pip",
- # "//tensorflow/contrib/specs",
- # "//tensorflow/contrib/tensor_forest:init_py",
- # "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
- # "//tensorflow/examples/tutorials/mnist:package",
- ],
+ data = COMMON_PIP_DEPS,
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
@@ -111,6 +141,7 @@ filegroup(
"@kafka//:LICENSE",
"@libxsmm_archive//:LICENSE",
"@lmdb//:LICENSE",
+ "@local_config_nccl//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
"@grpc//third_party/nanopb:LICENSE.txt",
"@grpc//third_party/address_sorting:LICENSE",
@@ -127,8 +158,6 @@ filegroup(
"@org_python_pypi_backports_weakref//:LICENSE",
] + if_mkl([
"//third_party/mkl:LICENSE",
- ]) + if_not_windows([
- "@nccl_archive//:LICENSE.txt",
]) + tf_additional_license_deps(),
)
@@ -138,63 +167,13 @@ sh_binary(
data = select({
"//tensorflow:windows": [":simple_console_for_windows"],
"//tensorflow:windows_msvc": [":simple_console_for_windows"],
- "//conditions:default": [
- ":licenses",
- "MANIFEST.in",
- "README",
- "setup.py",
- ":included_headers",
+ "//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/boosted_trees:boosted_trees_pip",
- "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
- "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/eager/python/examples:examples_pip",
- "//tensorflow/contrib/eager/python:checkpointable_utils",
- "//tensorflow/contrib/eager/python:evaluator",
- "//tensorflow/contrib/gan:gan",
- "//tensorflow/contrib/graph_editor:graph_editor_pip",
- "//tensorflow/contrib/keras:keras",
- "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
"//tensorflow/contrib/lite/python:interpreter_test_data",
"//tensorflow/contrib/lite/python:tf_lite_py_pip",
"//tensorflow/contrib/lite/toco:toco",
"//tensorflow/contrib/lite/toco/python:toco_wrapper",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",
- "//tensorflow/contrib/nn:nn_py",
- "//tensorflow/contrib/predictor:predictor_pip",
- "//tensorflow/contrib/autograph:autograph",
- "//tensorflow/contrib/autograph/converters:converters",
- "//tensorflow/contrib/autograph/converters:test_lib",
- "//tensorflow/contrib/autograph/impl:impl",
- "//tensorflow/contrib/autograph/pyct:pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
- "//tensorflow/contrib/receptive_field:receptive_field_pip",
- "//tensorflow/contrib/session_bundle:session_bundle_pip",
- "//tensorflow/contrib/signal:signal_py",
- "//tensorflow/contrib/signal:test_util",
- "//tensorflow/contrib/slim:slim",
- "//tensorflow/contrib/slim/python/slim/data:data_pip",
- "//tensorflow/contrib/slim/python/slim/nets:nets_pip",
- "//tensorflow/contrib/specs:specs",
- "//tensorflow/contrib/summary:summary_test_util",
- "//tensorflow/contrib/tensor_forest:init_py",
- "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
- "//tensorflow/contrib/timeseries:timeseries_pip",
- "//tensorflow/contrib/tpu",
- "//tensorflow/examples/tutorials/mnist:package",
- "//tensorflow/python:distributed_framework_test_lib",
- "//tensorflow/python:meta_graph_testdata",
- "//tensorflow/python:spectral_ops_test_util",
- "//tensorflow/python:util_example_parser_configuration",
- "//tensorflow/python/debug:debug_pip",
- "//tensorflow/python/eager:eager_pip",
- "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
- "//tensorflow/python/saved_model:saved_model",
- "//tensorflow/python/tools:tools_pip",
- "//tensorflow/python:test_ops",
- "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([
"//tensorflow/contrib/tensorrt:init_py",
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index feb3114bde..8f0cf8c3d1 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -162,7 +162,9 @@ function main() {
# Before we leave the top-level directory, make sure we know how to
# call python.
- source tools/python_bin_path.sh
+ if [[ -e tools/python_bin_path.sh ]]; then
+ source tools/python_bin_path.sh
+ fi
pushd ${TMPDIR}
rm -f MANIFEST
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index fe6b9407d6..ace0d411b9 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -2,6 +2,7 @@
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
+load("//third_party:nccl/nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
@@ -13,6 +14,8 @@ load("//third_party:repo.bzl", "tf_http_archive")
load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain")
load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external")
load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external")
+load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl",
+ "def_file_filter_configure")
# Sanitize a dependency so that it works correctly from code that includes
@@ -29,10 +32,15 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
cc_download_clang_toolchain(name="local_config_download_clang")
cuda_configure(name="local_config_cuda")
tensorrt_configure(name="local_config_tensorrt")
+ nccl_configure(name="local_config_nccl")
git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
python_configure(name="local_config_python")
+ # For windows bazel build
+ # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows.
+ def_file_filter_configure(name = "local_config_def_file_filter")
+
# Point //external/local_config_arm_compiler to //external/arm_compiler
arm_compiler_configure(
name="local_config_arm_compiler",
@@ -42,7 +50,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
mkl_repository(
name = "mkl_linux",
urls = [
- "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_lnx_2018.0.1.20171227.tgz",
],
sha256 = "feacc3d82565c1231470359b42c696236fae873704e0b013436afba5fd4fd30f",
@@ -52,7 +60,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
mkl_repository(
name = "mkl_windows",
urls = [
- "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip",
"https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_win_2018.0.1.20171227.zip"
],
sha256 = "24bae8d7b22b431a654acadea43f2243c46ae6b1e5a73a4a936825f31d284ee4",
@@ -62,7 +70,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
mkl_repository(
name = "mkl_darwin",
urls = [
- "https://mirror.bazel.build/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.12/mklml_mac_2018.0.1.20171227.tgz"
],
sha256 = "0e954ec6fd3dc5e37f64c4043f6b5613dd687558da3df1028b3b7c29ff5cf77f",
@@ -454,11 +462,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/1c3cdea2f181d8e14ee184466c5fb237f1b4cda8.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/1c3cdea2f181d8e14ee184466c5fb237f1b4cda8.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7e78daafdd22f3f17720a103d29d89590534004e.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/7e78daafdd22f3f17720a103d29d89590534004e.tar.gz",
],
- sha256 = "1efbb9b05af88368be984d2f6526061d4a857181ef10f8841889a3a46869bb01",
- strip_prefix = "llvm-1c3cdea2f181d8e14ee184466c5fb237f1b4cda8",
+ sha256 = "a6d94bd9de23515a1e3792a830421e3885977ea43d03427cdbe68f98cb7e0045",
+ strip_prefix = "llvm-7e78daafdd22f3f17720a103d29d89590534004e",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
@@ -497,11 +505,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "zlib_archive",
urls = [
- "https://mirror.bazel.build/zlib.net/zlib-1.2.8.tar.gz",
- "http://zlib.net/fossils/zlib-1.2.8.tar.gz",
+ "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
+ "https://zlib.net/zlib-1.2.11.tar.gz",
],
- sha256 = "36658cb768a54c1d4dec43c3116c27ed893e88b02ecfcb44f2166f9c0b7f2a0d",
- strip_prefix = "zlib-1.2.8",
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
build_file = clean_dep("//third_party:zlib.BUILD"),
)
@@ -518,11 +526,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "snappy",
urls = [
- "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz",
- "https://github.com/google/snappy/archive/1.1.4.tar.gz",
+ "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz",
+ "https://github.com/google/snappy/archive/1.1.7.tar.gz",
],
- sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94",
- strip_prefix = "snappy-1.1.4",
+ sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
+ strip_prefix = "snappy-1.1.7",
build_file = clean_dep("//third_party:snappy.BUILD"),
)
@@ -534,7 +542,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
],
sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
- build_file = clean_dep("//third_party:nccl.BUILD"),
+ build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
)
tf_http_archive(
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 28293a3659..075b46896e 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -163,13 +163,6 @@ all_cmake_vars = select({
# Performs CMake variable substitutions on configuration header files.
expand_cmake_vars(
- name = "datatypes_gen",
- src = "include/llvm/Support/DataTypes.h.cmake",
- cmake_vars = all_cmake_vars,
- dst = "include/llvm/Support/DataTypes.h",
-)
-
-expand_cmake_vars(
name = "config_gen",
src = "include/llvm/Config/config.h.cmake",
cmake_vars = all_cmake_vars,
@@ -305,9 +298,7 @@ cc_binary(
srcs = glob([
"utils/TableGen/*.cpp",
"utils/TableGen/*.h",
- ]) + [
- "lib/Target/X86/Disassembler/X86DisassemblerDecoderCommon.h",
- ],
+ ]),
linkopts = [
"-lm",
"-ldl",
@@ -2014,7 +2005,6 @@ cc_library(
"include/llvm/Support/WasmRelocs/*.def",
]) + [
"include/llvm/BinaryFormat/MachO.def",
- "include/llvm/Support/DataTypes.h",
"include/llvm/Support/VCSRevision.h",
"include/llvm/ExecutionEngine/ObjectMemoryBuffer.h",
],
diff --git a/third_party/nccl/LICENSE b/third_party/nccl/LICENSE
new file mode 100644
index 0000000000..146d9b765c
--- /dev/null
+++ b/third_party/nccl/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2018 The TensorFlow Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018, The TensorFlow Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/third_party/nccl.BUILD b/third_party/nccl/nccl_archive.BUILD
index b2b8e18824..a05899e38d 100644
--- a/third_party/nccl.BUILD
+++ b/third_party/nccl/nccl_archive.BUILD
@@ -43,6 +43,7 @@ cc_library(
"-Iexternal/nccl_archive/src",
"-O3",
] + cuda_default_copts(),
+ include_prefix = "third_party/nccl",
linkopts = select({
"@org_tensorflow//tensorflow:android": [
"-pie",
@@ -61,6 +62,7 @@ cc_library(
"-lrt",
],
}),
+ strip_include_prefix = "src",
visibility = ["//visibility:public"],
deps = ["@local_config_cuda//cuda:cuda_headers"],
)
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
new file mode 100644
index 0000000000..9dfcb18369
--- /dev/null
+++ b/third_party/nccl/nccl_configure.bzl
@@ -0,0 +1,172 @@
+# -*- Python -*-
+"""Repository rule for NCCL configuration.
+
+`nccl_configure` depends on the following environment variables:
+
+ * `TF_NCCL_VERSION`: The NCCL version.
+ * `NCCL_INSTALL_PATH`: The installation path of the NCCL library.
+"""
+
+load(
+ "//third_party/gpus:cuda_configure.bzl",
+ "auto_configure_fail",
+ "find_cuda_define",
+ "matches_version",
+)
+
+_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_TF_NCCL_VERSION = "TF_NCCL_VERSION"
+
+_DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR"
+_DEFINE_NCCL_MINOR = "#define NCCL_MINOR"
+_DEFINE_NCCL_PATCH = "#define NCCL_PATCH"
+
+_NCCL_DUMMY_BUILD_CONTENT = """
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nccl",
+ visibility = ["//visibility:public"],
+)
+"""
+
+_NCCL_ARCHIVE_BUILD_CONTENT = """
+filegroup(
+ name = "LICENSE",
+ data = ["@nccl_archive//:LICENSE.txt"],
+ visibility = ["//visibility:public"],
+)
+
+alias(
+ name = "nccl",
+ actual = "@nccl_archive//:nccl",
+ visibility = ["//visibility:public"],
+)
+"""
+
+_NCCL_LOCAL_BUILD_TEMPLATE = """
+filegroup(
+ name = "LICENSE",
+ data = ["nccl/NCCL-SLA.txt"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nccl",
+ srcs = ["nccl/lib/libnccl.so.%s"],
+ hdrs = ["nccl/include/nccl.h"],
+ include_prefix = "third_party/nccl",
+ strip_include_prefix = "nccl/include",
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ visibility = ["//visibility:public"],
+)
+"""
+
+
+def _find_nccl_header(repository_ctx, nccl_install_path):
+ """Finds the NCCL header on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library install directory.
+
+ Returns:
+ The path to the NCCL header.
+ """
+ header_path = repository_ctx.path("%s/include/nccl.h" % nccl_install_path)
+ if not header_path.exists:
+ auto_configure_fail("Cannot find %s" % str(header_path))
+ return header_path
+
+
+def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
+ """Checks whether the header file matches the specified version of NCCL.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library install directory.
+ nccl_version: The expected NCCL version.
+
+ Returns:
+ A string containing the library version of NCCL.
+ """
+ header_path = _find_nccl_header(repository_ctx, nccl_install_path)
+ header_dir = str(header_path.realpath.dirname)
+ major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
+ _DEFINE_NCCL_MAJOR)
+ minor_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
+ _DEFINE_NCCL_MINOR)
+ patch_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
+ _DEFINE_NCCL_PATCH)
+ header_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+ if not matches_version(nccl_version, header_version):
+ auto_configure_fail(
+ ("NCCL library version detected from %s/nccl.h (%s) does not match " +
+ "TF_NCCL_VERSION (%s). To fix this rerun configure again.") %
+ (header_dir, header_version, nccl_version))
+
+
+def _find_nccl_lib(repository_ctx, nccl_install_path, nccl_version):
+ """Finds the given NCCL library on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library installation directory.
+ nccl_version: The version of NCCL library files as returned
+ by _nccl_version.
+
+ Returns:
+ The path to the NCCL library.
+ """
+ lib_path = repository_ctx.path("%s/lib/libnccl.so.%s" % (nccl_install_path,
+ nccl_version))
+ if not lib_path.exists:
+ auto_configure_fail("Cannot find NCCL library %s" % str(lib_path))
+ return lib_path
+
+
+def _nccl_configure_impl(repository_ctx):
+ """Implementation of the nccl_configure repository rule."""
+ if _TF_NCCL_VERSION not in repository_ctx.os.environ:
+ # Add a dummy build file to make bazel query happy.
+ repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT)
+ return
+
+ nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
+ if matches_version("1", nccl_version):
+ # Alias to GitHub target from @nccl_archive.
+ if not matches_version(nccl_version, "1.3"):
+ auto_configure_fail(
+ "NCCL from GitHub must use version 1.3 (got %s)" % nccl_version)
+ repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
+ else:
+ # Create target for locally installed NCCL.
+ nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
+ _check_nccl_version(repository_ctx, nccl_install_path, nccl_version)
+ repository_ctx.symlink(nccl_install_path, "nccl")
+ repository_ctx.file("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE % nccl_version)
+
+
+nccl_configure = repository_rule(
+ implementation=_nccl_configure_impl,
+ environ=[
+ _NCCL_INSTALL_PATH,
+ _TF_NCCL_VERSION,
+ ],
+)
+"""Detects and configures the NCCL configuration.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+nccl_configure(name = "local_config_nccl")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD
index fd48ed8941..cc11f52d0e 100644
--- a/third_party/snappy.BUILD
+++ b/third_party/snappy.BUILD
@@ -4,25 +4,12 @@ licenses(["notice"]) # BSD 3-Clause
exports_files(["COPYING"])
-config_setting(
- name = "windows",
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "windows_msvc",
- values = {"cpu": "x64_windows_msvc"},
- visibility = ["//visibility:public"],
-)
-
cc_library(
name = "snappy",
srcs = [
+ "config.h",
"snappy.cc",
"snappy.h",
- "snappy-c.cc",
- "snappy-c.h",
"snappy-internal.h",
"snappy-sinksource.cc",
"snappy-sinksource.h",
@@ -32,9 +19,18 @@ cc_library(
],
hdrs = ["snappy.h"],
copts = select({
- ":windows": [],
- ":windows_msvc": [],
+ "@org_tensorflow//tensorflow:windows": [
+ "/DHAVE_CONFIG_H",
+ "/EHsc",
+ ],
+ "@org_tensorflow//tensorflow:windows_msvc": [
+ "/DHAVE_CONFIG_H",
+ "/EHsc",
+ ],
"//conditions:default": [
+ "-DHAVE_CONFIG_H",
+ "-fno-exceptions",
+ "-Wno-sign-compare",
"-Wno-shift-negative-value",
"-Wno-implicit-function-declaration",
],
@@ -42,20 +38,66 @@ cc_library(
)
genrule(
+ name = "config_h",
+ outs = ["config.h"],
+ cmd = "\n".join([
+ "cat <<'EOF' >$@",
+ "#define HAVE_STDDEF_H 1",
+ "#define HAVE_STDINT_H 1",
+ "",
+ "#ifdef __has_builtin",
+ "# if !defined(HAVE_BUILTIN_EXPECT) && __has_builtin(__builtin_expect)",
+ "# define HAVE_BUILTIN_EXPECT 1",
+ "# endif",
+ "# if !defined(HAVE_BUILTIN_CTZ) && __has_builtin(__builtin_ctzll)",
+ "# define HAVE_BUILTIN_CTZ 1",
+ "# endif",
+ "#elif defined(__GNUC__) && (__GNUC__ > 3 || __GNUC__ == 3 && __GNUC_MINOR__ >= 4)",
+ "# ifndef HAVE_BUILTIN_EXPECT",
+ "# define HAVE_BUILTIN_EXPECT 1",
+ "# endif",
+ "# ifndef HAVE_BUILTIN_CTZ",
+ "# define HAVE_BUILTIN_CTZ 1",
+ "# endif",
+ "#endif",
+ "",
+ "#ifdef __has_include",
+ "# if !defined(HAVE_BYTESWAP_H) && __has_include(<byteswap.h>)",
+ "# define HAVE_BYTESWAP_H 1",
+ "# endif",
+ "# if !defined(HAVE_UNISTD_H) && __has_include(<unistd.h>)",
+ "# define HAVE_UNISTD_H 1",
+ "# endif",
+ "# if !defined(HAVE_SYS_ENDIAN_H) && __has_include(<sys/endian.h>)",
+ "# define HAVE_SYS_ENDIAN_H 1",
+ "# endif",
+ "# if !defined(HAVE_SYS_MMAN_H) && __has_include(<sys/mman.h>)",
+ "# define HAVE_SYS_MMAN_H 1",
+ "# endif",
+ "# if !defined(HAVE_SYS_UIO_H) && __has_include(<sys/uio.h>)",
+ "# define HAVE_SYS_UIO_H 1",
+ "# endif",
+ "#endif",
+ "",
+ "#ifndef SNAPPY_IS_BIG_ENDIAN",
+ "# ifdef __s390x__",
+ "# define SNAPPY_IS_BIG_ENDIAN 1",
+ "# elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__",
+ "# define SNAPPY_IS_BIG_ENDIAN 1",
+ "# endif",
+ "#endif",
+ "EOF",
+ ]),
+)
+
+genrule(
name = "snappy_stubs_public_h",
srcs = ["snappy-stubs-public.h.in"],
outs = ["snappy-stubs-public.h"],
cmd = ("sed " +
- "-e 's/@ac_cv_have_stdint_h@/1/g' " +
- "-e 's/@ac_cv_have_stddef_h@/1/g' " +
- "-e 's/@ac_cv_have_stdint_h@/1/g' " +
- select({
- "@org_tensorflow//tensorflow:windows": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ",
- "@org_tensorflow//tensorflow:windows_msvc": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ",
- "//conditions:default": "-e 's/@ac_cv_have_sys_uio_h@/1/g' ",
- }) +
- "-e 's/@SNAPPY_MAJOR@/1/g' " +
- "-e 's/@SNAPPY_MINOR@/1/g' " +
- "-e 's/@SNAPPY_PATCHLEVEL@/4/g' " +
+ "-e 's/$${\\(.*\\)_01}/\\1/g' " +
+ "-e 's/$${SNAPPY_MAJOR}/1/g' " +
+ "-e 's/$${SNAPPY_MINOR}/1/g' " +
+ "-e 's/$${SNAPPY_PATCHLEVEL}/4/g' " +
"$< >$@"),
)
diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD
index d164ee719c..e8048dd98a 100644
--- a/third_party/zlib.BUILD
+++ b/third_party/zlib.BUILD
@@ -2,18 +2,6 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # BSD/MIT-like license (for zlib)
-config_setting(
- name = "windows",
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "windows_msvc",
- values = {"cpu": "x64_windows_msvc"},
- visibility = ["//visibility:public"],
-)
-
cc_library(
name = "zlib",
srcs = [
@@ -45,8 +33,8 @@ cc_library(
],
hdrs = ["zlib.h"],
copts = select({
- ":windows": [],
- ":windows_msvc": [],
+ "@org_tensorflow//tensorflow:windows": [],
+ "@org_tensorflow//tensorflow:windows_msvc": [],
"//conditions:default": [
"-Wno-shift-negative-value",
"-DZ_HAVE_UNISTD_H",