aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--SECURITY.md2
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc246
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h4
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc6
-rw-r--r--tensorflow/compiler/tests/BUILD9
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py10
-rw-r--r--tensorflow/compiler/tests/xla_test.py57
-rw-r--r--tensorflow/compiler/tests/xla_test_test.py44
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc46
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc4
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc84
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h14
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc3
-rw-r--r--tensorflow/compiler/xla/experimental/xla_sharding/BUILD18
-rw-r--r--tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py204
-rw-r--r--tensorflow/compiler/xla/layout_util.cc10
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc10
-rw-r--r--tensorflow/compiler/xla/literal_util.cc28
-rw-r--r--tensorflow/compiler/xla/literal_util.h5
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc16
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc5
-rw-r--r--tensorflow/compiler/xla/primitive_util.h3
-rw-r--r--tensorflow/compiler/xla/python_api/BUILD36
-rw-r--r--tensorflow/compiler/xla/python_api/types.py124
-rw-r--r--tensorflow/compiler/xla/python_api/xla_literal.py95
-rw-r--r--tensorflow/compiler/xla/python_api/xla_shape.py155
-rw-r--r--tensorflow/compiler/xla/service/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc24
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc5
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc72
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h7
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc4
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc73
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc82
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h68
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc23
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc52
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h58
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc204
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h98
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc155
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h120
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc104
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc60
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h50
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc21
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc16
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h4
-rw-r--r--tensorflow/compiler/xla/service/service.cc27
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc141
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc4
-rw-r--r--tensorflow/compiler/xla/shape_util.cc31
-rw-r--r--tensorflow/compiler/xla/shape_util.h15
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc88
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc17
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc37
-rw-r--r--tensorflow/contrib/autograph/LIMITATIONS.md50
-rw-r--r--tensorflow/contrib/autograph/README.md12
-rw-r--r--tensorflow/contrib/autograph/STYLE_GUIDE.md16
-rw-r--r--tensorflow/contrib/autograph/lang/BUILD40
-rw-r--r--tensorflow/contrib/autograph/lang/directives.py68
-rw-r--r--tensorflow/contrib/autograph/lang/special_functions.py59
-rw-r--r--tensorflow/contrib/autograph/lang/special_functions_test.py54
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/cloud/BUILD11
-rw-r--r--tensorflow/contrib/cloud/__init__.py5
-rw-r--r--tensorflow/contrib/cloud/kernels/BUILD1
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py34
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake2
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2.py31
-rw-r--r--tensorflow/contrib/control_flow/python/cond_v2_test.py254
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py41
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py18
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py325
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py1
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py138
-rw-r--r--tensorflow/contrib/distribute/python/values.py36
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py12
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/BUILD76
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py335
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py346
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/config.py117
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/ops.py70
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/ops_test.py80
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py263
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py277
-rw-r--r--tensorflow/contrib/estimator/BUILD2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn.py18
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_test.py17
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py279
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py10
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py120
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops.py3
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops.py11
-rw-r--r--tensorflow/contrib/lite/Makefile45
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc58
-rw-r--r--tensorflow/contrib/lite/arena_planner_test.cc13
-rw-r--r--tensorflow/contrib/lite/build_def.bzl3
-rw-r--r--tensorflow/contrib/lite/context.c3
-rw-r--r--tensorflow/contrib/lite/context.h10
-rw-r--r--tensorflow/contrib/lite/graph_info.h3
-rw-r--r--tensorflow/contrib/lite/graph_info_test.cc2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc64
-rw-r--r--tensorflow/contrib/lite/interpreter.h23
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc15
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md9
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD65
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h50
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h182
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h50
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h35
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h23
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h97
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/register.h2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc8
-rw-r--r--tensorflow/contrib/lite/model.cc26
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc2
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc4
-rw-r--r--tensorflow/contrib/lite/python/lite.py1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs13
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h65
-rw-r--r--tensorflow/contrib/lite/string_util.cc2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc11
-rw-r--r--tensorflow/contrib/lite/toco/BUILD3
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc94
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc59
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc102
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc69
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc97
-rw-r--r--tensorflow/contrib/lite/toco/model.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc56
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h11
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc8
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc6
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc3
-rw-r--r--tensorflow/contrib/metrics/BUILD26
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py66
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py395
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc12
-rw-r--r--tensorflow/contrib/tpu/ops/replication_ops.cc21
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py20
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py4
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py20
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py52
-rw-r--r--tensorflow/core/api_def/excluded_ops.cc3
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc3
-rw-r--r--tensorflow/core/framework/device_base.h4
-rw-r--r--tensorflow/core/framework/op_kernel.cc16
-rw-r--r--tensorflow/core/framework/op_kernel.h2
-rw-r--r--tensorflow/core/graph/graph.cc23
-rw-r--r--tensorflow/core/graph/graph.h20
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD4
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc125
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc44
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc110
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.h10
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h8
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc10
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc1
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc5
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc149
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc65
-rw-r--r--tensorflow/core/kernels/deep_conv2d.cc10
-rw-r--r--tensorflow/core/ops/control_flow_ops.cc13
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc4
-rw-r--r--tensorflow/core/platform/default/build_config.bzl2
-rw-r--r--tensorflow/core/util/tensor_format.cc12
-rw-r--r--tensorflow/core/util/tensor_format.h47
-rw-r--r--tensorflow/core/util/tensor_format_test.cc25
-rw-r--r--tensorflow/docs_src/mobile/tflite/demo_android.md3
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py26
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py528
-rw-r--r--tensorflow/python/eager/BUILD17
-rw-r--r--tensorflow/python/eager/function.py76
-rw-r--r--tensorflow/python/eager/function_test.py17
-rw-r--r--tensorflow/python/eager/memory_test.py108
-rw-r--r--tensorflow/python/estimator/canned/baseline.py14
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py20
-rw-r--r--tensorflow/python/estimator/canned/dnn.py10
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py10
-rw-r--r--tensorflow/python/estimator/canned/linear.py10
-rw-r--r--tensorflow/python/estimator/estimator.py16
-rw-r--r--tensorflow/python/feature_column/feature_column.py208
-rw-r--r--tensorflow/python/framework/function.py54
-rw-r--r--tensorflow/python/keras/engine/input_layer.py2
-rw-r--r--tensorflow/python/keras/layers/normalization.py11
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py20
-rw-r--r--tensorflow/python/layers/convolutional.py5
-rw-r--r--tensorflow/python/layers/core.py1
-rw-r--r--tensorflow/python/layers/normalization.py1
-rw-r--r--tensorflow/python/ops/distributions/categorical.py23
-rw-r--r--tensorflow/python/ops/embedding_ops.py12
-rw-r--r--tensorflow/python/ops/template.py67
-rw-r--r--tensorflow/python/ops/variable_scope.py3
-rw-r--r--tensorflow/python/training/checkpoint_utils.py8
-rw-r--r--tensorflow/python/training/checkpointable/util.py88
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py31
-rw-r--r--tensorflow/python/training/monitored_session.py14
-rw-r--r--tensorflow/python/training/saver.py10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc2
-rw-r--r--tensorflow/stream_executor/stream.cc1
-rw-r--r--tensorflow/tools/api/generator/BUILD21
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py11
-rw-r--r--tensorflow/tools/api/generator/doc_srcs.py29
-rw-r--r--tensorflow/tools/api/generator/doc_srcs_test.py11
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.pbtxt2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cmake2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh6
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh4
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh4
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL.tpl242
275 files changed, 8726 insertions, 2481 deletions
diff --git a/SECURITY.md b/SECURITY.md
index e2f6ff353a..0b52fdc7ab 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
-(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here].
+[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ea90d714c8..9448b8ebde 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -106,41 +106,11 @@ void MarkGuaranteedConstants(
}
}
-// A node/slot pair.
-// TODO(phawkins): is there a common definition of this?
-struct NodeSlot {
- NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {}
- NodeSlot(const Node* node, int slot)
- : node(node), slot(slot), dtype(DT_INVALID) {}
- NodeSlot(const Node* node, int slot, DataType dtype)
- : node(node), slot(slot), dtype(dtype) {}
-
- const Node* node;
- int slot;
-
- // Optional: used to record the destination type of a source NodeSlot in case
- // the source output is a Ref type that is cast to a Tensor at the
- // destination.
- DataType dtype;
-
- bool operator==(const NodeSlot& other) const {
- return node == other.node && slot == other.slot && dtype == other.dtype;
- }
-
- // Leave dtype out of the hash since there are never two NodeSlots with the
- // same node and slot and different dtypes.
- struct Hasher {
- uint64 operator()(NodeSlot const& s) const {
- return Hash64Combine(std::hash<const Node*>()(s.node),
- std::hash<int>()(s.slot));
- }
- };
-
- struct PairHasher {
- uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
- return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
- }
- };
+struct OutputInputTensorPairHasher {
+ uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
+ return Hash64Combine(OutputTensor::Hash()(s.first),
+ InputTensor::Hash()(s.second));
+ }
};
// TODO(phawkins) add a canonical copy of these operator names and refactor
@@ -376,7 +346,7 @@ class Encapsulator {
// Map from source (producer node/slot) tensors in the original graph to
// input index (slot number in the HostCompute/RecvAtHost nodes that will
// be created) for the outside_compilation subgraph.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> inputs;
// Set of nodes in the original graph that are the source of control edges
// that cross from the containing compiled subgraph into the
@@ -392,8 +362,15 @@ class Encapsulator {
// node/slot) tensors in the original graph to output index (slot number
// in the SendFromHost/HostCompute nodes that will be created) for the
// outside_compilation subgraph.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
+ struct ArgNumAndType {
+ int index;
+ DataType dtype;
+
+ ArgNumAndType(int i, DataType t) : index(i), dtype(t) {}
+ };
+ std::unordered_map<OutputTensor, ArgNumAndType, OutputTensor::Hash>
+ outputs_by_src;
+ std::unordered_map<InputTensor, int, InputTensor::Hash> outputs_by_dst;
// Set of nodes in the original graph that are the destination of control
// edges that cross from the outside_compilation subgraph into the
@@ -479,14 +456,14 @@ class Encapsulator {
// (consumer node/slot) tensors in the input graph to _Arg numbers in
// the subgraph. The source map is one-to-one, whereas the dest map may be
// many-to-one.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_;
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
+ std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
- // The _Arg nodes in the subgraph, in order by argument number.
+ // The arguments to the subgraph, in order.
std::vector<Node*> args_;
// Map from source tensor in the input graph to result #.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
// The outside_compilation clusters in this subgraph.
std::unordered_map<string, OutsideCompilationSubgraph>
@@ -583,8 +560,8 @@ class Encapsulator {
const string& dst_outside_compilation_id,
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out,
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
- edges_added);
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>* edges_added);
// Adds control dependencies between subgraph call nodes that have
// dependencies via outside_compilation edges.
@@ -716,11 +693,11 @@ void TopologicalClusterSort(
Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
- return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input()));
+ return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
}
int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
- return results_.at(NodeSlot(edge->src(), edge->src_output()));
+ return results_.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::GetRecvAtHostNode(
@@ -732,7 +709,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode(
int Encapsulator::Subgraph::GetRecvAtHostSlot(
const string& outside_compilation_subgraph_name, const Edge* edge) const {
return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
- .inputs.at(NodeSlot(edge->src(), edge->src_output()));
+ .inputs.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::GetSendFromHostNode(
@@ -744,7 +721,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode(
int Encapsulator::Subgraph::GetSendFromHostSlot(
const string& outside_compilation_subgraph_name, const Edge* edge) const {
return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
- .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
+ .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input()));
}
Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
@@ -769,10 +746,10 @@ Status Encapsulator::Subgraph::RecordArg(
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
Node* src_node = edge->src();
int src_slot = edge->src_output();
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
- std::tie(iter, inserted) =
- args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size());
+ std::tie(iter, inserted) = args_by_src_.emplace(
+ OutputTensor(src_node, src_slot), args_by_src_.size());
int arg_index = iter->second;
if (inserted) {
NodeDef arg_def;
@@ -793,7 +770,7 @@ Status Encapsulator::Subgraph::RecordArg(
Node* dst_node = edge->dst();
Node* dst_image = node_images.at(dst_node);
int dst_slot = edge->dst_input();
- args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index;
+ args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
return Status::OK();
}
@@ -804,10 +781,10 @@ Status Encapsulator::Subgraph::RecordResult(
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
int src_slot = edge->src_output();
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) =
- results_.emplace(NodeSlot(src_node, src_slot), results_.size());
+ results_.emplace(OutputTensor(src_node, src_slot), results_.size());
int ret_index = iter->second;
if (inserted) {
NodeDef ret_def;
@@ -845,8 +822,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl(
outside_subgraph->control_inputs.insert(edge->src());
} else {
int input_index = outside_subgraph->inputs.size();
- outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()),
- input_index);
+ outside_subgraph->inputs.emplace(
+ OutputTensor(edge->src(), edge->src_output()), input_index);
}
}
@@ -860,11 +837,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
DataType dtype = edge->dst()->input_type(edge->dst_input());
auto output_iter =
outside_subgraph->outputs_by_src
- .emplace(NodeSlot(edge->src(), edge->src_output(), dtype),
- outside_subgraph->outputs_by_src.size())
+ .emplace(OutputTensor(edge->src(), edge->src_output()),
+ OutsideCompilationSubgraph::ArgNumAndType(
+ outside_subgraph->outputs_by_src.size(), dtype))
.first;
- int output_index = output_iter->second;
- outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
+ const int output_index = output_iter->second.index;
+ outside_subgraph
+ ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] =
output_index;
}
}
@@ -946,7 +925,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (const auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = input_src.first.slot;
+ int src_slot = input_src.first.index;
int input_index = input_src.second;
DataType dtype = src_node->output_type(src_slot);
@@ -954,8 +933,8 @@ Status Encapsulator::Subgraph::AddHostComputes(
input_dtypes[input_index] = dtype;
}
for (const auto& output : oc_subgraph.outputs_by_src) {
- DataType dtype = output.first.dtype;
- int output_index = output.second;
+ DataType dtype = output.second.dtype;
+ int output_index = output.second.index;
output_dtypes[output_index] = dtype;
}
@@ -993,7 +972,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = input_src.first.slot;
+ int src_slot = input_src.first.index;
int input_index = input_src.second;
graph_->AddEdge(src_image, src_slot, host_compute, input_index);
}
@@ -1015,7 +994,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (const auto& output : oc_subgraph.outputs_by_dst) {
const Node* dst_node = output.first.node;
Node* dst_image = node_images.at(dst_node);
- int dst_slot = output.first.slot;
+ int dst_slot = output.first.index;
int output_index = output.second;
graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
@@ -1068,14 +1047,19 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
call_node_def_.set_device(device_);
if (rewrite_subgraph_fn) {
+ std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
+ for (const auto& arg : args_by_src_) {
+ arg_source_tensors.at(arg.second) = arg.first;
+ }
// Initialize the input and output permutations to the identity.
std::vector<int> input_permutation(args_by_src_.size());
std::iota(input_permutation.begin(), input_permutation.end(), 0);
std::vector<int> output_permutation(results_.size());
std::iota(output_permutation.begin(), output_permutation.end(), 0);
- TF_RETURN_IF_ERROR(rewrite_subgraph_fn(
- &graph_, &input_permutation, &output_permutation, &call_node_def_));
+ TF_RETURN_IF_ERROR(
+ rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
+ &output_permutation, &call_node_def_));
// Apply the input/output permutations to the 'args_by_...' and 'results_'
// mappings, so when we build edges in BuildOutputGraph() we
@@ -1226,7 +1210,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
for (const auto& input : oc_subgraph->inputs) {
const Node* src_node = input.first.node;
- int src_slot = input.first.slot;
+ int src_slot = input.first.index;
int input_index = input.second;
DataType dtype = src_node->output_type(src_slot);
@@ -1280,8 +1264,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
for (const auto& output : oc_subgraph->outputs_by_src) {
const Node* src_node = output.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = output.first.slot;
- int output_index = output.second;
+ int src_slot = output.first.index;
+ int output_index = output.second.index;
DataType dtype = src_node->output_type(src_slot);
dtypes[output_index] = dtype;
@@ -1680,8 +1664,8 @@ Status Encapsulator::CopyEdgeToOutputGraph(
const string& src_outside_compilation_id, const string& dst_func_id,
const string& dst_outside_compilation_id,
const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
- edges_added) {
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>* edges_added) {
Node* src_image;
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
src_func_id, src_outside_compilation_id, dst_func_id,
@@ -1696,7 +1680,8 @@ Status Encapsulator::CopyEdgeToOutputGraph(
if (edge->IsControlEdge()) {
// Add the control edge, if we have not already added it, using the images
// determined above (potentially call operators or RecvAtHost/SendFromHost).
- if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
+ if (edges_added
+ ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
.second) {
graph_out->AddControlEdge(src_image, dst_image);
}
@@ -1714,8 +1699,8 @@ Status Encapsulator::CopyEdgeToOutputGraph(
// Add the edge, if we have not already added it.
if (edges_added
- ->emplace(NodeSlot(src_image, src_output),
- NodeSlot(dst_image, dst_input))
+ ->emplace(OutputTensor(src_image, src_output),
+ InputTensor(dst_image, dst_input))
.second) {
graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
}
@@ -1739,7 +1724,8 @@ Status Encapsulator::AddEdgesToOutputGraph(
// Set of edges already added to the output graph, represented as (src, dst)
// pairs. We use the set to deduplicate edges; multiple edges in the input
// graph may map to one edge in the output graph.
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>
edges_added;
for (const Edge* edge : graph_in_->edges()) {
@@ -2472,64 +2458,66 @@ Status EncapsulateSubgraphsPass::Run(
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
- auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
- std::vector<int>* input_permutation,
- std::vector<int>* output_permutation,
- NodeDef* node) {
- // Optimize the subgraph.
- OptimizeGraph(flr, subgraph);
-
- const int num_args = input_permutation->size();
- std::vector<bool> const_args(num_args);
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
-
- DataTypeVector arg_types(num_args);
- TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
-
- // Compute a permutation of the arguments such that the constant arguments
- // are first.
- const int num_consts =
- std::count(const_args.begin(), const_args.end(), true);
-
- const int num_resources =
- std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
- const int num_nonconsts = num_args - num_resources - num_consts;
- if (num_nonconsts < 0) {
- return errors::Internal("num_nonconsts should be >= 0, was ",
- num_nonconsts);
- }
+ auto rewrite_subgraph =
+ [flr](const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* subgraph,
+ std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation, NodeDef* node) {
+ // Optimize the subgraph.
+ OptimizeGraph(flr, subgraph);
+
+ const int num_args = input_permutation->size();
+ std::vector<bool> const_args(num_args);
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
+
+ DataTypeVector arg_types(num_args);
+ TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
+
+ // Compute a permutation of the arguments such that the constant
+ // arguments are first.
+ const int num_consts =
+ std::count(const_args.begin(), const_args.end(), true);
+
+ const int num_resources =
+ std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
+ const int num_nonconsts = num_args - num_resources - num_consts;
+ if (num_nonconsts < 0) {
+ return errors::Internal("num_nonconsts should be >= 0, was ",
+ num_nonconsts);
+ }
- int const_pos = 0;
- int arg_pos = num_consts;
- int resource_pos = num_consts + num_nonconsts;
- for (int i = 0; i < num_args; ++i) {
- if (const_args[i]) {
- if (arg_types[i] == DT_RESOURCE) {
- return errors::Internal(
- "Resource arguments cannot be constant (argument ", i, ")");
+ int const_pos = 0;
+ int arg_pos = num_consts;
+ int resource_pos = num_consts + num_nonconsts;
+ for (int i = 0; i < num_args; ++i) {
+ if (const_args[i]) {
+ if (arg_types[i] == DT_RESOURCE) {
+ return errors::Internal(
+ "Resource arguments cannot be constant (argument ", i, ")");
+ }
+ (*input_permutation)[i] = const_pos;
+ ++const_pos;
+ } else if (arg_types[i] == DT_RESOURCE) {
+ (*input_permutation)[i] = resource_pos;
+ ++resource_pos;
+ } else {
+ (*input_permutation)[i] = arg_pos;
+ ++arg_pos;
+ }
}
- (*input_permutation)[i] = const_pos;
- ++const_pos;
- } else if (arg_types[i] == DT_RESOURCE) {
- (*input_permutation)[i] = resource_pos;
- ++resource_pos;
- } else {
- (*input_permutation)[i] = arg_pos;
- ++arg_pos;
- }
- }
- // Renumber argument nodes in the graph.
- TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
+ // Renumber argument nodes in the graph.
+ TF_RETURN_IF_ERROR(
+ RenumberArguments(subgraph->get(), *input_permutation));
- // TODO(phawkins): add a forward is-constant analysis, similarly split
- // outputs into host-memory constants and device-memory non-constants.
+ // TODO(phawkins): add a forward is-constant analysis, similarly split
+ // outputs into host-memory constants and device-memory non-constants.
- AddNodeAttr(kXlaCompiledKernelAttr, true, node);
- AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
- AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
- return Status::OK();
- };
+ AddNodeAttr(kXlaCompiledKernelAttr, true, node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
+ return Status::OK();
+ };
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index e5dab7c657..926589546f 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -28,6 +28,9 @@ limitations under the License.
namespace tensorflow {
// A rewriting function to apply to each subgraph during encapsulation.
+// 'arg_source_tensors' are the tensors corresponding to the arguments in the
+// original source graph (*not* 'graph').
+//
// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
// 'input_permutation' is a mapping from old argument numbers to new argument
// numbers, whereas 'output_permutation' is the same for outputs. Both
@@ -37,6 +40,7 @@ namespace tensorflow {
// The rewrite may also change the NodeDef's operator name, and that
// name will be used as the name of the generated function.
typedef std::function<Status(
+ const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node_def)>
RewriteSubgraphFn;
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 6a7cd932e5..4eb389e0c6 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -757,7 +757,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
- [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
+ [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
@@ -801,7 +802,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
- [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
+ [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index e6c92f9720..98fab319d6 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -51,6 +51,15 @@ py_library(
],
)
+py_test(
+ name = "xla_test_test",
+ size = "small",
+ srcs = ["xla_test_test.py"],
+ deps = [
+ ":xla_test",
+ ],
+)
+
tf_xla_py_test(
name = "adagrad_test",
size = "small",
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 689a4a1f4e..e610b63e30 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -201,6 +201,16 @@ class UnaryOpsTest(XLATestCase):
expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284],
dtype=dtype))
+ # Disable float16 testing for now
+ if dtype != np.float16:
+ x = np.arange(-10, 10, 1).astype(dtype)
+ with self.test_session() as session:
+ erf_x = session.run(math_ops.erf(x))
+ erfc_x = session.run(math_ops.erfc(x))
+
+ self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x)
+ self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x)
+
self._assertOpOutputMatchesExpected(
math_ops.exp,
np.array([[-1, 1]], dtype=dtype),
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index e924fe1e61..88827cb53b 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -49,6 +49,32 @@ flags.DEFINE_string('tf_xla_flags', None,
'Value to set the TF_XLA_FLAGS environment variable to')
+def parse_disabled_manifest(manifest_content):
+ comments_re = re.compile('#.*$')
+ disabled_tests = []
+ disabled_method_types = []
+ for l in manifest_content.splitlines():
+ stripped = comments_re.sub('', l).strip()
+ if not stripped:
+ continue
+ entry = stripped.split(' ')
+ if len(entry) == 1:
+ disabled_tests.append(entry[0])
+ elif len(entry) == 2:
+ disabled_method_types.append((entry[0], entry[1].strip().split(',')))
+ else:
+ raise ValueError('Bad entry in manifest file.')
+
+ disabled_regex = '|'.join(disabled_tests)
+ method_types_filter = dict()
+ for method, types in disabled_method_types:
+ method_types_filter[method] = set([
+ dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
+ for name in types
+ ])
+ return disabled_regex, method_types_filter
+
+
class XLATestCase(test.TestCase):
"""XLA test cases are parameterized test cases."""
@@ -85,38 +111,21 @@ class XLATestCase(test.TestCase):
# Parse the manifest file, if any, into a regex identifying tests to
# disable
- self.disabled_regex = None
- self._method_types_filter = dict()
# TODO(xpan): Make it text proto if it doesn't scale.
# Each line of the manifest file specifies an entry. The entry can be
# 1) TestNameRegex // E.g. CumprodTest.* Or
# 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
# The 1) disables the entire test. While 2) only filter some numeric types
# so that they are not used in those tests.
+ self.disabled_regex = None
+ self._method_types_filter = {}
if FLAGS.disabled_manifest is not None:
- comments_re = re.compile('#.*$')
- manifest_file = open(FLAGS.disabled_manifest, 'r')
- disabled_tests = []
- disabled_method_types = []
- for l in manifest_file.read().splitlines():
- if not l:
- continue
- entry = comments_re.sub('', l).strip().split(' ')
- if len(entry) == 1:
- disabled_tests.append(entry[0])
- elif len(entry) == 2:
- disabled_method_types.append(
- (entry[0], entry[1].strip().split(',')))
- else:
- raise ValueError('Bad entry in manifest file.')
-
- self.disabled_regex = re.compile('|'.join(disabled_tests))
- for method, types in disabled_method_types:
- self._method_types_filter[method] = set([
- dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
- for name in types])
- manifest_file.close()
+ with open(FLAGS.disabled_manifest, 'r') as manifest_file:
+ disabled_regex, self._method_types_filter = (
+ parse_disabled_manifest(manifest_file.read()))
+ if disabled_regex:
+ self.disabled_regex = re.compile(disabled_regex)
if FLAGS.tf_xla_flags is not None:
os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
diff --git a/tensorflow/compiler/tests/xla_test_test.py b/tensorflow/compiler/tests/xla_test_test.py
new file mode 100644
index 0000000000..2466445157
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_test_test.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the XLATestCase test fixture base class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.platform import test
+
+
+class XlaTestCaseTestCase(test.TestCase):
+
+ def testManifestEmptyLineDoesNotCatchAll(self):
+ manifest = """
+testCaseOne
+"""
+ disabled_regex, _ = xla_test.parse_disabled_manifest(manifest)
+ self.assertEqual(disabled_regex, "testCaseOne")
+
+ def testManifestWholeLineCommentDoesNotCatchAll(self):
+ manifest = """# I am a comment
+testCaseOne
+testCaseTwo
+"""
+ disabled_regex, _ = xla_test.parse_disabled_manifest(manifest)
+ self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 71a9fd051b..2521445e86 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -16,9 +16,11 @@ limitations under the License.
// Native XLA implementations of simple unary Ops
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -185,5 +187,49 @@ XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
#undef XLAJIT_MAKE_UNARY
+// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial
+// is used outside of this range.
+class ErfOp : public XlaOpKernel {
+ public:
+ explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::PrimitiveType primitive_type;
+ xla::XlaOp one = XlaHelpers::One(b, input_type(0));
+ xla::XlaOp x = ctx->Input(0);
+ xla::XlaOp abs_x = b->Abs(x);
+
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(input_type(0), &primitive_type));
+
+ auto y = b->Select(b->Gt(abs_x, one),
+ b->Sub(one, ComputeErfc(b, x, primitive_type)),
+ ComputeErf(b, x, primitive_type));
+ ctx->SetOutput(0, y);
+ }
+};
+REGISTER_XLA_OP(Name("Erf"), ErfOp);
+
+class ErfcOp : public XlaOpKernel {
+ public:
+ explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp one = XlaHelpers::One(b, input_type(0));
+ xla::XlaOp x = ctx->Input(0);
+ xla::XlaOp abs_x = b->Abs(x);
+
+ xla::PrimitiveType primitive_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(input_type(0), &primitive_type));
+
+ auto y = b->Select(b->Lt(abs_x, one),
+ b->Sub(one, ComputeErf(b, x, primitive_type)),
+ ComputeErfc(b, x, primitive_type));
+ ctx->SetOutput(0, y);
+ }
+};
+REGISTER_XLA_OP(Name("Erfc"), ErfcOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 526694d5a0..ee0bb91a6b 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -71,8 +71,8 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
}
// Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::HasZeroElements(x_shape) ||
- xla::ShapeUtil::HasZeroElements(y_shape)) {
+ if (xla::ShapeUtil::IsZeroElementArray(x_shape) ||
+ xla::ShapeUtil::IsZeroElementArray(y_shape)) {
std::vector<int64> dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 1b8e516770..4525197146 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -309,7 +309,6 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index a1d34796cc..639f85737f 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -121,4 +121,88 @@ StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
return builder->Reduce(predicates, f, logical_or, all_dimensions);
}
+namespace {
+xla::XlaOp FloatLiteral(xla::XlaBuilder* b, PrimitiveType data_type,
+ float value) {
+ return b->ConvertElementType(b->ConstantR0(value), data_type);
+}
+
+// Polynomials for computing erf/erfc. Originally from cephes.
+// Note we use float for compatibility across devices, at the cost of some
+// precision for 64 bit computations.
+//
+// Coefficients are in descending order.
+std::array<float, 9> kErfcPCoefficient = {
+ 2.46196981473530512524E-10, 5.64189564831068821977E-1,
+ 7.46321056442269912687E0, 4.86371970985681366614E1,
+ 1.96520832956077098242E2, 5.26445194995477358631E2,
+ 9.34528527171957607540E2, 1.02755188689515710272E3,
+ 5.57535335369399327526E2};
+std::array<float, 9> kErfcQCoefficient = {
+ 1.00000000000000000000E0, 1.32281951154744992508E1,
+ 8.67072140885989742329E1, 3.54937778887819891062E2,
+ 9.75708501743205489753E2, 1.82390916687909736289E3,
+ 2.24633760818710981792E3, 1.65666309194161350182E3,
+ 5.57535340817727675546E2};
+std::array<float, 6> kErfcRCoefficient = {
+ 5.64189583547755073984E-1, 1.27536670759978104416E0,
+ 5.01905042251180477414E0, 6.16021097993053585195E0,
+ 7.40974269950448939160E0, 2.97886665372100240670E0};
+std::array<float, 7> kErfcSCoefficient = {
+ 1.00000000000000000000E0, 2.26052863220117276590E0,
+ 9.39603524938001434673E0, 1.20489539808096656605E1,
+ 1.70814450747565897222E1, 9.60896809063285878198E0,
+ 3.36907645100081516050E0};
+std::array<float, 5> kErfTCoefficient = {
+ 9.60497373987051638749E0, 9.00260197203842689217E1,
+ 2.23200534594684319226E3, 7.00332514112805075473E3,
+ 5.55923013010394962768E4};
+std::array<float, 6> kErfUCoefficient = {
+ 1.00000000000000000000E0, 3.35617141647503099647E1,
+ 5.21357949780152679795E2, 4.59432382970980127987E3,
+ 2.26290000613890934246E4, 4.92673942608635921086E4};
+} // namespace
+
+// Evaluate the polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
+ tensorflow::gtl::ArraySlice<float> coefficients,
+ PrimitiveType data_type) {
+ xla::XlaOp poly = FloatLiteral(b, data_type, 0.0);
+ for (float c : coefficients) {
+ poly = b->Add(b->Mul(poly, x), FloatLiteral(b, data_type, c));
+ }
+ return poly;
+}
+
+// Compute an approximation of the error function complement (1 - erf(x)).
+xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type) {
+ xla::XlaOp zero = FloatLiteral(b, data_type, 0.0);
+ xla::XlaOp two = FloatLiteral(b, data_type, 2.0);
+ xla::XlaOp eight = FloatLiteral(b, data_type, 8.0);
+
+ xla::XlaOp abs_x = b->Abs(x);
+ xla::XlaOp z = b->Exp(b->Mul(b->Neg(x), x));
+
+ xla::XlaOp pp = EvaluatePolynomial(b, abs_x, kErfcPCoefficient, data_type);
+ xla::XlaOp pq = EvaluatePolynomial(b, abs_x, kErfcQCoefficient, data_type);
+ xla::XlaOp pr = EvaluatePolynomial(b, abs_x, kErfcRCoefficient, data_type);
+ xla::XlaOp ps = EvaluatePolynomial(b, abs_x, kErfcSCoefficient, data_type);
+
+ xla::XlaOp y = b->Select(b->Lt(abs_x, eight), b->Div(b->Mul(z, pp), pq),
+ b->Div(b->Mul(z, pr), ps));
+
+ return b->Select(b->Lt(x, zero), b->Sub(two, y), y);
+}
+
+// Compute a polynomial approximation of the error function.
+xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type) {
+ xla::XlaOp z = b->Mul(x, x);
+ xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type);
+ xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type);
+ return b->Div(b->Mul(x, pt), pu);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 64b6b7d633..f11cc00317 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -55,6 +55,20 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
// Note: if predicates is zero-sized, Any() vacuously returns false.
StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder);
+// Evaluate the polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
+ tensorflow::gtl::ArraySlice<double> coefficients,
+ PrimitiveType data_type);
+
+// Compute an approximation of the error function complement (1 - erf(x)).
+xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type);
+
+// Compute an approximation of the error function.
+xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
+ PrimitiveType data_type);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index ae8fbdb2dc..d7ebcf8beb 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1632,8 +1632,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (channel_id.has_value()) {
- return Unimplemented(
- "replica_group_ids and channel_id and is not supported in AllReduce");
+ return Unimplemented("channel_id is not supported in AllReduce");
}
HloInstructionProto instr;
diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD
new file mode 100644
index 0000000000..a26b20c861
--- /dev/null
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD
@@ -0,0 +1,18 @@
+# Description:
+# Python API for shardings in XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+py_library(
+ name = "xla_sharding",
+ srcs = ["xla_sharding.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ "//tensorflow/compiler/xla/python_api:types",
+ "//tensorflow/compiler/xla/python_api:xla_shape",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
new file mode 100644
index 0000000000..abd10b164e
--- /dev/null
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
@@ -0,0 +1,204 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Experimental support for defining XLA shardings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python_api import xla_shape
+from tensorflow.core.framework import attr_value_pb2
+
+
+class Sharding(object):
+ """A class to support adding sharding attributes to Ops.
+
+ Use the factory constructors and then call apply_to_tensor:
+ Sharding.replicate().apply_to_tensor(tensor)
+ """
+
+ def __init__(self, proto=None):
+ """Do not use this constructor; use the factory functions below."""
+ self._proto = proto
+
+ @classmethod
+ def replicate(cls):
+ """Returns a replicated sharding attribute.
+
+ This causes an op to be computed in its entirety independently on all
+ cores in the XLA device.
+ """
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
+
+ @classmethod
+ def assign_device(cls, core):
+ """Returns an AssignDevice sharding attribute.
+
+ This causes an op to be computed in its entirety only on one core in
+ the XLA device.
+ Args:
+ core: The core to assign this Op to.
+ """
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.MAXIMAL,
+ tile_assignment_dimensions=[1],
+ tile_assignment_devices=[core]))
+
+ @classmethod
+ def tile(cls, tile_shape, tile_assignment):
+ """Returns a Tiled sharding attribute.
+
+ This causes an op to be partially computed on multiple cores in the
+ XLA device.
+
+ Args:
+ tile_shape: A xla_shape.Shape describing the tile shape that each core
+ will compute.
+ The tile shape does not need to be divisible by the tile assignment.
+ tile_assignment: An np.ndarray describing the topology of the tiling and
+ which device will compute which part of the topology.
+
+ Raises:
+ TypeError: tile_assignment was not of np.array type or tile_shape was
+ not of xla_shape.Shape type.
+
+ TODO(jmolloy): This concept is nefarious and is not
+ something we really want to expose to users (especially as the
+ contract for tile_assignment is very strict).
+ """
+ if not isinstance(tile_assignment, np.ndarray):
+ raise TypeError('Tile assignment must be of type np.ndarray')
+ if not isinstance(tile_shape, xla_shape.Shape):
+ raise TypeError('Tile shape must be of type xla_shape.Shape')
+ dims = list(tile_assignment.shape)
+ flattened_devices = tile_assignment.reshape(-1, order='C')
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.OTHER,
+ tile_shape=tile_shape.message,
+ tile_assignment_dimensions=dims,
+ tile_assignment_devices=list(flattened_devices)))
+
+ @classmethod
+ def split(cls, tensor, split_dimension, num_devices):
+ """Returns a Sharding that splits a tensor across a dimension.
+
+ This creates a Tiled attribute, similar to tile(), but easier to use for the
+ common case of tiling a tensor N ways in one dimension.
+
+ Args:
+ tensor: A tf.Tensor to split.
+ split_dimension: The dimension number to split.
+ num_devices: The number of cores to split `tensor` over.
+
+ Raises:
+ ValueError: The tensor to split was smaller in the split dimension than
+ the number of devices to split over.
+ """
+ tensor.shape.assert_is_fully_defined()
+ shape = tensor.shape.as_list()
+ if shape[split_dimension] < num_devices:
+ raise ValueError('Split dimension was smaller than the required number '
+ 'of splits: shape=%r, dimension=%r, num_devices=%r',
+ shape, split_dimension, num_devices)
+
+ tile_shape = shape
+ tile_shape[split_dimension] = int(
+ math.ceil(tile_shape[split_dimension] / num_devices))
+ tile_shape_proto = xla_data_pb2.Shape(
+ element_type=xla_data_pb2.F32, dimensions=tile_shape)
+
+ tile_assignment_dims = [1] * len(shape)
+ tile_assignment_dims[split_dimension] = num_devices
+
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.OTHER,
+ tile_shape=tile_shape_proto,
+ tile_assignment_dimensions=tile_assignment_dims,
+ tile_assignment_devices=range(num_devices)))
+
+ def apply_to_tensor(self, tensor):
+ """Applies this Sharding attribute to `tensor`."""
+ if len(tensor.op.outputs) > 1:
+ proto = self._get_or_create_tuple_proto(tensor.op)
+ # We can't mutate an element of old_proto.tuple_shardings, so create
+ # a new proto.
+ tuple_shardings = list(proto.tuple_shardings)
+ tuple_shardings[tensor.value_index] = self._proto
+ proto = xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
+ else:
+ proto = self._proto
+
+ attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
+ # TODO(jmolloy): This need to be seriously revisited before declaring this
+ # API available for public use.
+ # pylint: disable=protected-access
+ tensor.op._set_attr('_XlaSharding', attr_value)
+
+ @property
+ def proto(self):
+ """Return the sharding protobuf of type xla_data_pb2.OpSharding."""
+ return self._proto
+
+ def _get_or_create_tuple_proto(self, op):
+ try:
+ attr = op.get_attr('_XlaSharding')
+ proto = xla_data_pb2.OpSharding()
+ proto.ParseFromString(attr)
+ return proto
+ except ValueError:
+ return self._create_tuple_proto(op)
+
+ def _create_tuple_proto(self, op):
+ shardings = [
+ xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
+ for _ in op.outputs
+ ]
+ return xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
+
+
+# Helpers for the above factory functions that allow easy application of
+# shardings, for example:
+# tensor = xla_sharding.replicate(tensor)
+
+
+def replicate(tensor):
+ Sharding.replicate().apply_to_tensor(tensor)
+ return tensor
+
+
+def assign_device(tensor, device):
+ Sharding.assign_device(device).apply_to_tensor(tensor)
+ return tensor
+
+
+def tile(tensor, tile_shape, tile_assignment):
+ Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor)
+ return tensor
+
+
+def split(tensor, split_dimension, num_devices):
+ Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor)
+ return tensor
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index e8f29b8329..3f059cac30 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -190,9 +190,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
if (!ShapeUtil::IsArray(shape)) {
- return InvalidArgument(
- "shape of primitive type %s should not have a layout",
- PrimitiveType_Name(shape.element_type()).c_str());
+ if (layout.minor_to_major_size() != 0 ||
+ layout.padded_dimensions_size() != 0) {
+ return InvalidArgument(
+ "shape of primitive type %s should not have a non-trivial layout",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return Status::OK();
}
if (layout.format() == INVALID_FORMAT) {
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index bf9679cafe..2125ab7c61 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -606,8 +606,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
} // namespace
Status EqualShapes(const Shape& expected, const Shape& actual) {
- if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
- return InvalidArgument("tupleness-mismatch! want: %s got %s",
+ if (expected.element_type() != actual.element_type()) {
+ return InvalidArgument("element type mismatch, want: %s got %s",
ShapeUtil::HumanString(expected).c_str(),
ShapeUtil::HumanString(actual).c_str());
}
@@ -626,7 +626,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
return AppendStatus(result, StrCat("mismatch in tuple index", i));
}
}
- } else {
+ } else if (ShapeUtil::IsArray(expected)) {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
return InvalidArgument("want rank of %s got rank of %s",
ShapeUtil::HumanString(expected).c_str(),
@@ -652,6 +652,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
}
}
}
+ // Non-array, non-tuple shapes are trivially equivalent.
return Status::OK();
}
@@ -705,6 +706,9 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
}
break;
}
+ case TOKEN:
+ // Tokens have no on-device representation and are trivially equal.
+ return Status::OK();
default:
LOG(FATAL)
<< "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 6b29589700..19e6d288c0 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -148,8 +148,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
piece->emplace_back(std::move(child_piece));
}
- } else {
- CHECK(ShapeUtil::IsArray(shape));
+ } else if (ShapeUtil::IsArray(shape)) {
if (allocate_arrays) {
if (LayoutUtil::IsSparseArray(shape)) {
// For sparse arrays, the buffer must be of the size of the maximum
@@ -165,6 +164,10 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
piece->set_buffer(new char[piece->size_bytes()]);
}
}
+ } else {
+ // If the shape is neither an array nor tuple, then it must be
+ // zero-sized. Otherwise, some memory needs to be allocated for it.
+ CHECK_EQ(piece->size_bytes(), 0);
}
}
@@ -264,8 +267,8 @@ Status Literal::CopySliceFromInternal(
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
linear_index(src_literal.shape(), src_base), 0, 1);
- } else if (!ShapeUtil::HasZeroElements(shape()) &&
- !ShapeUtil::HasZeroElements(src_literal.shape())) {
+ } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
+ !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
// Perform copy if neither src nor dest has dimensions with zero element,
// otherwise it's a no-op.
TF_RET_CHECK(src_base.size() == dest_base.size());
@@ -327,6 +330,10 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
+/* static */ std::unique_ptr<Literal> Literal::CreateToken() {
+ return MakeUnique<Literal>(ShapeUtil::MakeTokenShape());
+}
+
std::vector<Literal> Literal::DecomposeTuple() {
CHECK(ShapeUtil::IsTuple(shape()));
std::vector<Literal> elements;
@@ -379,7 +386,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
tensorflow::gtl::ArraySlice<NativeT> src,
const Shape& dest_shape, const Shape& src_shape) {
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
- if (ShapeUtil::HasZeroElements(dest_shape)) {
+ if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(ShapeUtil::Rank(dest_shape));
@@ -1177,7 +1184,7 @@ size_t LiteralBase::Hash() const {
ShapeUtil::ForEachSubshape(
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
- if (ShapeUtil::IsTuple(subshape)) {
+ if (!ShapeUtil::IsArray(subshape)) {
return;
}
@@ -1368,6 +1375,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
return;
}
+ if (ShapeUtil::IsToken(subshape)) {
+ pieces->push_back("token");
+ return;
+ }
+
if (LayoutUtil::IsSparseArray(subshape)) {
pieces->push_back(shape_to_string(subshape));
pieces->push_back("{");
@@ -1556,7 +1568,7 @@ string LiteralBase::ToString(bool print_layout) const {
void LiteralBase::EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
@@ -1962,7 +1974,7 @@ bool LiteralBase::IsAllFirst() const {
// Empty shapes are not all the first element since there is no first
// element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
+ if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
return false;
}
auto piece_is_all = [&]() {
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 8e4159e360..37ca8ea9f1 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -917,6 +917,9 @@ class Literal : public LiteralBase {
return MakeTupleOwned(std::move(v));
}
+ // Create a constant token literal. Token types have no value.
+ static std::unique_ptr<Literal> CreateToken();
+
// Returns a vector containing the tuple elements of this Literal as separate
// Literals. This Literal must be tuple-shaped and can be a nested tuple. The
// elements are moved into the new Literals; no data is copied. Upon return
@@ -1456,7 +1459,7 @@ void LiteralBase::EachCell(
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
NativeT value)>
per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 53b926163c..493d807591 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -334,6 +334,22 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
EXPECT_EQ(nil, nil);
}
+TEST_F(LiteralUtilTest, TokenEquality) {
+ auto token0 = Literal::CreateToken();
+ auto token1 = Literal::CreateToken();
+ auto scalar = Literal::CreateR0<float>(1.0);
+
+ EXPECT_EQ(*token0, *token1);
+ EXPECT_NE(*token0, *scalar);
+
+ EXPECT_EQ(*Literal::MakeTuple({token0.get()}),
+ *Literal::MakeTuple({token0.get()}));
+ EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}),
+ *Literal::MakeTuple({token1.get(), scalar.get()}));
+ EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}),
+ *Literal::MakeTuple({scalar.get(), token1.get()}));
+}
+
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
auto colmajor =
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 143c9a2366..b16147e3be 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
}
}
+bool IsArrayType(PrimitiveType primitive_type) {
+ return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
+ primitive_type != OPAQUE && primitive_type != TOKEN;
+}
+
} // namespace primitive_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index b26a10ade6..889e9a1cec 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type);
bool IsIntegralType(PrimitiveType type);
+// Returns true if values of the given primitive type are held in array shapes.
+bool IsArrayType(PrimitiveType primitive_type);
+
// Returns the number of bits in the representation for a given type.
int BitWidth(PrimitiveType type);
diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD
new file mode 100644
index 0000000000..8999cda5ef
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/BUILD
@@ -0,0 +1,36 @@
+# Description:
+# Python API for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+py_library(
+ name = "types",
+ srcs = ["types.py"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "xla_shape",
+ srcs = ["xla_shape.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ ],
+)
+
+py_library(
+ name = "xla_literal",
+ srcs = ["xla_literal.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ ":xla_shape",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ ],
+)
diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py
new file mode 100644
index 0000000000..b60f8dce92
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/types.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Utilities for XLA-specific Python types."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+
+# Records corresponsence between a XLA primitive type and Python/Numpy types.
+#
+# primitive_type: value of type xla_data_pb2.PrimitiveType
+# numpy_dtype: corresponsing Numpy "dtype" (like np.float32)
+# literal_field_name: name of the field in the LiteralProto message elements
+# of this type go into.
+# literal_field_type: type of the field named 'literal_field_name'.
+#
+# TODO(eliben): figure out how to avoid knowing the extra Python type and the
+# astype cast when writing into Literals.
+TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [
+ 'primitive_type', 'numpy_dtype', 'literal_field_name', 'literal_field_type'
+])
+
+# Maps from XLA primitive types to TypeConversionRecord.
+MAP_XLA_TYPE_TO_RECORD = {
+ xla_data_pb2.F16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.F16,
+ numpy_dtype=np.float16,
+ literal_field_name='f16s',
+ literal_field_type=float),
+ xla_data_pb2.F32:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.F32,
+ numpy_dtype=np.float32,
+ literal_field_name='f32s',
+ literal_field_type=float),
+ xla_data_pb2.F64:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.F64,
+ numpy_dtype=np.float64,
+ literal_field_name='f64s',
+ literal_field_type=float),
+ xla_data_pb2.S8:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S8,
+ numpy_dtype=np.int8,
+ literal_field_name='s8s',
+ literal_field_type=int),
+ xla_data_pb2.S16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S16,
+ numpy_dtype=np.int16,
+ literal_field_name='s16s',
+ literal_field_type=int),
+ xla_data_pb2.S32:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S32,
+ numpy_dtype=np.int32,
+ literal_field_name='s32s',
+ literal_field_type=int),
+ xla_data_pb2.S64:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S64,
+ numpy_dtype=np.int64,
+ literal_field_name='s64s',
+ literal_field_type=int),
+ xla_data_pb2.U8:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U8,
+ numpy_dtype=np.uint8,
+ literal_field_name='s8s',
+ literal_field_type=int),
+ xla_data_pb2.U16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U16,
+ numpy_dtype=np.uint16,
+ literal_field_name='s16s',
+ literal_field_type=int),
+ xla_data_pb2.U32:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U32,
+ numpy_dtype=np.uint32,
+ literal_field_name='s32s',
+ literal_field_type=int),
+ xla_data_pb2.U64:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U64,
+ numpy_dtype=np.uint64,
+ literal_field_name='s64s',
+ literal_field_type=int),
+ xla_data_pb2.PRED:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.PRED,
+ numpy_dtype=np.bool,
+ literal_field_name='preds',
+ literal_field_type=bool)
+}
+
+# Maps from Numpy dtypes to TypeConversionRecord.
+# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
+# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
+# when keying by dtype in this dict, we use the string form of dtypes.
+MAP_DTYPE_TO_RECORD = {
+ str(np.dtype(record.numpy_dtype)): record
+ for record in MAP_XLA_TYPE_TO_RECORD.values()
+}
diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py
new file mode 100644
index 0000000000..b040098c29
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/xla_literal.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""XLA LiteralProto utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python_api import types
+from tensorflow.compiler.xla.python_api import xla_shape
+
+
+def ConvertLiteralToNumpyArray(literal):
+ """Converts a XLA literal to a Numpy array."""
+ element_type = literal.shape.element_type
+ if element_type == xla_data_pb2.TUPLE:
+ return tuple(
+ ConvertLiteralToNumpyArray(subliteral)
+ for subliteral in literal.tuple_literals)
+
+ type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type]
+ if not literal.shape.dimensions:
+ return np.array(
+ getattr(literal, type_record.literal_field_name)[0],
+ type_record.numpy_dtype)
+ else:
+ # Infer the proper Numpy order from the LiteralProto's layout. The repeated
+ # field representing the array's content in the Literal is linearized.
+ # Reading is done in two steps:
+ #
+ # 1. Read the array as 1D from the LiteralProto repeated field.
+ # 2. Reshape the array to its proper shape, using the right order depending
+ # on the LiteralProto's layout.
+ layout_order = literal.shape.layout.minor_to_major
+ numpy_shape = tuple(literal.shape.dimensions)
+ if layout_order == range(len(literal.shape.dimensions)):
+ numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F')
+ elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1):
+ numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C')
+ else:
+ raise NotImplementedError('Unsupported layout: {0}'.format(layout_order))
+ ndarray = np.array(
+ getattr(literal, type_record.literal_field_name),
+ copy=False,
+ dtype=type_record.numpy_dtype)
+ return numpy_reshaper(ndarray)
+
+
+def _ConvertNumpyArrayToLiteral(ndarray):
+ """Converts a Numpy array to a XLA literal."""
+ type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)]
+ literal = xla_data_pb2.LiteralProto()
+ literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)
+
+ if ndarray.ndim == 0:
+ getattr(literal, type_record.literal_field_name).append(
+ np.asscalar(ndarray.astype(type_record.literal_field_type)))
+ else:
+ # Ndarrays with boolean dtypes need special type conversion with protobufs
+ if ndarray.dtype in {np.bool_, np.dtype('bool')}:
+ for element in np.nditer(ndarray):
+ getattr(literal, type_record.literal_field_name).append(
+ type_record.literal_field_type(element))
+ else:
+ ndarray_flat = ndarray.ravel(order='A')
+ getattr(literal, type_record.literal_field_name).extend(ndarray_flat)
+ return literal
+
+
+def ConvertNumpyArrayToLiteral(value):
+ """Converts a Numpy array or a nested tuple thereof to an XLA literal."""
+ if isinstance(value, tuple):
+ literal = xla_data_pb2.LiteralProto()
+ literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
+ for component in value:
+ component_literal = literal.tuple_literals.add()
+ component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component))
+ return literal
+ else:
+ return _ConvertNumpyArrayToLiteral(value)
diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py
new file mode 100644
index 0000000000..6af2895803
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/xla_shape.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""XLA Shape utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python_api import types
+
+
+class Shape(object):
+ """Wraps a xla_data_pb2.Shape message with a convenient Python type.
+
+ Provides direct access to the underlying xla_data_pb2.Shape message in the
+ message attribute, along with accessor wrappers to the message's fields.
+ Avoid direct access to .message unless interacting directly with protobuf APIs
+ like CopyFrom. In other words, prefer hauling the shape around in a Shape, and
+ only access .message when strictly required by the protobuf API.
+ """
+
+ def __init__(self, element_type, dimensions, layout=None):
+ """Creates a new XLA Shape.
+
+ Args:
+ element_type: element type from xla_data_pb2.
+ dimensions: sequence of dimensions sizes (integers), or sequence
+ of Shapes in the case of a tuple, i.e. when element_type is
+ TUPLE.
+ layout: optional minor_to_major sequence for layout. If not given, the
+ default major-to-minor layout is used.
+
+ Raises:
+ ValueError: if element_type is TUPLE but dimensions are not Shape objects.
+ """
+ self.message = xla_data_pb2.Shape()
+ self.message.element_type = element_type
+ if element_type == xla_data_pb2.TUPLE:
+ if not all(isinstance(subshape, Shape) for subshape in dimensions):
+ raise ValueError(
+ 'XLA tuple requires sequence of Shape objects as dimensions')
+ self._tuple_shapes = tuple(dimensions)
+ for component_shape in self._tuple_shapes:
+ component_message = self.message.tuple_shapes.add()
+ component_message.CopyFrom(component_shape.message)
+ else:
+ self.message.dimensions.extend(dimensions)
+ if layout is None:
+ layout = list(reversed(range(len(dimensions))))
+ self.message.layout.format = xla_data_pb2.DENSE
+ self.message.layout.minor_to_major.extend(layout)
+
+ def element_type(self):
+ return self.message.element_type
+
+ def is_tuple(self):
+ return self.element_type() == xla_data_pb2.TUPLE
+
+ def dimensions(self):
+ if self.is_tuple():
+ raise ValueError('Tuple shape has no dimensions. Try tuple_shapes()?')
+ return self.message.dimensions
+
+ def tuple_shapes(self):
+ """If this is a tuple, returns its sequence of constituent Shape objects.
+
+ Returns:
+ Tuple sub-shapes.
+
+ Raises:
+ ValueError: if this is not a tuple.
+ """
+ if not self.is_tuple():
+ raise ValueError('tuple_shapes() called on a non-tuple shape')
+ return self._tuple_shapes
+
+ def layout(self):
+ return self.message.layout
+
+ @staticmethod
+ def from_pyval(pyval):
+ return CreateShapeFromNumpy(pyval)
+
+
+def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name
+ """Create a Shape from a given Numpy array.
+
+ Args:
+ ndarray: Numpy array.
+
+ Returns:
+ A Shape object.
+ """
+ element_type = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)].primitive_type
+ dimensions = ndarray.shape
+
+ # Set the shape's layout based on the ordering of ndarray.
+ # Numpy arrays come in two orders: Fortran (column-major) and C (row-major).
+ if np.isfortran(ndarray):
+ # Column-major layout. This corresponds to a "dimension order is
+ # minor-to-major" layout in XLA.
+ layout = range(ndarray.ndim)
+ else:
+ # Row-major layout. This corresponds to a "dimension order is
+ # major-to-minor" layout int XLA.
+ layout = list(reversed(xrange(ndarray.ndim)))
+
+ return Shape(element_type, dimensions, layout)
+
+
+def CreateShapeFromNumpy(value): # pylint: disable=invalid-name
+ """Create a Shape from a Numpy array or a nested tuple structure thereof.
+
+ Args:
+ value: Numpy array or (possibly nested) tuple structure that bottoms out in
+ Numpy arrays.
+
+ Returns:
+ A Shape object.
+ """
+ if isinstance(value, tuple):
+ return Shape(
+ xla_data_pb2.TUPLE,
+ [CreateShapeFromNumpy(component) for component in value])
+ else:
+ return _CreateShapeFromNumpy(value)
+
+
+def CreateShapeFromDtypeAndTuple(dtype, shape_tuple): # pylint: disable=invalid-name
+ """Create a shape from a Numpy dtype and a sequence of nonnegative integers.
+
+ Args:
+ dtype: a numpy dtype, e.g. np.dtype('int32').
+ shape_tuple: a sequence of nonnegative integers.
+
+ Returns:
+ A Shape object.
+ """
+ element_type = types.MAP_DTYPE_TO_RECORD[str(dtype)].primitive_type
+ return Shape(element_type, shape_tuple)
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2942edbf71..8a1d1bf73d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1101,6 +1101,7 @@ tf_cc_test(
srcs = ["hlo_scheduling_test.cc"],
deps = [
":buffer_value",
+ ":heap_simulator",
":hlo",
":hlo_ordering",
":hlo_scheduling",
@@ -2123,6 +2124,7 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
+ ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
@@ -2130,6 +2132,7 @@ cc_library(
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
+ ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -2143,6 +2146,7 @@ tf_cc_test(
name = "hlo_rematerialization_test",
srcs = ["hlo_rematerialization_test.cc"],
deps = [
+ ":flatten_call_graph",
":hlo",
":hlo_matchers",
":hlo_ordering",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 3b36939b8a..1fc8fb9b69 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -449,7 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
// Filter out and remove empty operands.
std::vector<HloInstruction*> nonempty_operands;
for (HloInstruction* operand : operands) {
- if (!ShapeUtil::HasZeroElements(operand->shape())) {
+ if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
nonempty_operands.push_back(operand);
}
}
@@ -1058,9 +1058,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
}
// Replace a zero element dot with a broadcast of the constant 0.
- if (ShapeUtil::HasZeroElements(dot->shape()) ||
- ShapeUtil::HasZeroElements(lhs->shape()) ||
- ShapeUtil::HasZeroElements(rhs->shape())) {
+ if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
+ ShapeUtil::IsZeroElementArray(lhs->shape()) ||
+ ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
return ReplaceWithNewInstruction(
@@ -1392,7 +1392,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
}
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
- if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
pad, HloInstruction::CreateBroadcast(pad->shape(),
pad->mutable_operand(1), {}));
@@ -1638,7 +1638,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// Reshape directly to empty constant if the shape contains zero-element
// dimension.
- if (ShapeUtil::HasZeroElements(reshape->shape())) {
+ if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
auto empty_constant = HloInstruction::CreateConstant(
Literal::CreateFromShape(reshape->shape()));
@@ -1739,7 +1739,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
// If any dimension of update is 0, elide the DynamicUpdateSlice. This
// optimization becomes invalid should we later prefer to warn about out of
// bound indices.
- if (ShapeUtil::HasZeroElements(update->shape())) {
+ if (ShapeUtil::IsZeroElementArray(update->shape())) {
return ReplaceInstruction(dynamic_update_slice,
dynamic_update_slice->mutable_operand(0));
}
@@ -1751,8 +1751,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
auto init_value = reduce->mutable_operand(1);
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
- if (ShapeUtil::HasZeroElements(arg->shape()) ||
- ShapeUtil::HasZeroElements(reduce->shape())) {
+ if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
+ ShapeUtil::IsZeroElementArray(reduce->shape())) {
return ReplaceWithNewInstruction(
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
@@ -1863,7 +1863,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* reduce_window) {
- if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateBroadcast(reduce_window->shape(),
@@ -2059,8 +2059,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution) {
auto lhs = convolution->mutable_operand(0);
auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::HasZeroElements(lhs->shape()) ||
- ShapeUtil::HasZeroElements(rhs->shape())) {
+ if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
+ ShapeUtil::IsZeroElementArray(rhs->shape())) {
return ReplaceWithNewInstruction(
convolution,
HloInstruction::CreateBroadcast(
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index ed0746980f..8f1d2f0804 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -631,7 +631,7 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
subshape, converted_outputs.element(parent_index),
output_index.back()));
}
- if (ShapeUtil::IsTuple(subshape)) {
+ if (!ShapeUtil::IsArray(subshape)) {
continue;
}
if (!ShapeUtil::Compatible(
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 5d3b0cb333..afe4b2e142 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -631,8 +631,9 @@ Status BufferAssignment::ComputeSummaryStats() {
}
}
if (module_sequence.size() == module_->computation_count()) {
- TF_ASSIGN_OR_RETURN(const int64 min_size,
- MinimumMemoryForModule(module_sequence, buffer_size_));
+ TF_ASSIGN_OR_RETURN(
+ const int64 min_size,
+ HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 33d8338809..e0ce2e3555 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -472,6 +472,10 @@ class CopyRemover {
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
+ if (ShapeUtil::IsToken(value_a->shape())) {
+ // Token values have no representation and cannot interfere.
+ continue;
+ }
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
@@ -613,7 +617,10 @@ class CopyRemover {
VLOG(2) << copy->name() << " is not removable";
return false;
}
-
+ if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
+ VLOG(2) << copy->name() << " is not removable (shape mismatch)";
+ return false;
+ }
const CopyNodes& copy_node = copy_map_.at(copy);
ValueNode* src = copy_node.src;
ValueNode* dest = copy_node.dest;
@@ -947,28 +954,6 @@ class CopyRemover {
BufferValueTracker buffer_value_tracker_;
};
-// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- CopyRemover copy_remover(*alias_analysis, ordering, module);
- XLA_VLOG_LINES(3, copy_remover.ToString());
-
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id())) {
- TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
- }
- }
- }
- return Status::OK();
-}
-
// Add copies to address special constraints on the roots of computations not
// related to live range interference:
//
@@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
HloInstruction* instruction = pair.first;
const ShapeTree<bool>& indices_to_copy = pair.second;
+ ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
std::vector<HloInstruction*> users = instruction->users();
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
instruction->parent()->DeepCopyInstruction(
- instruction, &indices_to_copy));
+ instruction, &indices_to_copy, &copies_added));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
+ // Special case copies are not eligible for later copy elision passes.
+ indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
+ if (has_copy) {
+ HloInstruction* copy = *copies_added.mutable_element(index);
+ if (copy != nullptr) {
+ copy->SetCopyElisionAllowed(false);
+ }
+ }
+ });
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@@ -1097,6 +1092,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering,
+ const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
+ MaybeDumpModule("after adding copies to resolve interference", *module);
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
+ HloAliasAnalysis::Run(module));
+ CopyRemover copy_remover(*alias_analysis, ordering, module);
+ XLA_VLOG_LINES(3, copy_remover.ToString());
+
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCopy &&
+ !ContainsKey(copies_to_exclude, instruction->unique_id()) &&
+ instruction->CopyElisionAllowed()) {
+ TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
+ }
+ }
+ }
+ MaybeDumpModule("after removing unnecessary copies", *module);
+
+ return Status::OK();
+}
+
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Copy insertion is performed in three steps:
//
@@ -1158,14 +1178,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
- MaybeDumpModule("after adding copies to resolve interference", *module);
-
DependencyHloOrdering ordering(module);
TF_RETURN_IF_ERROR(
RemoveUnnecessaryCopies(ordering, existing_copies, module));
- MaybeDumpModule("after removing unnecessary copies", *module);
-
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
MaybeDumpModule("after adding special-case copies", *module);
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 65e3d31e34..0d7b3c20f9 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -64,6 +64,13 @@ class CopyInsertion : public HloPassInterface {
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
};
+// Try to remove as many copies from the module as possible without introducing
+// live range interference. Copy instructions (identified by their unique id) in
+// the set copies_to_exclude are not considered for removal.
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering,
+ const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 684fff8a6f..ed1a50f516 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -1595,6 +1595,45 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
EXPECT_THAT(condition->root_instruction(), op::Constant());
}
+TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
+ string module_string = R"(
+HloModule TokensShouldNotBeCopied
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %generate-token = token[] generate-token(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %TokensShouldNotBeCopied () -> s32[] {
+ %one = s32[] constant(1)
+ %negative_one = s32[] negate(%one)
+ %init_token = token[] generate-token()
+ %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(
+ module_string, GetDebugOptionsForTest()));
+ InsertCopies(module.get());
+
+ // There should be no copies added because tokens should not be copied.
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
auto builder = HloComputation::Builder("trivial_condition");
builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 8eb39d615f..e8b205051e 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -1627,8 +1627,8 @@ bool PotentiallyImplementedAsEigenDot(
const Shape& lhs_shape = hlo.operand(0)->shape();
const Shape& rhs_shape = hlo.operand(1)->shape();
- if (ShapeUtil::HasZeroElements(lhs_shape) ||
- ShapeUtil::HasZeroElements(rhs_shape)) {
+ if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
+ ShapeUtil::IsZeroElementArray(rhs_shape)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index b560b7531c..1a8bedfe6a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution(
return false;
}
- if (ShapeUtil::HasZeroElements(input_shape) ||
- ShapeUtil::HasZeroElements(kernel_shape)) {
+ if (ShapeUtil::IsZeroElementArray(input_shape) ||
+ ShapeUtil::IsZeroElementArray(kernel_shape)) {
return false;
}
// Make sure input and kernel has the same data type.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 59223fddac..758b8c62b4 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -226,10 +226,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
return EmitMemcpy(*(copy->operand(0)), *copy);
- } else {
- // Use the elemental emitter for non-tuple shapes.
+ } else if (ShapeUtil::IsArray(copy->shape())) {
+ // Use the elemental emitter for array shapes.
return DefaultAction(copy);
}
+ return Unimplemented(
+ "unsupported operand type %s for copy instruction",
+ PrimitiveType_Name(copy->shape().element_type()).c_str());
}
// Calculate the alignment of a buffer allocated for a given primitive type.
@@ -1873,7 +1876,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
- if (ShapeUtil::HasZeroElements(slice->shape())) {
+ if (ShapeUtil::IsZeroElementArray(slice->shape())) {
return Status::OK();
}
@@ -2528,6 +2531,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
return Status::OK();
}
+Status IrEmitter::HandleGenerateToken(HloInstruction* gen_token) {
+ TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0);
+ // No code to generate, but we need to emit an address for book-keeping.
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token));
+ return Status::OK();
+}
+
Status IrEmitter::FinishVisit(HloInstruction* root) {
// When this method is called, we should have already emitted an IR value for
// the root (return) op. The IR value holds the address of the buffer holding
@@ -2809,7 +2819,10 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
// For the root node, we write directly to the output buffer of the
// function.
llvm::Argument* retval = compute_function_->result_arg();
- if (!ShapeUtil::IsNil(target_shape)) {
+ if ((ShapeUtil::IsArray(target_shape) &&
+ !ShapeUtil::IsZeroElementArray(target_shape)) ||
+ (ShapeUtil::IsTuple(target_shape) &&
+ !ShapeUtil::IsEmptyTuple(target_shape))) {
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 32c536e18f..e1815c1db7 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -150,6 +150,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
+ Status HandleGenerateToken(HloInstruction* gen_token) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 2d3e4b1fcd..7cd2c9c136 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -300,7 +300,7 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* gather_instr) {
- CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape()));
+ CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape()));
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
@@ -369,7 +369,7 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) {
return inst->opcode() == HloOpcode::kGather &&
// Avoid expanding gather ops that produce zero sized tensors,
// instead punt these to ZeroSizedHloElimination.
- !ShapeUtil::HasZeroElements(inst->shape());
+ !ShapeUtil::IsZeroElementArray(inst->shape());
};
std::vector<HloInstruction*> gather_instrs;
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 5ee67ccb4a..d9f62c21c4 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -74,7 +74,7 @@ GenericTransferManager::TransferLiteralFromDevice(
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
- if (!ShapeUtil::IsTuple(subshape)) {
+ if (ShapeUtil::IsArray(subshape)) {
TF_RETURN_IF_ERROR(TransferBufferFromDevice(
executor,
/*source=*/device_buffer.buffer(index),
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 5e02631a58..541a5275a3 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -238,6 +238,19 @@ cc_library(
)
cc_library(
+ name = "hlo_execution_profiler",
+ srcs = ["hlo_execution_profiler.cc"],
+ hdrs = ["hlo_execution_profiler.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:pool",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
name = "gpu_executable",
srcs = [
"conditional_thunk.cc",
@@ -278,6 +291,7 @@ cc_library(
":backend_configs",
":buffer_allocations",
":cudnn_convolution_runner",
+ ":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
":partition_assignment",
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index e0c73aa73a..f9dccd287d 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
}
// CuDNN does not accept zero-element arguments
- if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) ||
- ShapeUtil::HasZeroElements(conv->operand(1)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index afefc740d7..9d66648a40 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -260,6 +260,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
fusion.AddPass<GpuMultiOutputFusion>();
+ fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
+ /*only_fusion_computations=*/true);
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 25d8f720ea..f20a828bc1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -41,77 +41,6 @@ namespace {
using tensorflow::tracing::ScopedAnnotation;
-// A helper class for profiling HLO in the course of GPU program execution.
-// All of the profiling is guarded internally, to avoid the caller needing to
-// have lots of conditionals sprinkled around.
-class HloExecutionProfiler {
- public:
- // If profiling is enabled, start an execution timer running.
- explicit HloExecutionProfiler(
- bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
- const HloComputation* computation)
- : do_profile_(do_profile),
- profile_(profile),
- stream_(stream),
- sub_streams_(sub_streams),
- computation_(computation) {
- if (do_profile_) {
- clock_rate_ghz_ =
- stream->parent()->GetDeviceDescription().clock_rate_ghz();
- execution_timer_.reset(new se::Timer(stream->parent()));
- per_op_timer_.reset(new se::Timer(stream->parent()));
- stream->InitTimer(execution_timer_.get())
- .ThenStartTimer(execution_timer_.get());
- stream->InitTimer(per_op_timer_.get());
- }
- }
-
- // If profiling is enabled, sets the total cycle count on the profile from the
- // execution timer.
- void FinishExecution() {
- CHECK(!finished_execution_) << "Call FinishExecution only once!";
- finished_execution_ = true;
- if (do_profile_) {
- stream_->ThenWaitFor(&sub_streams_);
- stream_->ThenStopTimer(execution_timer_.get());
- stream_->BlockHostUntilDone().IgnoreError();
- profile_->set_total_cycles_executed(
- *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_);
- }
- }
-
- // If profiling is enabled, starts the per-operation timer.
- void StartOperation() {
- if (do_profile_) {
- stream_->ThenStartTimer(per_op_timer_.get());
- }
- }
-
- // If profiling is enabled, stops the per-operation timer and records the time
- // that the hlo_instruction took to execute in the profile.
- void FinishOperation(const HloInstruction* hlo_instruction) {
- if (do_profile_) {
- stream_->ThenWaitFor(&sub_streams_);
- stream_->ThenStopTimer(per_op_timer_.get());
- stream_->BlockHostUntilDone().IgnoreError();
- profile_->SetCyclesTakenBy(
- hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_);
- }
- }
-
- private:
- const bool do_profile_;
- double clock_rate_ghz_;
- HloExecutionProfile* profile_;
- se::Stream* stream_;
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
- const HloComputation* computation_;
- std::unique_ptr<se::Timer> execution_timer_;
- std::unique_ptr<se::Timer> per_op_timer_;
- bool finished_execution_ = false;
-};
-
} // namespace
// Implementation note: HLO profiling is always enabled for GPU executables,
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
new file mode 100644
index 0000000000..daddd3738e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -0,0 +1,82 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+HloExecutionProfiler::HloExecutionProfiler(
+ bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ const HloComputation* computation)
+ : do_profile_(do_profile),
+ profile_(profile),
+ stream_(stream),
+ sub_streams_(sub_streams),
+ computation_(computation) {
+ if (do_profile_) {
+ clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz();
+ execution_timer_.reset(new se::Timer(stream->parent()));
+ per_op_timer_.reset(new se::Timer(stream->parent()));
+ stream->InitTimer(execution_timer_.get())
+ .ThenStartTimer(execution_timer_.get());
+ stream->InitTimer(per_op_timer_.get());
+ }
+}
+
+void HloExecutionProfiler::FinishExecution() {
+ CHECK(!finished_execution_) << "Call FinishExecution only once!";
+ finished_execution_ = true;
+ if (do_profile_) {
+ stream_->ThenWaitFor(&sub_streams_);
+ stream_->ThenStopTimer(execution_timer_.get());
+ stream_->BlockHostUntilDone().IgnoreError();
+ profile_->set_total_cycles_executed(
+ *computation_,
+ static_cast<uint64>(execution_timer_->Nanoseconds() * clock_rate_ghz_));
+ }
+}
+
+void HloExecutionProfiler::StartOperation() {
+ if (do_profile_) {
+ stream_->ThenStartTimer(per_op_timer_.get());
+ }
+}
+
+void HloExecutionProfiler::FinishOperation(
+ const HloInstruction* hlo_instruction) {
+ if (do_profile_) {
+ stream_->ThenWaitFor(&sub_streams_);
+ stream_->ThenStopTimer(per_op_timer_.get());
+ stream_->BlockHostUntilDone().IgnoreError();
+ profile_->SetCyclesTakenBy(
+ hlo_instruction,
+ static_cast<uint64>(per_op_timer_->Nanoseconds() * clock_rate_ghz_));
+ }
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
new file mode 100644
index 0000000000..c9b882ff80
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A helper class for profiling HLO in the course of GPU program execution.
+// All of the profiling is guarded internally, to avoid the caller needing to
+// have lots of conditionals sprinkled around.
+class HloExecutionProfiler {
+ public:
+ // If profiling is enabled, start an execution timer running.
+ explicit HloExecutionProfiler(
+ bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ const HloComputation* computation);
+
+ // If profiling is enabled, sets the total cycle count on the profile from the
+ // execution timer.
+ void FinishExecution();
+
+ // If profiling is enabled, starts the per-operation timer.
+ void StartOperation();
+
+ // If profiling is enabled, stops the per-operation timer and records the time
+ // that the hlo_instruction took to execute in the profile.
+ void FinishOperation(const HloInstruction* hlo_instruction);
+
+ private:
+ const bool do_profile_;
+ double clock_rate_ghz_;
+ HloExecutionProfile* profile_;
+ se::Stream* stream_;
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
+ const HloComputation* computation_;
+ std::unique_ptr<se::Timer> execution_timer_;
+ std::unique_ptr<se::Timer> per_op_timer_;
+ bool finished_execution_ = false;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 061210352c..e303999c63 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -202,7 +202,7 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
<< " of " << hlo.ToString();
llvm_ir::IrArray ir_array(base_ptr,
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
- alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
+ alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index);
// The GPU backend emits one kernel per top-level HLO, and LLVM views
// execution of one kernel as the "whole program" executed on the GPU.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 67890bfed1..388aa35d7d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -56,8 +56,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape) &&
- !ShapeUtil::HasZeroElements(lhs_shape) &&
- !ShapeUtil::HasZeroElements(rhs_shape);
+ !ShapeUtil::IsZeroElementArray(lhs_shape) &&
+ !ShapeUtil::IsZeroElementArray(rhs_shape);
}
bool DotImplementedAsGemm(const HloInstruction& dot) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 547af33e9a..7b7dd673a5 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -610,7 +610,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
- if (ShapeUtil::HasZeroElements(convolution->shape())) {
+ if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
// Emit no code for an empty output.
return Status::OK();
}
@@ -620,7 +620,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
}
Status IrEmitter::HandleFft(HloInstruction* fft) {
- if (ShapeUtil::HasZeroElements(fft->shape())) {
+ if (ShapeUtil::IsZeroElementArray(fft->shape())) {
// Emit no code for an empty output.
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 726434c3df..ccbd99a042 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1493,21 +1493,21 @@ Status IrEmitterUnnested::EmitRowReduction(
// x + (x_tile_size - 1) * warpSize < width) {
// // The entire x_tile is in bounds.
// for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
- // ++element_id_in_z_tile) {
+ // ++element_id_in_z_tile) {
// z = z_in_tiles * z_tile_size + element_id_in_z_tile;
- // for (int element_id_in_x_tile = 0;element_id_in_x_tile < x_tile_size;
- // ++element_id_in_x_tile, x += warpSize) {
+ // for (int element_id_in_x_tile = 0;
+ // element_id_in_x_tile < x_tile_size;
+ // ++element_id_in_x_tile, x += warpSize) {
// partial_result = Reducer(partial_result, input[z][y][x]);
// }
// }
// } else {
// // The tile is partially in bounds.
// for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
- // ++element_id_in_z_tile) {
+ // ++element_id_in_z_tile) {
// z = z_in_tiles * z_tile_size + element_id_in_z_tile;
// for (int element_id_in_x_tile = 0; element_id_in_x_tile <
- // x_tile_size;
- // ++element_id_in_tile, x += warpSize) {
+ // x_tile_size; ++element_id_in_tile, x += warpSize) {
// if (x < width)
// partial_result = Reducer(partial_result, input[z][y][x]);
// }
@@ -1558,8 +1558,7 @@ Status IrEmitterUnnested::EmitRowReduction(
x_tile, ir_builder_.getInt64(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
- // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id *
- // x_tile_size);
+ // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
llvm::Value* last_x = ir_builder_.CreateNSWAdd(
lane_id, ir_builder_.CreateNSWMul(
ir_builder_.getInt64(kWarpSize),
@@ -1586,8 +1585,8 @@ Status IrEmitterUnnested::EmitRowReduction(
"x_tile",
/*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1,
[&](llvm::Value* x_indvar) -> Status {
- // x = lane_id + warpSize * (element_id_in_x_tile + warp_id *
- // x_tile_size);
+ // x = lane_id +
+ // warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
llvm::Value* x = ir_builder_.CreateNSWAdd(
lane_id,
ir_builder_.CreateNSWMul(
@@ -2206,6 +2205,10 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
return Status::OK();
}
+Status IrEmitterUnnested::HandleGenerateToken(HloInstruction* gen_token) {
+ return Status::OK();
+}
+
Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
thunk_sequence_->emplace_back(BuildInfeedThunk(infeed));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 202231b82f..d228be81d4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleGenerateToken(HloInstruction* gen_token) override;
Status EmitTargetElementLoop(
const HloInstruction& hlo,
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index e3f444a126..d541776f00 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -81,22 +81,6 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
get_element_shape(element_instr_2));
}
-bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
- // kConstant instruction will not have memory reads, so it won't be a profit
- // source. Skip them.
- if (instr->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsEffectiveScalar(instr->shape())) {
- return false;
- }
- // We don't target to fuse producer/consumer instructions -- this should
- // be taken care of by the instruction_fusion pass. If instr has only
- // one user, it will not have sibling instructions. We won't consider it.
- if (instr->user_count() < 2) {
- return false;
- }
- return true;
-}
-
namespace {
bool IsReduction(HloInstruction* instr) {
if (instr->IsMultiOutputFusion()) {
@@ -116,7 +100,13 @@ bool IsReduction(HloInstruction* instr) {
} // namespace
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
- return IsReduction(instr);
+ // We can fuse reduces and loop fusions.
+ return IsReduction(instr) ||
+ (instr->opcode() == HloOpcode::kFusion &&
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ // TODO(b/110202584): bitcasts make nested fusions, GPU has no support
+ // for nested fusions.
+ instr->fused_expression_root()->opcode() != HloOpcode::kBitcast);
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
@@ -140,5 +130,22 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
return profit;
}
+bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) {
+ return false;
+ }
+ // If we're fusing fusions only do it if the fusion kind matches. Loop fusions
+ // merge into bigger loop fusions and input (reduce) fusions become fusions
+ // with multiple reduce outputs. We could fuse reduce and loop fusions
+ // together too (the result being an input fusion) if we find cases where this
+ // improves things.
+ CHECK(instr1->opcode() == HloOpcode::kFusion);
+ if (instr2->opcode() == HloOpcode::kFusion) {
+ return instr1->fusion_kind() == instr2->fusion_kind();
+ }
+ return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index 5451a93cec..16db0e0f02 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -43,10 +43,8 @@ class GpuMultiOutputFusion : public MultiOutputFusion {
// estimated as the size of the common operands b/w instr1 and instr2.
int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override;
- // Whether fusing the instruction can reduce memory reads.
- //
- // TODO(tjoerg): Move this method up into the MultiOutputFusion base class.
- bool IsProfitableOperand(HloInstruction* instr) override;
+ // Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
+ bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 924cfb11f3..5e7ceb7976 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -226,5 +226,34 @@ TEST_F(InstructionFusionTest,
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
}
+TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[6400]{0} parameter(0)
+ const.2 = f32[] constant(1)
+ ROOT div = f32[6400]{0} divide(p0.2, const.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Divide()));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 5dba50a63b..a04aa4069d 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -26,7 +26,8 @@ namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
-StatusOr<int64> MinimumMemoryForModule(
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const LogicalBuffer::SizeFunction& size_function) {
if (module_sequence.empty()) {
@@ -49,15 +50,19 @@ StatusOr<int64> MinimumMemoryForModule(
return result.heap_size;
}
-StatusOr<int64> MinimumMemoryForComputation(
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
const HloComputation& computation,
const std::vector<const HloInstruction*>& sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function));
+ sequence, points_to_analysis, size_function,
+ HeapSimulator::Options(), memory_by_computation));
return result.heap_size;
}
@@ -81,9 +86,11 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn, const Options& options) {
+ const BufferValue::SizeFunction& size_fn, const Options& options,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr);
+ /*module_sequence=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -254,6 +261,12 @@ Status HeapSimulator::RunComputation(
Alloc(buffer, instruction);
}
}
+ // Account for the memory used by subcomputations when estimating the
+ // current heap size.
+ if (memory_by_computation_ != nullptr) {
+ algorithm_->AccountForSubcomputationMemory(instruction,
+ *memory_by_computation_);
+ }
// If the whole module is sequential, we can save memory by running the
// heap-simulation for sub-computations inline. E.g. the buffers for the
@@ -321,12 +334,15 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence)
+ const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation)
: no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence) {
+ module_sequence_(module_sequence),
+ memory_by_computation_(memory_by_computation) {
debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
}
@@ -495,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
+void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {
+ // We only count the memory usage of the largest subcomputation, instead of
+ // adding them all, because subcomputations won't execute in parallel.
+ int64 max_subcomputation_bytes = 0;
+ for (const auto* c : instruction->called_computations()) {
+ auto it = memory_by_computation.find(c);
+ if (it != memory_by_computation.end()) {
+ int64 subcomputation_bytes = it->second;
+ if (subcomputation_bytes > max_subcomputation_bytes) {
+ max_subcomputation_bytes = subcomputation_bytes;
+ }
+ }
+ }
+ max_heap_size_ =
+ std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
+}
+
void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
current_heap_size_ -= size;
}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 3be3bb8e7f..811a6042df 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -34,21 +34,6 @@ limitations under the License.
namespace xla {
-// Returns the minimum memory required to compute an HLO module where all
-// computations have been scheduled (represented by the given module_sequence),
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const LogicalBuffer::SizeFunction& size_function);
-
-// Returns the minimum memory required to compute the given computation,
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function);
-
// Forward declare classes defined below.
class HeapAlgorithm;
@@ -100,6 +85,23 @@ class HeapSimulator {
const BufferValueFlatSet* buffers_to_assign;
};
+ // Returns the minimum memory required to compute an HLO module where all
+ // computations have been scheduled (represented by the given
+ // module_sequence), assuming no fragmentation.
+ static StatusOr<int64> MinimumMemoryForModule(
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const LogicalBuffer::SizeFunction& size_function);
+
+ // Returns the minimum memory required to compute the given computation,
+ // assuming no fragmentation.
+ static StatusOr<int64> MinimumMemoryForComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
+
// Run the heap simulation with the given algorithm, assuming the given
// module_sequence, which must contain a topologically-consistent total
// ordering of all instructions within each computation. The result is invalid
@@ -126,7 +128,9 @@ class HeapSimulator {
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
- const Options& options = Options());
+ const Options& options = Options(),
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
private:
// If 'module_sequence' is non-null, it is used to find kCall and kWhile
@@ -135,7 +139,9 @@ class HeapSimulator {
HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence);
+ const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
~HeapSimulator();
Status RunComputation(
@@ -159,7 +165,13 @@ class HeapSimulator {
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
+ // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+ // set by hlo scheduling. Then, in RunComputation, we check both in order to
+ // handle subcomputations. It would be good to unify the handling of
+ // subcomputations, but it's not clear how.
const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation_;
// In addition to Alloc and Free, the heap simulator exposes a concept of
// buffer sharing. When ShareBuffer is called, instead of allocating new
@@ -204,6 +216,11 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
+ virtual void AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {}
+
// Free de-allocates a previously allocated buffer.
virtual void Free(const BufferValue* buffer, int64 size) = 0;
@@ -222,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
~NoFragmentationStatsHeap() override = default;
void Alloc(const BufferValue* buffer, int64 size) override;
+
+ void AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) override;
+
void Free(const BufferValue* buffer, int64 size) override;
+
Result Finish() override;
private:
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 309ab85f78..93d7a14125 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -89,7 +89,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
cond_lt};
module_sequence[body_computation] = {body_param};
module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie());
+ EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
+ .ValueOrDie());
}
const char kAlloc[] = "Alloc";
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index a88283ed9a..0a948cc390 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -493,6 +493,16 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
bool HloAliasAnalysis::HasLiveRangeInterference(
const HloOrdering& ordering) const {
for (const HloBuffer& buffer : buffers()) {
+ CHECK(!buffer.values().empty());
+ if (ShapeUtil::IsToken(buffer.values().front()->shape())) {
+ // Tokens have no on-device representation and cannot interfere.
+ for (const HloValue* value : buffer.values()) {
+ // If one of the values is a token, all values must be a token.
+ DCHECK(ShapeUtil::IsToken(value->shape()));
+ }
+ continue;
+ }
+
// Check that the values in the buffer are totally ordered with respect to
// 'ordering'. Begin by sorting the values with respect to 'ordering' with a
// tie-break using value ID. The tie-break is necessary because we need a
@@ -517,7 +527,6 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
// a buffer and A interferes with C, then necessarily A also interferes
// with B. So to check interference you only need to check interference
// between A and B, and between B and C.
- CHECK(!values.empty());
for (int i = 1; i < values.size(); ++i) {
if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) {
VLOG(1) << values[i - 1]->ToShortString() << " and "
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index b158f44923..ef8bb030fb 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -234,7 +234,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
auto inst_it = instruction_iterators_.at(instruction);
(*inst_it)->set_parent(nullptr);
- instruction->DetachFromOperands();
instructions_.erase(inst_it);
return Status::OK();
}
@@ -524,21 +523,7 @@ HloInstruction* HloComputation::CreateFusionInstruction(
StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
- if (ShapeUtil::IsArray(instruction->shape())) {
- if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
- // Use kCopy to copy array elements
- HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- if (copies_added != nullptr) {
- *copies_added->mutable_element(*index) = copy;
- }
- return copy;
- } else {
- // Array elements which are not to be copied are passed through
- // transparently.
- return instruction;
- }
- } else if (ShapeUtil::IsTuple(instruction->shape())) {
+ if (ShapeUtil::IsTuple(instruction->shape())) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
i++) {
@@ -555,9 +540,27 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
index->pop_back();
}
return AddInstruction(HloInstruction::CreateTuple(elements));
+ }
+ if (ShapeUtil::IsToken(instruction->shape())) {
+ // Tokens have no on-device representation and cannot be copied. Pass
+ // through transparently.
+ return instruction;
+ }
+
+ // Array shape.
+ TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape()));
+ if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
+ // Use kCopy to copy array elements
+ HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ if (copies_added != nullptr) {
+ *copies_added->mutable_element(*index) = copy;
+ }
+ return copy;
} else {
- return FailedPrecondition(
- "Can only copy array and tuple shaped instructions");
+ // Elements which are not to be copied are passed through
+ // transparently.
+ return instruction;
}
}
@@ -863,15 +866,6 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
}
}
context->MapComputation(this, result.get());
- // We cloned the elements of 'replacements', so they're all going to be
- // destroyed. HloInstructions need to be detached from their operands before
- // they're destroyed, otherwise they stick around in the operands' users lists
- // and cause use-after-frees.
- for (auto& kv : replacements) {
- if (std::unique_ptr<HloInstruction>& new_instr = kv.second) {
- new_instr->DetachFromOperands();
- }
- }
return result;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 25469a54c4..3f59d31bb9 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -371,6 +371,38 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
}
}
+TEST_F(HloComputationTest, DeepCopyToken) {
+ // Test that DeepCopyInstruction properly handles tokens which should not be
+ // copied.
+ auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
+
+ // No copy should be added.
+ EXPECT_THAT(copy, op::GenerateToken());
+}
+
+TEST_F(HloComputationTest, DeepCopyTokenTuple) {
+ // Test that DeepCopyInstruction properly handles tokens which should not be
+ // copied.
+ auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
+
+ // Only the array (second tuple element) should be copied. The token is passed
+ // through transparently.
+ EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple),
+ op::Copy(op::GetTupleElement(tuple))));
+}
+
TEST_F(HloComputationTest, CycleDetection) {
// Test whether the visitor can detect cycles in the graph.
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index e0648e1467..33424019b9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -300,12 +300,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
instruction->CloneWithNewOperands(instruction->shape(), operands);
auto result = Evaluate(cloned_instruction.get());
- // Clean up our cloned instructions before returning.
- cloned_instruction->DetachFromOperands();
- for (auto& operand : owned_operands) {
- operand->DetachFromOperands();
- }
-
return result;
}
@@ -321,7 +315,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
rhs_instr.get());
auto result = Evaluate(cloned_instruction.get());
- cloned_instruction->DetachFromOperands();
return result;
}
@@ -334,7 +327,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
auto result = Evaluate(cloned_instruction.get());
- cloned_instruction->DetachFromOperands();
return result;
}
@@ -372,7 +364,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
- CHECK(!ShapeUtil::IsTuple(reference_shape));
+ CHECK(ShapeUtil::IsArray(reference_shape));
const int64 rank = ShapeUtil::Rank(reference_shape);
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
@@ -383,7 +375,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
- CHECK(!ShapeUtil::IsTuple(operand_shape));
+ CHECK(ShapeUtil::IsArray(operand_shape));
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
@@ -911,10 +903,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
}
Status HloEvaluator::HandleGenerateToken(HloInstruction* token) {
- // Literals cannot represent a TOKEN shape so just create an empty tuple as
- // the "result" of the kGenerateToken operation.
- // TODO(b/109929053): Add support for TOKENs in Literals.
- evaluated_[token] = Literal::MakeTuple({});
+ evaluated_[token] = Literal::CreateToken();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 13f46407e3..bc7340aa03 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -778,7 +778,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override {
CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
- CHECK(!ShapeUtil::IsTuple(select->shape()));
+ CHECK(ShapeUtil::IsArray(select->shape()));
std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
[](bool pred, ReturnT on_true, ReturnT on_false) {
if (pred) {
@@ -1103,7 +1103,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
Status HandlePad(HloInstruction* pad) override {
- CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
+ CHECK(ShapeUtil::IsArray(pad->operand(0)->shape()));
// Padding value must be scalar.
CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
@@ -1116,7 +1116,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
/*padding_config=*/pad->padding_config()));
CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
// Create new HLO of padded shape with padding value.
@@ -1182,7 +1182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
dynamic_slice->dynamic_slice_sizes()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
TF_RET_CHECK(
primitive_util::IsIntegralType(start_indices->shape().element_type()));
@@ -1237,7 +1237,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
operand->shape(), update->shape(), start_indices->shape()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
TF_RET_CHECK(
primitive_util::IsIntegralType(start_indices->shape().element_type()));
@@ -1393,7 +1393,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
@@ -1613,7 +1613,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
<< "return shape is set to: "
<< ShapeUtil::HumanStringWithLayout(reduce_window->shape())
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanStringWithLayout(inferred_return_shape);
const Literal& operand_literal =
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 28fc6c4209..ab224021c5 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -832,13 +832,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
- if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) {
+ if (ShapeUtil::IsZeroElementArray(shape)) {
return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements.
optional<int64> elem_count;
- if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
+ if (ShapeUtil::IsArray(shape)) {
elem_count = 1;
for (int64 dim : shape.dimensions()) {
*elem_count *= dim;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 39662d1735..0b4dd6412f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -178,10 +178,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kConstant: {
- CHECK(proto.has_literal());
- TF_ASSIGN_OR_RETURN(auto literal,
- Literal::CreateFromProto(proto.literal()));
- instruction = CreateConstant(std::move(literal));
+ // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
+ if (proto.has_literal()) {
+ TF_ASSIGN_OR_RETURN(auto literal,
+ Literal::CreateFromProto(proto.literal()));
+ instruction = CreateConstant(std::move(literal));
+ } else {
+ instruction = MakeUnique<HloConstantInstruction>(proto.shape());
+ }
break;
}
case HloOpcode::kTrace: {
@@ -243,6 +247,28 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
CreateReducePrecision(proto.shape(), operands(0),
proto.exponent_bits(), proto.mantissa_bits());
break;
+ case HloOpcode::kInfeed:
+ instruction = CreateInfeed(proto.shape(), proto.infeed_config());
+ break;
+ case HloOpcode::kOutfeed:
+ instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
+ proto.outfeed_config());
+ break;
+ case HloOpcode::kCrossReplicaSum: {
+ CHECK_EQ(proto.called_computation_ids_size(), 1);
+ std::vector<HloInstruction*> all_operands(proto.operand_ids_size());
+ c_transform(proto.operand_ids(), all_operands.begin(),
+ [&instruction_map](int64 operand_id) {
+ return instruction_map.at(operand_id);
+ });
+ instruction = CreateCrossReplicaSum(
+ proto.shape(), all_operands, computations(0),
+ /*replica_group_ids=*/
+ std::vector<int64>(proto.replica_group_ids().begin(),
+ proto.replica_group_ids().end()),
+ /*barrier=*/"");
+ break;
+ }
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -293,10 +319,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->padding_config_ =
MakeUnique<PaddingConfig>(proto.padding_config());
}
- instruction->outfeed_config_ = proto.outfeed_config();
- instruction->infeed_config_ = proto.infeed_config();
instruction->custom_call_target_ = proto.custom_call_target();
- instruction->outfeed_shape_ = proto.outfeed_shape();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -315,10 +338,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->channel_name_ = proto.channel_name();
instruction->cost_estimate_ns_ = proto.cost_estimate_ns();
- for (int64 replica_group_id : proto.replica_group_ids()) {
- instruction->replica_group_ids_.push_back(replica_group_id);
- }
-
return std::move(instruction);
}
@@ -531,40 +550,21 @@ HloInstruction::CreateCrossReplicaSum(
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& channel_id) {
- // TODO(b/79737069): Remove the CHECK when supported.
- CHECK(!channel_id.has_value());
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->called_computations_.push_back(reduce_computation);
- instruction->replica_group_ids_.assign(replica_group_ids.begin(),
- replica_group_ids.end());
- instruction->cross_replica_sum_barrier_ = std::string(barrier);
- return instruction;
+ const tensorflow::gtl::optional<int64>& all_reduce_id) {
+ return MakeUnique<HloAllReduceInstruction>(
+ shape, operands, reduce_computation, replica_group_ids, barrier,
+ all_reduce_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& shape, const string& config) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
- instruction->set_infeed_config(config);
- return instruction;
+ return MakeUnique<HloInfeedInstruction>(shape, config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config) {
- std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
- CHECK(ShapeUtil::Compatible(operand->shape(), shape))
- << "Outfeed shape " << shape << " must be compatible with operand shape "
- << operand->shape();
- instruction->AppendOperand(operand);
- instruction->outfeed_config_ = std::string(outfeed_config);
- instruction->outfeed_shape_ = shape;
- return instruction;
+ return MakeUnique<HloOutfeedInstruction>(shape, operand, outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
@@ -1040,6 +1040,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1136,11 +1139,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateDot(shape, new_operands[0], new_operands[1],
*dot_dimension_numbers_);
break;
- case HloOpcode::kCrossReplicaSum:
- clone =
- CreateCrossReplicaSum(shape, new_operands, to_apply(),
- replica_group_ids_, cross_replica_sum_barrier_);
- break;
case HloOpcode::kPad:
CHECK_EQ(new_operands.size(), 2);
clone =
@@ -1179,14 +1177,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone =
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
break;
- case HloOpcode::kInfeed:
- CHECK_EQ(new_operands.size(), 0);
- clone = CreateInfeed(shape, infeed_config());
- break;
- case HloOpcode::kOutfeed:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
- break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), 3);
clone = CreateConditional(shape, new_operands[0], new_operands[1],
@@ -1222,7 +1212,29 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return clone;
}
-HloInstruction::~HloInstruction() {}
+HloInstruction::~HloInstruction() {
+ // Detach from operands. An instruction may be repeated as an operand. To
+ // avoid calling RemoveUser twice on the same operand, check before remove.
+ for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
+ HloInstruction* operand = operands_[operand_num];
+ if (operand == nullptr) {
+ continue;
+ }
+ if (operand->user_set_.find(this) != operand->user_set_.end()) {
+ operand->RemoveUser(this);
+ }
+ operands_[operand_num] = nullptr;
+ }
+
+ // Update users. Set `nullptr` to the correpsonding operand slot for users.
+ for (auto& user : this->users()) {
+ for (int i = 0; i < user->operand_count(); ++i) {
+ if (user->operands_[i] == this) {
+ user->operands_[i] = nullptr;
+ }
+ }
+ }
+}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
const string& suffix, HloCloneContext* context) const {
@@ -1505,8 +1517,6 @@ bool HloInstruction::IdenticalSlowPath(
eq_computations(false_computation(), other.false_computation());
// These opcodes are not yet supported.
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
case HloOpcode::kSort:
case HloOpcode::kHostCompute:
return false;
@@ -1535,6 +1545,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -1621,22 +1633,6 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
return Status::OK();
}
-void HloInstruction::DetachFromOperands() {
- VLOG(3) << "DetachFromOperands:\n " << ToString();
- CHECK_EQ(0, user_count());
- // An instruction may be repeated as an operand. To avoid calling RemoveUser
- // twice on the same operand, keep a set of already detached operands.
- std::set<HloInstruction*> detached_operands;
- for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
- HloInstruction* operand = operands_[operand_num];
- if (!ContainsKey(detached_operands, operand)) {
- operand->RemoveUser(this);
- detached_operands.insert(operand);
- }
- operands_[operand_num] = nullptr;
- }
-}
-
HloComputation* HloInstruction::to_apply() const {
switch (opcode_) {
case HloOpcode::kCall:
@@ -1661,6 +1657,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
@@ -1675,11 +1672,6 @@ const string& HloInstruction::custom_call_target() const {
return custom_call_target_;
}
-const string& HloInstruction::outfeed_config() const {
- CHECK_EQ(opcode_, HloOpcode::kOutfeed);
- return outfeed_config_;
-}
-
HloComputation* HloInstruction::while_condition() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return called_computations_[kConditionComputationIndex];
@@ -1901,6 +1893,11 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ // If operand is already been deleted, put `null` to the string output.
+ if (operand == nullptr) {
+ StrAppend(out, "null ");
+ return;
+ }
std::vector<string> str;
if (options.print_operand_shape()) {
str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
@@ -2008,6 +2005,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
@@ -2036,25 +2034,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}),
"}"));
}
- if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) {
- extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\""));
- }
- if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) {
- extra.push_back(
- StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
- }
if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
"\", entry=", operand_side_metadata_->ToString(),
", exit=", user_side_metadata_->ToString(), "}"));
}
- if (!replica_group_ids().empty()) {
- extra.push_back(
- StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}"));
- }
- if (!cross_replica_sum_barrier().empty()) {
- extra.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
- }
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
@@ -2125,10 +2109,7 @@ HloInstructionProto HloInstruction::ToProto() const {
if (padding_config_ != nullptr) {
*proto.mutable_padding_config() = *padding_config_;
}
- proto.set_outfeed_config(outfeed_config_);
- proto.set_infeed_config(infeed_config_);
proto.set_custom_call_target(custom_call_target_);
- *proto.mutable_outfeed_shape() = outfeed_shape_;
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
@@ -2136,9 +2117,6 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_channel_name(channel_name_);
proto.set_cost_estimate_ns(cost_estimate_ns_);
- for (int64 replica_group_id : replica_group_ids_) {
- proto.add_replica_group_ids(replica_group_id);
- }
return proto;
}
@@ -2629,12 +2607,6 @@ Status HloInstruction::AcceptOrdered(
return visitor->FinishVisit(this);
}
-const Shape& HloInstruction::outfeed_shape() const {
- DCHECK_EQ(opcode_, HloOpcode::kOutfeed);
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
- return outfeed_shape_;
-}
-
const Shape& HloInstruction::shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
return shape_;
@@ -3168,4 +3140,38 @@ int32 HloInstruction::exponent_bits() const {
int32 HloInstruction::mantissa_bits() const {
return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
}
+
+string HloInstruction::infeed_config() const {
+ return Cast<HloInfeedInstruction>(this)->infeed_config();
+}
+
+void HloInstruction::set_infeed_config(const string& config) {
+ return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
+}
+
+const Shape& HloInstruction::outfeed_shape() const {
+ return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
+}
+
+const string& HloInstruction::outfeed_config() const {
+ return Cast<HloOutfeedInstruction>(this)->outfeed_config();
+}
+
+const std::vector<int64>& HloInstruction::replica_group_ids() const {
+ return Cast<HloAllReduceInstruction>(this)->replica_group_ids();
+}
+
+string HloInstruction::cross_replica_sum_barrier() const {
+ return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
+}
+
+void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
+ return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
+ barrier);
+}
+
+tensorflow::gtl::optional<int64> HloInstruction::all_reduce_id() const {
+ return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index a206cdab27..8a0ffc21cd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -435,9 +435,9 @@ class HloInstruction {
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
- // `channel_id`: for Allreduce nodes from different models, if they have the
- // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
- // applied cross models.
+ // `all_reduce_id`: for Allreduce nodes from different modules, if they have
+ // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
+ // not be applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
@@ -445,7 +445,7 @@ class HloInstruction {
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& channel_id =
+ const tensorflow::gtl::optional<int64>& all_reduce_id =
tensorflow::gtl::nullopt);
// Creates a conversion instruction, where operand is the data to convert and
@@ -824,13 +824,6 @@ class HloInstruction {
// root to new_producer.
Status ReplaceAllUsesWith(HloInstruction* new_producer);
- // Detaches an instruction from its operands. That is, remove the instruction
- // from each operand's user set. This should only be called prior to
- // deallocating the instruction.
- //
- // TODO(b/78305363): Make this automatic when deleting an instruction.
- void DetachFromOperands();
-
// Performs a postorder DFS visit using this node as the root. If
// call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
// complete. If ignore_control_predecessors is true, instructions only
@@ -907,14 +900,6 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kCustomCall
const string& custom_call_target() const;
- // Returns the config for the Outfeed instruction.
- // Precondition: opcode() == HloOpcode::kOutfeed
- const string& outfeed_config() const;
-
- // Returns the shape for the Outfeed instruction.
- // Precondition: opcode() == HloOpcode::kOutfeed
- const Shape& outfeed_shape() const;
-
// Gets/sets the while_condition or while_body HloComputation for While. The
// setters should only be called by HloModule or HloComputation methods.
//
@@ -988,12 +973,6 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kHostCompute
string channel_name() const { return channel_name_; }
- // Returns the infeed configuration string. The infeed configuration includes
- // any metadata needed for the backend compiler (e.g., infeed buffer address)
- // and is target-dependent.
- string infeed_config() const { return infeed_config_; }
- void set_infeed_config(const string& config) { infeed_config_ = config; }
-
// Returns true if this instruction is fused, ie contained within a fusion
// instruction.
bool IsFused() const;
@@ -1060,6 +1039,19 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
+ // TODO(b/80249101): Remove these methods once HLO scheduling and copy
+ // insertion are integrated, and we don't need to run a separate pass
+ // of copy elision anymore.
+ bool CopyElisionAllowed() const {
+ CHECK_EQ(HloOpcode::kCopy, opcode_);
+ return copy_elision_allowed_;
+ }
+
+ void SetCopyElisionAllowed(bool value) {
+ CHECK_EQ(HloOpcode::kCopy, opcode_);
+ copy_elision_allowed_ = value;
+ }
+
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -1422,26 +1414,34 @@ class HloInstruction {
// Delegates to HloGetTupleElementInstruction::tuple_index.
int64 tuple_index() const;
- // Returns the number of exponent bits for a reduce-precision node.
+ // Delegates to HloReducePrecisionInstruction::exponent_bits.
int32 exponent_bits() const;
- // Returns the number of mantissa bits for a reduce-precision node.
+ // Delegates to HloReducePrecisionInstruction::mantissa_bits.
int32 mantissa_bits() const;
- // Old methods kept for smooth subclassing transition END.
- // Returns the group ids of each replica for CrossReplicaSum op.
- const std::vector<int64>& replica_group_ids() const {
- return replica_group_ids_;
- }
+ // Delegates to HloInfeedInstruction::infeed_config.
+ string infeed_config() const;
- // Returns the barrier config used for the CrossReplicaSum implementation of
- // each backend.
- string cross_replica_sum_barrier() const {
- return cross_replica_sum_barrier_;
- }
- void set_cross_replica_sum_barrier(string barrier) {
- cross_replica_sum_barrier_ = barrier;
- }
+ // Delegates to HloInfeedInstruction::set_infeed_config.
+ void set_infeed_config(const string& config);
+
+ // Returns the config for the Outfeed instruction.
+ const string& outfeed_config() const;
+
+ // Returns the shape for the Outfeed instruction.
+ const Shape& outfeed_shape() const;
+
+ // Delegates to HloAllReduceInstruction::replica_group_ids.
+ const std::vector<int64>& replica_group_ids() const;
+
+ // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
+ string cross_replica_sum_barrier() const;
+ void set_cross_replica_sum_barrier(const string& barrier);
+
+ // Delegates to HloAllReduceInstruction::all_reduce_id.
+ tensorflow::gtl::optional<int64> all_reduce_id() const;
+ // Old methods kept for smooth subclassing transition END.
protected:
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
@@ -1555,9 +1555,6 @@ class HloInstruction {
// The computation in which this instruction is contained.
HloComputation* parent_ = nullptr;
- // Shape of outfeed request.
- Shape outfeed_shape_;
-
// Result shape of this instruction.
Shape shape_;
@@ -1573,6 +1570,9 @@ class HloInstruction {
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_window_bounds_;
+ // Used to tag kCopy instructions that are eligible for copy elision.
+ bool copy_elision_allowed_ = true;
+
// Describes the [start, start + size) range size for a dynamic slice
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
@@ -1616,28 +1616,16 @@ class HloInstruction {
kFalseComputationIndex = 1,
};
- // Outfeed configuration information, only present for kOutfeed.
- string outfeed_config_;
-
// A trace instruction that consumes this instruction.
//
// Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
// an operand.
HloInstruction* trace_instruction_ = nullptr;
- // The string representation of the infeed configuration.
- string infeed_config_;
-
// The backend-specific configuration for how a backend should compile this
// HLO. See the documentation on backend_config().
string backend_config_;
- // The group id of each replica for CrossReplicaSum.
- std::vector<int64> replica_group_ids_;
-
- // The string representation of the barrier config used for CrossReplicaSum.
- string cross_replica_sum_barrier_;
-
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index d326d5d009..5871a6605f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -24,6 +24,7 @@ limitations under the License.
namespace xla {
namespace {
+using ::tensorflow::str_util::CEscape;
using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@@ -268,6 +269,68 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
Cast<HloRecvInstruction>(new_operands[0]));
}
+HloAllReduceInstruction::HloAllReduceInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
+ const tensorflow::gtl::optional<int64>& all_reduce_id)
+ : HloInstruction(HloOpcode::kCrossReplicaSum, shape),
+ replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
+ cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
+ all_reduce_id_(all_reduce_id) {
+ // TODO(b/79737069): Remove the CHECK when supported.
+ CHECK(!all_reduce_id_.has_value());
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloAllReduceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 i : replica_group_ids_) {
+ proto.add_replica_group_ids(i);
+ }
+ // TODO(b/79737069): handle barrier and all_reduce_id.
+ return proto;
+}
+
+std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result = {
+ StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")};
+ if (!cross_replica_sum_barrier().empty()) {
+ result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+ }
+ if (all_reduce_id_.has_value()) {
+ result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
+ }
+ return result;
+}
+
+bool HloAllReduceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
+ return replica_group_ids() == casted_other.replica_group_ids() &&
+ eq_computations(to_apply(), casted_other.to_apply()) &&
+ cross_replica_sum_barrier() ==
+ casted_other.cross_replica_sum_barrier() &&
+ all_reduce_id() == casted_other.all_reduce_id();
+}
+
+std::unique_ptr<HloInstruction>
+HloAllReduceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return MakeUnique<HloAllReduceInstruction>(
+ shape, new_operands, to_apply(), replica_group_ids(),
+ cross_replica_sum_barrier(), all_reduce_id());
+}
+
HloReverseInstruction::HloReverseInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)
@@ -609,9 +672,14 @@ HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
: HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
literal_(std::move(literal)) {}
+HloConstantInstruction::HloConstantInstruction(const Shape& shape)
+ : HloInstruction(HloOpcode::kConstant, shape) {}
+
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_literal() = literal_->ToProto();
+ if (literal_ != nullptr) {
+ *proto.mutable_literal() = literal_->ToProto();
+ }
return proto;
}
@@ -657,8 +725,9 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
- options.print_large_constants()) {
+ if (literal_ != nullptr &&
+ ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
+ options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
// lines. Compact this into one line by stripping out white space.
string tmp = literal().ToString();
@@ -830,10 +899,8 @@ void HloFusionInstruction::MergeFusionInstruction(
// Fuse 'unfused_instructions' into 'this'.
for (auto& instruction : unfused_instructions) {
FuseInstruction(instruction);
- instruction->DetachFromOperands();
}
CHECK_EQ(0, cloned_fusion->user_count());
- cloned_fusion->DetachFromOperands();
TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
cloned_fusion->fused_instructions_computation()));
}
@@ -1284,4 +1351,82 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], exponent_bits(), mantissa_bits());
}
+HloInfeedInstruction::HloInfeedInstruction(const Shape& shape,
+ const string& config)
+ : HloInstruction(HloOpcode::kInfeed, shape), infeed_config_(config) {}
+
+HloInstructionProto HloInfeedInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_infeed_config(infeed_config_);
+ return proto;
+}
+
+std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (infeed_config_.empty()) {
+ return {};
+ }
+ return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
+}
+
+bool HloInfeedInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 0);
+ return MakeUnique<HloInfeedInstruction>(shape, infeed_config());
+}
+
+HloOutfeedInstruction::HloOutfeedInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config)
+ : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()),
+ outfeed_shape_(shape),
+ outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ CHECK(ShapeUtil::Compatible(operand->shape(), shape))
+ << "Outfeed shape " << shape << " must be compatible with operand shape "
+ << operand->shape();
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloOutfeedInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_outfeed_config(outfeed_config());
+ *proto.mutable_outfeed_shape() = outfeed_shape();
+ return proto;
+}
+
+std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (outfeed_config_.empty()) {
+ return {};
+ }
+ return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")};
+}
+
+bool HloOutfeedInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
+ outfeed_config());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 6749d87555..04df2d860e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -207,6 +207,63 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
HloCloneContext* context) const override;
};
+class HloAllReduceInstruction : public HloInstruction {
+ public:
+ explicit HloAllReduceInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
+ const tensorflow::gtl::optional<int64>& all_reduce_id =
+ tensorflow::gtl::nullopt);
+
+ // Returns the group ids of each replica for CrossReplicaSum op.
+ const std::vector<int64>& replica_group_ids() const {
+ return replica_group_ids_;
+ }
+
+ // Returns the barrier config used for the CrossReplicaSum implementation of
+ // each backend.
+ string cross_replica_sum_barrier() const {
+ return cross_replica_sum_barrier_;
+ }
+ void set_cross_replica_sum_barrier(string barrier) {
+ cross_replica_sum_barrier_ = barrier;
+ }
+
+ tensorflow::gtl::optional<int64> all_reduce_id() const {
+ return all_reduce_id_;
+ }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The group id of each replica for CrossReplicaSum.
+ std::vector<int64> replica_group_ids_;
+
+ // The string representation of the barrier config used for CrossReplicaSum.
+ string cross_replica_sum_barrier_;
+
+ // For Allreduce nodes from different modules, if they have the same
+ // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+ // applied cross modules.
+ tensorflow::gtl::optional<int64> all_reduce_id_;
+};
+
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
@@ -436,6 +493,8 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ // Used when the literal is too large and dropped.
+ explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns a serialized representation of this instruction.
@@ -722,6 +781,67 @@ class HloReducePrecisionInstruction : public HloInstruction {
int32 exponent_bits_ = 0;
int32 mantissa_bits_ = 0;
};
+
+class HloInfeedInstruction : public HloInstruction {
+ public:
+ explicit HloInfeedInstruction(const Shape& shape, const string& config);
+ // Returns the infeed configuration string. The infeed configuration includes
+ // any metadata needed for the backend compiler (e.g., infeed buffer address)
+ // and is target-dependent.
+ string infeed_config() const { return infeed_config_; }
+ void set_infeed_config(const string& config) { infeed_config_ = config; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The string representation of the infeed configuration.
+ string infeed_config_;
+};
+
+class HloOutfeedInstruction : public HloInstruction {
+ public:
+ explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config);
+ // Returns the shape for the Outfeed instruction.
+ const Shape& outfeed_shape() const {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
+ return outfeed_shape_;
+ }
+ // Returns the config for the Outfeed instruction.
+ const string& outfeed_config() const { return outfeed_config_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Shape of outfeed request.
+ Shape outfeed_shape_;
+ // Outfeed configuration information, only present for kOutfeed.
+ string outfeed_config_;
+};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index c570b420c2..8a31a8e617 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -187,6 +187,7 @@ HLO_MATCHER(Exp);
HLO_MATCHER(Floor);
HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
+HLO_MATCHER(GenerateToken);
HLO_MATCHER(Gt);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index dcd4725fe7..6c1e015f77 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition(
<< " and def is in FALSE computation";
return true;
}
+ if (value.defining_instruction() == use.instruction) {
+ VLOG(4) << " use is conditional " << use << " and def is "
+ << value.ToShortString();
+ return true;
+ }
}
VLOG(4) << " use is not before value";
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index f834d34d57..d551400d1e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -913,7 +913,7 @@ add {
ENTRY CRS {
input = f32[8]{0} parameter(0)
- ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add
+ ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add
}
)"
@@ -931,7 +931,7 @@ add {
ENTRY CrossReplicaSumWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
- ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc"
+ ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add
}
)"
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 9c7bc7a5ea..62c07d7fac 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -1201,7 +1202,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes) {
+ int64 memory_limit_bytes, RematerializationSizes* sizes,
+ bool run_copy_elision) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
@@ -1236,6 +1238,15 @@ StatusOr<bool> HloRematerialization::Run(
return size_function_(buffer.shape());
},
scheduler_algorithm_));
+ if (run_copy_elision) {
+ // We run a separate pass of copy elision here because the sequential
+ // ordering from the HLO schedule allows for more copies to be eliminated.
+ // TODO(b/80249101): Instead of a separate copy elision pass, use the
+ // ordering from the HLO schedule directly for copy insertion.
+ SequentialHloOrdering ordering(module, *sequence);
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module));
+ }
+
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
@@ -1338,9 +1349,10 @@ StatusOr<bool> HloRematerialization::Run(
int64 memory_limit_bytes, HloModule* hlo_module,
MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes) {
+ RematerializationSizes* sizes, bool run_copy_elision) {
HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
+ return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
+ run_copy_elision);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ee2dd0571..59b4cf5dcc 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -57,6 +57,12 @@ class HloRematerialization {
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
+ // run_copy_elision: Enable copy elision. This pass is used to eliminate
+ // copies that were inserted before HLO scheduling.
+ //
+ // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
+ // insertion is integrated with HLO scheduling.
+ //
// Returns whether any instructions were rematerialized. If memory use is
// already below the given limit then no instructions are rematerialized and
// false is returned.
@@ -68,7 +74,7 @@ class HloRematerialization {
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes = nullptr);
+ RematerializationSizes* sizes, bool run_copy_elision = true);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@@ -83,7 +89,8 @@ class HloRematerialization {
// contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit, RematerializationSizes* sizes);
+ int64 memory_limit, RematerializationSizes* sizes,
+ bool run_copy_elision);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index e81334d5a8..7a46da6efe 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -147,7 +147,7 @@ class HloRematerializationTest : public HloTestBase {
TF_EXPECT_OK(verifier().Run(module).status());
return HloRematerialization::RematerializeAndSchedule(
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- sequence);
+ sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false);
}
// Various shapes used in the canned computations.
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index b14ade3549..641b9ecec9 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -375,7 +375,7 @@ int64 SumLogicalBufferSizes(
return size;
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleComputationsInModule(
+StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -498,29 +498,29 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
std::vector<const HloInstruction*> list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 list_memory,
- MinimumMemoryForComputation(computation, list_sequence,
- points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(const int64 list_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, list_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis,
size_function, memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 dfs_memory,
- MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
- size_function));
+ TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, dfs_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> post_order_sequence,
PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 post_order_memory,
- MinimumMemoryForComputation(computation, post_order_sequence,
- points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, post_order_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
@@ -551,12 +551,13 @@ StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
- ScheduleComputationsInModule(
+ ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
- MinimumMemoryForComputation(*computation, one_computation_sequence,
- *points_to_analysis, size_function)
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, one_computation_sequence, *points_to_analysis,
+ size_function, &memory_by_computation)
.ValueOrDie();
sequence[computation] = std::move(one_computation_sequence);
}
@@ -571,8 +572,8 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map;
- return ScheduleComputationsInModule(computation, *points_to_analysis,
- size_function, nullptr, empty_map);
+ return ScheduleComputationHelper(computation, *points_to_analysis,
+ size_function, nullptr, empty_map);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 6f1b1215d3..73f22f81f4 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -144,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
// ROOT %subtract = f32[4]{0} subtract(
// f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
// }
- // %SubcomputationsNotAccounted () -> f32[2,4] {
+ // %ListAccountsForSubcomputations () -> f32[2,4] {
// %constant.3 = f32[2,4]{1,0} constant(
// f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
// %transpose = f32[2,4]{1,0} transpose(
@@ -210,16 +211,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(
- *module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- },
- ListMemoryScheduler));
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ auto entry_computation = module->entry_computation();
+ EXPECT_EQ(entry_computation->instruction_count(),
+ sequence.at(entry_computation).size());
SequentialHloOrdering ordering(module.get(), sequence);
// This schedule is an example of List's greedy heuristics being suboptimal.
// The while_loop is more expensive than transpose, so it would have been
@@ -228,6 +229,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast));
EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 17;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator doesn't account for subcomputations
+ EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn)
+ .ValueOrDie());
+ // HeapSimulator accounts for subcomputations. The max mem doesn't change
+ // because the while body isn't live during the peak.
+ EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
}
TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
@@ -325,5 +344,70 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
}
+TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
+ auto module = CreateNewModule();
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+ // param != 0
+ // Needs 17 bytes
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+ HloInstruction* zero_vector = cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+ auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+ // param - 1
+ // Needs 16 bytes
+ auto body_builder = HloComputation::Builder("WhileBody");
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "body_param"));
+ HloInstruction* one_vector = body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, body_param, one_vector));
+ auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* while_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ // Creates 16 bytes, ignoring subcomputations
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ r1f32, cond_computation, body_computation, while_init));
+
+ module->AddEntryComputation(builder.Build());
+
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ // Verify that all instructions are in the sequence.
+ auto entry_computation = module->entry_computation();
+ EXPECT_EQ(entry_computation->instruction_count(),
+ sequence.at(entry_computation).size());
+
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 17;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator doesn't account for subcomputations
+ EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn)
+ .ValueOrDie());
+ // HeapSimulator accounts for subcomputations
+ EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 7b4b071af4..748273a43c 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -235,6 +235,23 @@ StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
const HloSharding& sharding) {
+ // Here is the place to call external sharding normalizers, which are
+ // implemented in other modules (ie, spatial partitioning).
+ // The signature of the external normalizer function should be something
+ // like:
+ //
+ // StatusOr<bool> Normalizer(const DomainMetadata::Domain&,
+ // const HloSharding& sharding);
+ //
+ // The function should return true if it has processed the domain
+ // normalization, false if domain was not one recognized by it, or an error.
+ // We will call the functions in order below, and fall back to local code if
+ // none of the external normalizers acted on the domain.
+ // External normalizers should not handle the cases that are already handled
+ // locally.
+
+ // None of the external normalizers handled the domain sharding, try to see
+ // whether this is a single sharding first.
auto single_sharding = sharding.ExtractSingleSharding();
if (single_sharding) {
// Shortcut the simple case. We have a unique sharding, so we call
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9034073cc8..1d6cd4cb23 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -431,7 +431,8 @@ Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) {
for (const HloInstruction* operand : token->operands()) {
operand_shapes.push_back(&operand->shape());
}
- return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes));
+ return CheckShape(token,
+ ShapeInference::InferGenerateTokenShape(operand_shapes));
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 7067b6f86a..eb469e77a0 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -937,6 +937,11 @@ LayoutAssignment::LayoutAssignment(
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
+ if (channel_layout_constraints_ != nullptr) {
+ // Save a copy of the input ChannelLayoutConstraints so that we can reset it
+ // if we have to undo previous operations (ClearPreviousPassSideEffects()).
+ channel_constraints_ = *channel_layout_constraints_;
+ }
VLOG(1) << "Entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
// Layouts of all parameter instructions must be set.
@@ -1614,13 +1619,57 @@ Status LayoutAssignment::RunOnComputation(
// Record the layouts assigned for any communication ops in
// channel_constraints so that they are constrained for future modules.
+ if (channel_constraints != nullptr) {
+ TF_RETURN_IF_ERROR(
+ ConstrainChannelLayouts(computation, channel_constraints));
+ }
+ return Status::OK();
+}
+
+Status LayoutAssignment::ConstrainChannelLayouts(
+ HloComputation* computation,
+ ChannelLayoutConstraints* channel_constraints) {
+ // We go through the kRecvDone before. These must either impose their layout,
+ // of find a matching one already existing (ConstrainChannel() returns
+ // nullptr).
for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kRecvDone) {
+ const Layout* layout = channel_constraints->ConstrainChannel(
+ instruction->channel_id(), instruction->shape().layout());
+ TF_RET_CHECK(layout == nullptr)
+ << instruction->ToString()
+ << " cannot constrain layout as it was set to "
+ << LayoutUtil::HumanString(*layout);
+ }
+ }
+ // After that we go through the kSend. These are likely going to have a kCopy
+ // as operand (otherwise we add it), so in case the constrained layout does
+ // not match, we can change the kCopy layout (and the kSend one as well).
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
- channel_constraints->ConstrainChannel(
- instruction->channel_id(), instruction->operand(0)->shape().layout());
- } else if (instruction->opcode() == HloOpcode::kRecvDone) {
- channel_constraints->ConstrainChannel(instruction->channel_id(),
- instruction->shape().layout());
+ HloInstruction* operand = instruction->mutable_operand(0);
+ const Layout* layout = channel_constraints->ConstrainChannel(
+ instruction->channel_id(), operand->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the kSend wants to impose. Eitehr add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ Shape* send_shape =
+ ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
+ *send_shape = shape;
+ }
}
}
return Status::OK();
@@ -1743,6 +1792,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
+ ResetChannelConstraints();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index c287cca0c5..eb4cd5936b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -249,25 +249,30 @@ class ChannelLayoutConstraints {
// Given `shape`, apply the layout for `channel_id`. `channel_id` must already
// be constrained.
Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const {
- CHECK(IsChannelConstrained(channel_id));
- *shape.mutable_layout() = constraints_.at(channel_id);
+ auto it = constraints_.find(channel_id);
+ CHECK(it != constraints_.end()) << "Channel " << channel_id;
+ *shape.mutable_layout() = it->second;
return shape;
}
// Returns the layout constraint for `channel_id`, which must already be
// constrained.
- Layout LayoutForChannel(int64 channel_id) const {
- CHECK(IsChannelConstrained(channel_id));
- return constraints_.at(channel_id);
+ const Layout& LayoutForChannel(int64 channel_id) const {
+ auto it = constraints_.find(channel_id);
+ CHECK(it != constraints_.end()) << "Channel " << channel_id;
+ return it->second;
}
// Adds a new layout constraint for `channel_id`. If a constraint for
- // `channel_id` already exists, this operation requires that the new layout is
- // the same as the previously constrained layout.
- void ConstrainChannel(int64 channel_id, const Layout& layout) {
- CHECK(!IsChannelConstrained(channel_id) ||
- LayoutUtil::Equal(layout, constraints_[channel_id]));
- constraints_[channel_id] = layout;
+ // `channel_id` has been added, this API returns nullptr, otherwise returns
+ // the layout which has already been set for the channel.
+ const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) {
+ auto it = constraints_.emplace(std::make_pair(channel_id, layout));
+ if (it.second) {
+ return nullptr;
+ }
+ return LayoutUtil::Equal(layout, it.first->second) ? nullptr
+ : &it.first->second;
}
private:
@@ -464,6 +469,20 @@ class LayoutAssignment : public HloPassInterface {
// itself).
Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
+ // Apply the channel layout constraints by populating the channel_constraints
+ // data structure passed in at constructor time. Eventually adds copies in
+ // case two ends of a channel ended up with a different leyout.
+ Status ConstrainChannelLayouts(HloComputation* computation,
+ ChannelLayoutConstraints* channel_constraints);
+
+ // Resets the input ChannelLayoutConstraints to the original copy received
+ // from the constructor input.
+ void ResetChannelConstraints() {
+ if (channel_layout_constraints_ != nullptr) {
+ *channel_layout_constraints_ = channel_constraints_;
+ }
+ }
+
// Map containing the layouts of all computations assigned so
// far. Computations are handled in a topological sort where computations are
// handled before their caller instructions so the layouts of caller
@@ -474,7 +493,14 @@ class LayoutAssignment : public HloPassInterface {
// here.
tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
- ChannelLayoutConstraints* channel_layout_constraints_;
+ // The pointer to the channel layout constraints passed in with the
+ // constructor. If not nullptr, this is an input/output argument.
+ ChannelLayoutConstraints* channel_layout_constraints_ = nullptr;
+
+ // A copy of the input layout constraints used to reset the above pointer in
+ // case we have to undo operations due to the multiple passes over the
+ // computations/instructions.
+ ChannelLayoutConstraints channel_constraints_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index bf0448a676..62599b376a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -52,10 +52,18 @@ using ::testing::ElementsAre;
class LayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
- ComputationLayout* entry_computation_layout) {
- LayoutAssignment layout_assignment(entry_computation_layout);
+ ComputationLayout* entry_computation_layout,
+ ChannelLayoutConstraints* channel_constraints = nullptr) {
+ LayoutAssignment layout_assignment(
+ entry_computation_layout, /*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
+
+ std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) {
+ auto minor_to_major =
+ FindInstruction(module, name)->shape().layout().minor_to_major();
+ return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
+ }
};
TEST_F(LayoutAssignmentTest, ComputationLayout) {
@@ -707,17 +715,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
LayoutUtil::MakeLayout({2, 1, 0}));
AssignLayouts(module.get(), &computation_layout);
- auto layout_of = [&](tensorflow::StringPiece name) {
- return FindInstruction(module.get(), name)
- ->shape()
- .layout()
- .minor_to_major();
- };
-
- EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
EXPECT_THAT(FindInstruction(module.get(), "gte1")
->shape()
.tuple_shapes(0)
@@ -816,5 +817,44 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
"Unexpected bitcast operation seen during layout assignment"));
}
+TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
+ // Pin non matching layouts to parameter and root.
+ const char* module_str = R"(
+ HloModule test_module
+
+ ENTRY entry_computation {
+ param = (f32[2,2]) parameter(0)
+ gte = f32[2,2] get-tuple-element(param), index=0
+ recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1}
+ ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1,
+ sharding={maximal device=1}
+ send = (f32[2,2], u32[]) send(gte), channel_id=1,
+ sharding={maximal device=0}
+ send-done = () send-done(send), channel_id=1, sharding={maximal device=0}
+ }
+ )";
+
+ auto module = ParseHloString(module_str).ValueOrDie();
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ Shape param_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
+ TF_ASSERT_OK(
+ computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
+ param_shape));
+ computation_layout.mutable_result_layout()->ResetLayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ ChannelLayoutConstraints channel_constraints;
+ AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+
+ EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::GetSubshape(
+ FindInstruction(module.get(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index 21bca1d6be..f200a08a3c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -32,7 +32,8 @@ static const BufferAllocation* kParameterAllocation = new BufferAllocation(
LogicalBuffer::Color(0));
void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
- llvm_ir::IrArray* array) {
+ llvm_ir::IrArray* array,
+ const ShapeIndex& index) {
BufferAllocation::Slice buffer_slice;
if (hlo.opcode() == HloOpcode::kParameter) {
// Parameters may alias with each other but may not alias with our temporary
@@ -40,7 +41,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0);
} else {
const std::set<BufferAllocation::Slice> slices =
- assignment_.GetAllSlices(&hlo, /*index=*/{});
+ assignment_.GetAllSlices(&hlo, index);
if (slices.empty() || slices.size() > 1) {
// Skip HLOs which don't have a buffer assigned or for which the
// buffer can't be determined statically. We cannot determine their
@@ -137,16 +138,18 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
// 2. Operands of users of the given hlo.
// 3. Operands of the given hlo.
//
- // This set can be increased as we need. For now only consider top-level
- // buffers (index = {}) not buffers nested within the instruction's
- // operands/output which are not typically touched.
+ // This set can be increased as we need.
std::vector<const LogicalBuffer*> worklist;
auto add_buffers_to_worklist =
[&worklist, &assignment](const HloInstruction* instruction) {
- for (const LogicalBuffer* buffer :
- assignment.GetSourceBuffers(instruction, /*index=*/{})) {
- worklist.push_back(buffer);
- }
+ ShapeUtil::ForEachSubshape(
+ instruction->shape(),
+ [&](const Shape& /*shape*/, const ShapeIndex& index) {
+ for (const LogicalBuffer* buffer :
+ assignment.GetSourceBuffers(instruction, index)) {
+ worklist.push_back(buffer);
+ }
+ });
};
for (HloInstruction* user : hlo.users()) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index 5244ac61e5..fe9eab93aa 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -38,7 +38,8 @@ class AliasAnalysis {
// Augments IrArray with aliasing information.
void AddAliasingInformationToIrArray(const HloInstruction& hlo,
- llvm_ir::IrArray* array);
+ llvm_ir::IrArray* array,
+ const ShapeIndex& index = {});
private:
// Returns a unique alias domain for this emitter.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index ff64da87e9..d18c9dee82 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -193,6 +193,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
// An Opaque is like a void*, use i8*.
case OPAQUE:
return llvm::Type::getInt8PtrTy(module->getContext());
+ case TOKEN:
+ // Tokens do not have a physical representation, but the compiler needs
+ // some placeholder type, so use int8*.
+ return llvm::Type::getInt8PtrTy(module->getContext());
default:
LOG(FATAL) << "unsupported type " << element_type;
}
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 29f787b86b..f9f9c7dcf7 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -151,6 +151,22 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
return remaining;
}
+bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
+ // kConstant instruction will not have memory reads, so it won't be a profit
+ // source. Skip them.
+ if (instr->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsEffectiveScalar(instr->shape())) {
+ return false;
+ }
+ // We don't target to fuse producer/consumer instructions -- this should
+ // be taken care of by the instruction_fusion pass. If instr has only
+ // one user, it will not have sibling instructions. We won't consider it.
+ if (instr->user_count() < 2) {
+ return false;
+ }
+ return true;
+}
+
void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
HloInstruction* fusion = instr1;
HloInstruction* fused = instr2;
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index cfdf83cfe8..d9c36fa284 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -72,8 +72,8 @@ class MultiOutputFusion : public HloPassInterface {
// multi-output fusion instruction.
virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0;
- // Whether fusing the instruction can reduce cost.
- virtual bool IsProfitableOperand(HloInstruction* instr) = 0;
+ // Whether fusing the instruction can reduce memory reads.
+ virtual bool IsProfitableOperand(HloInstruction* instr);
// Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index d01c35b992..961158e677 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -348,8 +348,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
module_protos[i]->entry_computation_name().c_str());
TF_RETURN_IF_ERROR(
Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
- hlo_snapshots.push_back(std::move(hlo_snapshot));
}
+ hlo_snapshots.push_back(std::move(hlo_snapshot));
}
VLOG(1) << "Computations:";
@@ -721,6 +721,15 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
executable_ptrs.push_back(executable.get());
}
+ for (int i = 0; i < executable_ptrs.size(); i++) {
+ if (executable_ptrs[i]->dumping_snapshot()) {
+ TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(),
+ all_executors[i][0],
+ execute_backend_->transfer_manager(),
+ executable_ptrs[i]->hlo_snapshot()));
+ }
+ }
+
// Execute the generated executables in parallel and return the device
// handles for each computation's output.
ExecutionProfile profile;
@@ -736,6 +745,18 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
*result->add_responses() = response;
}
+ for (int i = 0; i < executable_ptrs.size(); i++) {
+ if (executable_ptrs[i]->dumping_snapshot()) {
+ TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
+ allocation_tracker_.ResolveForReplica(outputs[i], 0));
+ TF_RETURN_IF_ERROR(RecordResult(*result_buffer, all_executors[i][0],
+ execute_backend_->transfer_manager(),
+ executable_ptrs[i]->hlo_snapshot()));
+ // Dump out the ith snapshot.
+ TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot());
+ }
+ }
+
VLOG(1) << "successfully completed 'execute-graph-parallel' request";
return Status::OK();
}
@@ -835,6 +856,10 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
backend->compiler()->RunBackend(
std::move(module), executor, device_allocator));
+ if (!execution_directory_path.empty()) {
+ executable->set_hlo_snapshot(std::move(hlo_snapshot));
+ }
+
return std::move(executable);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index bd98e86b08..e25f5e67c7 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -49,19 +49,13 @@ bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectNotTupleOrOpaque(const Shape& shape,
- tensorflow::StringPiece op_type) {
- if (ShapeUtil::IsTuple(shape)) {
- return InvalidArgument("Expected non-tuple argument for %s, but got %s.",
+Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument("Expected array argument for %s, but got %s.",
std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
- } else if (ShapeUtil::IsOpaque(shape)) {
- return InvalidArgument("Expected non-opaque argument for %s, but got %s.",
- std::string(op_type).c_str(),
- ShapeUtil::HumanString(shape).c_str());
- } else {
- return Status::OK();
}
+ return Status::OK();
}
Status VerifyReducerShape(const ProgramShape& reducer_shape,
@@ -198,8 +192,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
}
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(shape, "operand of unary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
switch (opcode) {
@@ -289,8 +282,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const Shape* arg_shape = nullptr;
PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
for (const Shape* shape : arg_shapes) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
+ TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
if (!arg_shape) {
arg_shape = shape;
element_type = arg_shape->element_type();
@@ -337,7 +329,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
-/* static */ StatusOr<Shape> ShapeInference::InferTokenShape(
+/* static */ StatusOr<Shape> ShapeInference::InferGenerateTokenShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
@@ -358,12 +350,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
return InvalidArgument(
- "Convert does not allow tuples, so cannot convert from %s to %s.",
+ "Convert does not allow non-arrays, so cannot convert from %s to %s.",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
@@ -380,7 +373,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
@@ -427,7 +421,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
const Shape& operand_shape, const Shape& padding_value_shape,
const PaddingConfig& padding_config) {
- if (ShapeUtil::IsTuple(operand_shape)) {
+ if (!ShapeUtil::IsArray(operand_shape)) {
return InvalidArgument(
"Pad operation does not support tuple-shape operands.");
}
@@ -566,8 +560,8 @@ Status ValidateDotDimensionNumbers(
/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
auto fail = [lhs, rhs](const string& addendum) -> Status {
string message = tensorflow::strings::Printf(
@@ -786,10 +780,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -853,12 +845,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- HloOpcodeString(opcode))));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
+ HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
+ HloOpcodeString(opcode))));
switch (opcode) {
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -984,15 +976,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// All arguments must have the same shape.
const Shape* arg_shape = arg_shapes[0];
for (size_t i = 1; i < arg_shapes.size(); ++i) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
+ TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
continue;
}
- if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
- !ShapeUtil::IsTuple(*arg_shape) &&
- ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
+ if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
*arg_shape)) {
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
continue;
@@ -1075,11 +1064,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& scale_shape,
const Shape& offset_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm training"));
+ ExpectArray(operand_shape, "operand of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm training"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1181,11 +1170,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& offset_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm inference"));
+ ExpectArray(operand_shape, "operand of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1328,16 +1317,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,
const Shape& output_grad_shape, int64 feature_index) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
+ ExpectArray(scale_shape, "scale input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- output_grad_shape, "output_grad input of batch norm grad"));
+ ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
@@ -1486,8 +1472,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
const ConvolutionDimensionNumbers& dnums) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -1722,7 +1708,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum"));
+ ExpectArray(*operand_shape, "operand of cross replica sum"));
}
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
@@ -1764,8 +1750,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
operand_shape.element_type()));
return InferWindowOutputShape(operand_shape, window,
@@ -1778,7 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Window& window, const Shape& source_shape,
const Shape& init_value_shape, const ProgramShape& scatter_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
+ ExpectArray(operand_shape, "operand of select-and-scatter"));
// Check if the select function has a proper shape of (T,T) -> PRED.
if (select_shape.parameters_size() != 2) {
@@ -1843,7 +1828,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
Join(starts, ",").c_str(), Join(limits, ",").c_str(),
Join(strides, ",").c_str());
};
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
@@ -1902,10 +1887,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
- "start indices of dynamic slice"));
+ ExpectArray(start_indices_shape, "start indices of dynamic slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
@@ -1963,11 +1947,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& update_shape,
const Shape& start_indices_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
+ ExpectArray(operand_shape, "operand of dynamic update slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- start_indices_shape, "start indices of dynamic update slice"));
+ ExpectArray(update_shape, "update of dynamic update slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
+ "start indices of dynamic update slice"));
VLOG(2) << tensorflow::strings::Printf(
"updating slice of shape %s at dynamic start_indices %s with update "
@@ -2035,8 +2019,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
}
@@ -2166,7 +2149,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
return InvalidArgument("Broadcast with negative dimension size %lld.",
@@ -2185,7 +2168,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
ShapeUtil::MakeShape(operand.element_type(), new_sizes);
@@ -2217,7 +2200,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
std::iota(indices.begin(), indices.end(), 0);
@@ -2238,9 +2221,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// "degenerate" cases, as with binary elementwise ops.
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
const Shape& min, const Shape& operand, const Shape& max) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
+ TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
+ TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
@@ -2439,9 +2422,9 @@ static Status ValidateGatherDimensionNumbers(
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(input_shape, "input tensor operand gather op"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index f1f7b50902..eef6e62fc8 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -220,7 +220,7 @@ class ShapeInference {
// shape is always a TOKEN shape. However, ShapeInference serves two purposes:
// inferring shapes and checking operand shapes. This method verifies that the
// operand shapes are all TOKENs.
- static StatusOr<Shape> InferTokenShape(
+ static StatusOr<Shape> InferGenerateTokenShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
// Helper that validates the given operand shape can be converted to the
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6d017dffe2..bafe14d6f4 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1311,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(
inferred_status_error4.status().error_message(),
- HasSubstr("Expected non-tuple argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
@@ -1387,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
ShapeInference::InferReverseShape(tuple_shape, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- HasSubstr("Expected non-tuple argument"));
+ HasSubstr("Expected array argument"));
}
TEST_F(ShapeInferenceTest, Call) {
@@ -1686,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for input"))
+ HasSubstr("Expected array argument for input"))
<< statusor.status();
}
@@ -1700,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for gather indices"))
+ HasSubstr("Expected array argument for gather indices"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
index aa40b5cb26..44b0ec5cd4 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
@@ -32,11 +32,11 @@ StatusOr<bool> ZeroSizedHloElimination::Run(HloModule* module) {
for (HloComputation* comp : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (instruction->HasSideEffect() ||
- ShapeUtil::IsTuple(instruction->shape())) {
+ !ShapeUtil::IsArray(instruction->shape())) {
continue;
}
if (comp->IsRemovable(instruction) &&
- ShapeUtil::HasZeroElements(instruction->shape())) {
+ ShapeUtil::IsZeroElementArray(instruction->shape())) {
TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(
Literal::CreateFromShape(instruction->shape()))));
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 5db6659932..c85fb20e01 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -363,7 +363,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::IsNil(const Shape& shape) {
- return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape);
+ return IsEmptyTuple(shape);
}
/* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
@@ -413,8 +413,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
std::multiplies<int64>());
}
-/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) {
- return ElementsIn(shape) == 0;
+/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
+ return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
@@ -645,15 +645,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
- if (IsArray(lhs)) {
- return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
- } else if (lhs.element_type() == TUPLE) {
- return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
- } else {
- // Opaque, token, etc types are vacuously compatible.
- return true;
- }
+ return CompareShapes(lhs, rhs, /*compare_layouts=*/false);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
@@ -903,6 +895,21 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return *return_shape;
}
+/* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
+ const Shape& shape, ShapeIndexView index) {
+ const Shape* return_shape = &shape;
+ for (auto i : index) {
+ if (!IsTuple(*return_shape) || i < 0 ||
+ i >= return_shape->tuple_shapes_size()) {
+ return InvalidArgument(
+ "Shape index %s not a valid subshape index for tuple with shape %s",
+ index.ToString().c_str(), shape.DebugString().c_str());
+ }
+ return_shape = &return_shape->tuple_shapes(i);
+ }
+ return return_shape;
+}
+
/* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
ShapeIndexView index) {
Shape* return_shape = shape;
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index ae2d17d6bb..8ee3f490a0 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -175,8 +175,8 @@ class ShapeUtil {
// Precondition: IsArray(shape)
static int64 ElementsIn(const Shape& shape);
- // Returns true if 'shape' has zero elements.
- static bool HasZeroElements(const Shape& shape);
+ // Returns true if 'shape' is an array with zero elements.
+ static bool IsZeroElementArray(const Shape& shape);
// Returns the number of bytes required for an allocation of shape. The
// |pointer_size| parameter is used for calculating the size of tuple
@@ -336,7 +336,7 @@ class ShapeUtil {
// Appends a major dimension to the shape with the given bound.
static void AppendMajorDimension(int bound, Shape* shape);
- // Returns an empty tuple shape. Can be used to indicate side-effects.
+ // Returns an empty tuple shape. Can be used as a sentinel Shape value.
static Shape MakeNil() { return MakeTupleShape({}); }
// Checks whether the shape is initialized.
@@ -446,7 +446,7 @@ class ShapeUtil {
// Returns true if shape is an empty tuple.
static bool IsEmptyTuple(const Shape& shape);
- // Returns true if shape is an empty tuple, or is an array with no elements.
+ // Returns true if shape is the nil shape (an empty tuple).
static bool IsNil(const Shape& shape);
// Returns the number of elements in the given tuple shape.
@@ -476,8 +476,11 @@ class ShapeUtil {
static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
// GetSubshape and GetMutableSubshape return a particular nested Shape within
- // the given Shape argument.
+ // the given Shape argument. The non-Try variants check fail if index is
+ // invalid.
static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
+ static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
+ ShapeIndexView index);
static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
// Returns whether the given index in the given shape is a leaf element of the
@@ -697,7 +700,7 @@ class ShapeUtil {
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function,
bool parallel = false) {
- if (ShapeUtil::HasZeroElements(shape)) {
+ if (ShapeUtil::IsZeroElementArray(shape)) {
return Status::OK();
}
CHECK_EQ(Rank(shape), base.size());
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 0ff514564b..61aa198e52 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -172,6 +172,41 @@ TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2));
}
+TEST(ShapeUtilTest, TokenCompatibility) {
+ EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShape(F32, {})));
+ EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_TRUE(ShapeUtil::Compatible(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()})));
+}
+
+TEST(ShapeUtilTest, TokensEqualShapes) {
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShape(F32, {})));
+ EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})})));
+ EXPECT_FALSE(ShapeUtil::Equal(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {1, 0})})));
+}
+
TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
auto layout_1 = shape_1.mutable_layout();
@@ -329,6 +364,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape));
}
+TEST(ShapeUtilTest, NilShape) {
+ EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil()));
+ EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3})));
+ EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1})));
+ EXPECT_FALSE(ShapeUtil::IsNil(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_FALSE(ShapeUtil::IsNil(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})})));
+}
+
TEST(ShapeUtilTest, NestedTuple) {
EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({})));
EXPECT_FALSE(ShapeUtil::IsNestedTuple(
@@ -359,25 +404,30 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
-TEST(ShapeUtilTest, HasZeroElements) {
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {})));
- EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0})));
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1})));
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5})));
- EXPECT_EQ(true,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5})));
- EXPECT_EQ(true,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17})));
+TEST(ShapeUtilTest, IsZeroElementArray) {
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_TRUE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_TRUE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17})));
+
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil()));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})})));
}
TEST(ShapeUtilTest, SameDimensions) {
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 36a7064969..c3a289ee09 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -2758,7 +2758,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
::testing::ContainsRegex(
- "Expected non-opaque argument for lhs of binary operation"));
+ "Expected array argument for lhs of binary operation"));
}
XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index a4c8a83eb1..352864502a 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -417,7 +417,22 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(
computation_status.status().ToString(),
- HasSubstr("Expected non-opaque argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
+}
+
+// Show that we can't concatenate with tokens.
+XLA_TEST_F(ConcatTest, CannotConcatTokens) {
+ XlaBuilder builder(TestName());
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
+ auto x = builder.Parameter(0, r1f32, "x");
+ auto y = builder.Parameter(1, token_shape, "y");
+ builder.ConcatInDim({x, y}, 0);
+ StatusOr<XlaComputation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_THAT(
+ computation_status.status().ToString(),
+ HasSubstr("Expected array argument for operand of concatenation"));
}
XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index 4585244ce8..3ef54e6f89 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -28,8 +28,6 @@ namespace {
class TokenHloTest : public HloTestBase {};
-// TODO(b/79770375): Compile, not just verify the HLO module when the backends
-// support kGenerateToken.
XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
std::unique_ptr<HloModule> module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
@@ -120,5 +118,40 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
"Operands of token instructions must be TOKEN types"));
}
+XLA_TEST_F(TokenHloTest, TokenInWhileLoop) {
+ // Thread a token around a while loop. Token is created and consumed by a
+ // GenerateToken instruction in the while body.
+ string module_string = R"(
+HloModule TokenInWhileLoop
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %generate-token = token[] generate-token(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %TokenInWhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] generate-token()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ EXPECT_TRUE(RunAndCompare(module_string, error_spec_));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/contrib/autograph/LIMITATIONS.md
new file mode 100644
index 0000000000..d8b1cb7616
--- /dev/null
+++ b/tensorflow/contrib/autograph/LIMITATIONS.md
@@ -0,0 +1,50 @@
+# Capabilities and Limitations
+
+TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`.
+
+Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support.
+
+# Python Language Support Status
+
+Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved.
+
+ Construct | Supported now? | Plan to support? | Notes
+ :--------- | :--------------: | :----------------: | :-----
+If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error.
+For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations.
+While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations.
+Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests.
+Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested.
+Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`.
+Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so.
+Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists.
+Static function calls | Yes | | Non-recursive function calls
+Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion.
+Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant.
+Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html).
+List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation.
+Function variables | Yes | | e.g. `f_new = f_orig; f_new()`
+Lambda functions | No | Yes | Planned feature.
+Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods.
+Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported.
+Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case.
+Dynamic code / exec | No | |
+Reflection | No | |
+Try / Except | No | No | No current sane TF equivalent.
+Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code.
+Functions with side effects | Some | | Side effects are allowed, under certain circumstances.
+Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples.
+List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority.
+Custom context managers | No | Yes | Currently low priority. Left unconverted currently.
+Generators | No | Maybe | Could be achievable using queues; very low priority.
+Assertions | Yes | | As `tf.Assert`
+Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in.
+Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority.
+Async | No | No |
+
+## Extra capabilities
+
+ - We liberally add name scopes to generated functions
+ - Operations get decent default names everywhere (planned)
+ - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially.
+
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index 674859bed4..829a57d8e6 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -120,3 +120,15 @@ You can use the functional API to inspect the generated code as well:
print(ag.to_code(f))
# Output: <Python and TensorFlow code>
```
+
+## Filing bugs and feature requests
+
+### Reporting a bug
+
+ - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
+
+### Requesting a feature
+
+If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md
index 866e5f583a..7e6b0cc27d 100644
--- a/tensorflow/contrib/autograph/STYLE_GUIDE.md
+++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md
@@ -20,7 +20,17 @@ Naming conventions:
Below are AutoGraph-specific conventions. In the event of conflict,
it supercedes all previous conventions.
-1. __Citations in Docstrings.__ Write a `#### References` subsection at the
+1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/]
+ notation to describe the type for args, return values and attributes.
+
+ Example:
+
+ ```
+ Args:
+ foo: Dict[str, List[int]], a dictionary of sorts
+ ```
+
+2. __Citations in Docstrings.__ Write a `#### References` subsection at the
bottom of any docstring with citations. Use ICLR’s bibliography style to
write references; for example, order entries by the first author's last
name. Add a link to the paper if the publication is open source (ideally,
@@ -60,12 +70,12 @@ it supercedes all previous conventions.
https://arxiv.org/abs/1803.04386
```
-2. Avoid LaTeX in docstrings.
+3. Avoid LaTeX in docstrings.
* It is not rendered in many (if not most) editors and can be hard to read
for both LaTeX experts and non-experts.
-3. Write docstring and comment math using ASCII friendly notation; python using
+4. Write docstring and comment math using ASCII friendly notation; python using
operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`,
`sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx:
x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`.
diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/contrib/autograph/lang/BUILD
new file mode 100644
index 0000000000..77a2184e22
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/BUILD
@@ -0,0 +1,40 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "lang",
+ srcs = [
+ "directives.py",
+ "special_functions.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/autograph/operators",
+ ],
+)
+
+py_test(
+ name = "special_functions_test",
+ srcs = ["special_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":lang",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/contrib/autograph/lang/directives.py
new file mode 100644
index 0000000000..aabe5d9939
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/directives.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Directives are special no-op functions that serve as compilation markers.
+
+They provide static information like type hints, compilation and TensorFlow
+overrides.
+
+These serve as annotations in the compiled code, allowing the user some control
+over the compilation process. They have no functional role at runtime.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+UNSPECIFIED = object()
+
+
+def set_element_type(entity, dtype, shape=UNSPECIFIED):
+ """Indicates that the entity is expected hold items of specified type/shape.
+
+ The staged TensorFlow ops will reflect and assert this data type. Ignored
+ otherwise.
+
+ Args:
+ entity: The entity to annotate.
+ dtype: TensorFlow dtype value to assert for entity.
+ shape: Optional shape to assert for entity.
+ """
+ del entity
+ del dtype
+ del shape
+
+
+def set_loop_options(
+ parallel_iterations=UNSPECIFIED,
+ back_prop=UNSPECIFIED,
+ swap_memory=UNSPECIFIED,
+ maximum_iterations=UNSPECIFIED):
+ """Specifies additional arguments to be passed to the enclosing while_loop.
+
+ The parameters apply to and only to the immediately enclosing loop. It only
+ has effect if the loop is staged as a TF while_loop; otherwise the parameters
+ have no effect.
+
+ Args:
+ parallel_iterations: See tf.while_loop.
+ back_prop: See tf.while_loop.
+ swap_memory: See tf.while_loop.
+ maximum_iterations: See tf.while_loop.
+ """
+ del parallel_iterations
+ del back_prop
+ del swap_memory
+ del maximum_iterations
diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py
new file mode 100644
index 0000000000..11135295a7
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/special_functions.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Special functions that only make sense for AutoGraph.
+
+These functions are meant to ensure feature parity between Python and AutoGraph,
+so that the exact same code works in both modes. In general, AutoGraph will
+replace these calls.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.operators import data_structures
+
+
+def stack(list_or_tensor, element_dtype=None, strict=True):
+ """Stacks the input, if it admits the notion of stacking.
+
+ For example, a list of tensors can be stacked into a larger tensor. This
+ function is similar to tf.stack, but it accepts non-lists and lists of
+ non-tensors as arguments. In the latter case, the function does nothing.
+
+ Args:
+ list_or_tensor: Any
+ element_dtype: tf.DType, optional dtypedtype for the elements in the list.
+ Required if the input is stackable, and the list is untyped.
+ strict: bool, if True an error is raised if the input is not stackable.
+ Otherwise the function is a no-op.
+
+ Returns:
+ Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
+ if strict=False, the result will be list_or_tensor.
+
+ Raises:
+ ValueError: if strict=True and the input is not stackable.
+ """
+ if strict:
+ def raise_error(x):
+ raise ValueError('%s must be stackable when strict=True' % x)
+ original_call = raise_error
+ else:
+ original_call = lambda x: x
+ return data_structures.list_stack(
+ list_or_tensor,
+ data_structures.ListStackOpts(
+ element_dtype=element_dtype, original_call=original_call))
diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py
new file mode 100644
index 0000000000..a49cb64075
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/special_functions_test.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for special_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.lang import special_functions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SpecialFunctionsTest(test.TestCase):
+
+ def test_basic(self):
+ self.assertEqual(special_functions.stack(1, strict=False), 1)
+ self.assertListEqual(
+ special_functions.stack([1, 2, 3], strict=False), [1, 2, 3])
+ # TODO(mdan): This should probably forward to tf.stack.
+ self.assertTrue(
+ isinstance(
+ special_functions.stack(
+ [constant_op.constant(1),
+ constant_op.constant(2)], strict=False), list))
+
+ with self.assertRaises(ValueError):
+ special_functions.stack([1, 2, 3])
+
+ t = constant_op.constant([1.0, 2.0])
+ l = list_ops.tensor_list_from_tensor(
+ t, element_shape=constant_op.constant([], dtype=dtypes.int32))
+ self.assertTrue(
+ tensor_util.is_tensor(
+ special_functions.stack(l, element_dtype=dtypes.float32)))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 8ae493ba99..9aa4614967 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -16,9 +16,11 @@
Visualization and inspection:
@@dot_graph_from_checkpoint
+@@list_objects
@@object_metadata
Managing dependencies:
+@@capture_dependencies
@@Checkpointable
@@CheckpointableObjectGraph
@@NoDependency
@@ -42,6 +44,8 @@ from tensorflow.python.training.checkpointable.base import Checkpointable
from tensorflow.python.training.checkpointable.base import NoDependency
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
+from tensorflow.python.training.checkpointable.util import capture_dependencies
+from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
index 42ba368531..1a7a3759ba 100644
--- a/tensorflow/contrib/cloud/BUILD
+++ b/tensorflow/contrib/cloud/BUILD
@@ -74,3 +74,14 @@ tf_py_test(
],
tags = ["manual"],
)
+
+tf_py_test(
+ name = "gcs_config_ops_test",
+ size = "small",
+ srcs = ["python/ops/gcs_config_ops_test.py"],
+ additional_deps = [
+ ":cloud_py",
+ "//tensorflow/python:client_testlib",
+ ],
+ tags = ["manual"],
+)
diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py
index a6e13ea3ae..ef7aa7624c 100644
--- a/tensorflow/contrib/cloud/__init__.py
+++ b/tensorflow/contrib/cloud/__init__.py
@@ -27,8 +27,9 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'BigQueryReader',
- 'ConfigureColabSession',
- 'ConfigureGcs',
+ 'BlockCacheParams',
+ 'configure_colab_session',
+ 'configure_gcs',
'ConfigureGcsHook',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD
index 40160706f7..1311063ec0 100644
--- a/tensorflow/contrib/cloud/kernels/BUILD
+++ b/tensorflow/contrib/cloud/kernels/BUILD
@@ -79,6 +79,7 @@ tf_kernel_library(
srcs = ["gcs_config_ops.cc"],
visibility = ["//tensorflow:internal"],
deps = [
+ "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform/cloud:curl_http_request",
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
new file mode 100644
index 0000000000..fc0c994812
--- /dev/null
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
@@ -0,0 +1,34 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the gcs_config_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cloud.python.ops import gcs_config_ops
+from tensorflow.python.platform import test
+
+
+class GcsConfigOpsTest(test.TestCase):
+
+ def testSetBlockCache(self):
+ cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024)
+ with self.test_session() as sess:
+ gcs_config_ops.configure_gcs(sess, block_cache=cfg)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index eb9482dc25..c8de8db126 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -325,6 +325,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810
# Broken io_utils_test
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325
+ # OOM
+ "${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # b/110210559
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py
index 9ffad9caa9..90371cd8d7 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2.py
@@ -44,11 +44,34 @@ from tensorflow.python.util import compat
def cond_v2(pred, true_fn, false_fn, name="cond"):
"""Like tf.cond, except emits a single If op."""
+ if not name:
+ name = "cond"
+
with ops.name_scope(name) as scope:
- true_graph = function.func_graph_from_py_func(true_fn, [], [],
- name="%s_true" % scope)
- false_graph = function.func_graph_from_py_func(false_fn, [], [],
- name="%s_false" % scope)
+ # Identify if there is a caller device, & get the innermost if possible.
+ device_stack = ops.get_default_graph()._device_function_stack
+ caller_device = device_stack[-1] if device_stack else None
+
+ caller_colocation_stack = ops.get_default_graph()._colocation_stack
+ caller_container = ops.get_default_graph()._container
+ caller_collection_ref = ops.get_default_graph()._collections
+
+ func_name_prefix = scope.replace("/", "_")
+
+ true_graph = function.func_graph_from_py_func(
+ true_fn, [], [],
+ name="%strue" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
+ false_graph = function.func_graph_from_py_func(
+ false_fn, [], [],
+ name="%sfalse" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that
diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py
index 338601aa2c..94ed3e130b 100644
--- a/tensorflow/contrib/control_flow/python/cond_v2_test.py
+++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py
@@ -25,10 +25,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
+from tensorflow.python.util import compat
class NewCondTest(test.TestCase):
@@ -96,6 +99,37 @@ class NewCondTest(test.TestCase):
self.assertEqual(sess.run(out, {pred: True}), [1.0])
self.assertEqual(sess.run(out, {pred: False}), [2.0])
+ def _createCond(self, name):
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+ x = constant_op.constant(1.0, name="x")
+
+ def true_fn():
+ return x
+
+ def false_fn():
+ return x + 1
+
+ return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
+
+ def testDefaultName(self):
+ with ops.Graph().as_default():
+ cond = self._createCond(None)
+ self.assertEqual(cond.name, "cond")
+ self.assertIn("cond_true", ops.get_default_graph()._functions)
+ self.assertIn("cond_false", ops.get_default_graph()._functions)
+
+ with ops.Graph().as_default():
+ with ops.name_scope("foo"):
+ cond = self._createCond("")
+ self.assertEqual(cond.name, "foo/cond")
+ self.assertIn("foo_cond_true", ops.get_default_graph()._functions)
+ self.assertIn("foo_cond_false", ops.get_default_graph()._functions)
+
+ cond2 = self._createCond(None)
+ self.assertEqual(cond2.name, "foo/cond_1")
+ self.assertIn("foo_cond_1_true", ops.get_default_graph()._functions)
+ self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
+
def testSecondDerivative(self):
pred = array_ops.placeholder(dtypes.bool, name="pred")
x = constant_op.constant(3.0, name="x")
@@ -167,5 +201,225 @@ class NewCondTest(test.TestCase):
self.assertEqual(false_val, [0.0])
+class CondV2CollectionTest(test.TestCase):
+
+ def testCollectionIntValueAccessInCond(self):
+ """Read values from graph collections inside of cond_v2."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = 2
+ y = 5
+ ops.add_to_collection("x", x)
+ ops.add_to_collection("y", y)
+ def fn():
+ x_const = constant_op.constant(ops.get_collection("x")[0])
+ y_const = constant_op.constant(ops.get_collection("y")[0])
+ return math_ops.add(x_const, y_const)
+
+ cnd = cond_v2.cond_v2(True, fn, fn)
+ self.assertEquals(cnd[0].eval(), 7)
+
+ def testCollectionTensorValueAccessInCond(self):
+ """Read tensors from collections inside of cond_v2 & use them."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ ops.add_to_collection("x", x)
+ ops.add_to_collection("y", y)
+
+ def fn():
+ x_read = ops.get_collection("x")[0]
+ y_read = ops.get_collection("y")[0]
+ return math_ops.add(x_read, y_read)
+
+ cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
+ self.assertEquals(cnd[0].eval(), 7)
+
+ def testCollectionIntValueWriteInCond(self):
+ """Make sure Int writes to collections work inside of cond_v2."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ def true_fn():
+ z = math_ops.add(x, y)
+ ops.add_to_collection("z", 7)
+ return math_ops.mul(x, z)
+
+ def false_fn():
+ z = math_ops.add(x, y)
+ return math_ops.mul(x, z)
+
+ cnd = cond_v2.cond_v2(
+ True, true_fn,
+ false_fn)
+ self.assertEquals(cnd[0].eval(), 14)
+
+ read_z_collection = ops.get_collection("z")
+ self.assertEquals(read_z_collection, [7])
+
+
+class CondV2ContainerTest(test.TestCase):
+
+ def testContainer(self):
+ """Set containers outside & inside of cond_v2.
+
+ Make sure the containers are set correctly for both variable creation
+ (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
+ """
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ v0 = variables.Variable([0])
+ q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ def container(node):
+ return node.op.get_attr("container")
+
+ self.assertEqual(compat.as_bytes(""), container(v0))
+ self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))
+
+ def true_fn():
+ # When this branch is created in cond below,
+ # the container should begin with 'l1'
+ v1 = variables.Variable([1])
+ q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ with ops.container("l2t"):
+ v2 = variables.Variable([2])
+ q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ v3 = variables.Variable([1])
+ q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v1))
+ self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
+ self.assertEqual(compat.as_bytes("l2t"), container(v2))
+ self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))
+ self.assertEqual(compat.as_bytes("l1"), container(v3))
+ self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))
+
+ return constant_op.constant(2.0)
+
+ def false_fn():
+ # When this branch is created in cond below,
+ # the container should begin with 'l1'
+ v1 = variables.Variable([1])
+ q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ with ops.container("l2f"):
+ v2 = variables.Variable([2])
+ q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ v3 = variables.Variable([1])
+ q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v1))
+ self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
+ self.assertEqual(compat.as_bytes("l2f"), container(v2))
+ self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
+ self.assertEqual(compat.as_bytes("l1"), container(v3))
+ self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))
+
+ return constant_op.constant(6.0)
+
+ with ops.container("l1"):
+ cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
+ self.assertEquals(cnd_true[0].eval(), 2)
+
+ cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
+ self.assertEquals(cnd_false[0].eval(), 6)
+
+ v4 = variables.Variable([3])
+ q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+ v5 = variables.Variable([4])
+ q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v4))
+ self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
+ self.assertEqual(compat.as_bytes(""), container(v5))
+ self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
+
+
+class CondV2ColocationGroupAndDeviceTest(test.TestCase):
+
+ def testColocateWithBeforeCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant([2.0], name="b")
+
+ def fn():
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+
+ def fn2():
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(b.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ def testColocateWithInAndOutOfCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant([2.0], name="b")
+
+ def fn2():
+ with ops.colocate_with(b.op):
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ d = constant_op.constant([2.0], name="d")
+ self.assertEqual([b"loc:@a"], d.op.colocation_groups())
+
+ def testDeviceBeforeCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ def fn():
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:CPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:CPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+
+ def fn2():
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:GPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:GPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ def testDeviceInAndOutOfCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ def fn2():
+ with ops.device("/device:GPU:0"):
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:GPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:CPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ d = constant_op.constant(4.0)
+ self.assertEqual("/device:CPU:0", d.op.device)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 0dfd249ec2..4e3f9801d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -30,6 +30,7 @@ py_test(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index b5fbc45ad3..1435503beb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import math
import time
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
@@ -40,7 +41,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase):
+class BatchDatasetTest(test.TestCase, parameterized.TestCase):
def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
@@ -427,9 +428,13 @@ class BatchDatasetTest(test.TestCase):
self.assertEqual([None], dataset.output_shapes[1][0].as_list())
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
- def _testMapAndBatchDatasetHelper(self,
- num_parallel_calls=None,
- num_parallel_batches=None):
+ @parameterized.named_parameters(
+ ("default", None, None),
+ ("sequential_calls", 1, None),
+ ("parallel_calls", 2, None),
+ ("parallel_batches", None, 10),
+ )
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
@@ -500,19 +505,11 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
- def testMapAndBatch(self):
- return self._testMapAndBatchDatasetHelper()
-
- def testMapAndBatchWithParallelBatches(self):
- return self._testMapAndBatchDatasetHelper(num_parallel_batches=10)
-
- def testMapAndBatchWithSequentialCalls(self):
- return self._testMapAndBatchDatasetHelper(num_parallel_calls=1)
-
- def testMapAndBatchWithParallelCalls(self):
- return self._testMapAndBatchDatasetHelper(num_parallel_calls=2)
-
- def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False):
+ @parameterized.named_parameters(
+ ("even", False),
+ ("uneven", True),
+ )
+ def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(
@@ -532,12 +529,6 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- def testMapAndBatchPartialBatch(self):
- return self._testMapAndBatchPartialBatchHelper()
-
- def testMapAndBatchPartialBatchDropRemainder(self):
- return self._testMapAndBatchPartialBatchHelper(drop_remainder=True)
-
def testMapAndBatchYieldsPartialBatch(self):
iterator = (dataset_ops.Dataset.range(10)
.apply(batching.map_and_batch(
@@ -614,7 +605,7 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testMapAndBatchDatasetFails(self):
+ def testMapAndBatchFails(self):
"""Test a dataset that maps a TF function across its input elements."""
dataset = dataset_ops.Dataset.from_tensors(
array_ops.check_numerics(
@@ -628,7 +619,7 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
- def testMapAndBatchDatasetShapeMismatch(self):
+ def testMapAndBatchShapeMismatch(self):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index bd3e034211..4fbfbfdbdd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -68,7 +68,7 @@ class GroupByReducerTest(test.TestCase):
reducer = grouping.Reducer(
init_func=lambda _: (0.0, 0.0),
reduce_func=reduce_fn,
- finalize_func=lambda x: x[0])
+ finalize_func=lambda x, _: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).apply(
grouping.group_by_reducer(
@@ -121,7 +121,7 @@ class GroupByReducerTest(test.TestCase):
reducer = grouping.Reducer(
init_func=lambda x: ([0], 1),
reduce_func=reduce_fn,
- finalize_func=lambda x: x)
+ finalize_func=lambda x, y: (x, y))
for i in range(1, 11):
dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index bdc003a8a5..520da7d6ff 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -17,10 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import time
-from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import resampling
from tensorflow.python.data.ops import dataset_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 1b67a33f04..25e9ea47b8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -48,10 +48,10 @@ class ShuffleDatasetSerializationTest(
def testShuffleCore(self):
seed = 55
- range_limit = 10
- num_repeats = 5
+ range_limit = 5
+ num_repeats = 2
num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 8, 10, 25, 50]
+ buffer_sizes = [1, 3, 5, 8, 10]
# pylint: disable=cell-var-from-loop
# pylint: disable=g-long-lambda
for reshuffle_each_iteration in [True, False]:
@@ -75,10 +75,10 @@ class ShuffleDatasetSerializationTest(
def testNonDeterministicSeeding(self):
- range_limit = 10
- num_repeats = 5
+ range_limit = 5
+ num_repeats = 2
num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 8, 10, 25, 50]
+ buffer_sizes = [1, 3, 5, 8, 10]
for reshuffle_each_iteration in [True, False]:
for buffer_size in buffer_sizes:
@@ -111,10 +111,10 @@ class ShuffleDatasetSerializationTest(
self.match(expected, actual)
def testMultipleIterators(self):
- range_limit = 10
- num_repeats = 5
+ range_limit = 5
+ num_repeats = 2
num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 8, 10, 25, 50]
+ buffer_sizes = [1, 3, 5, 8, 10]
for reshuffle_each_iteration in [True, False]:
for buffer_size in buffer_sizes:
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index f9f25e6a06..4068a2ffa5 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -21,12 +21,9 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -273,70 +270,27 @@ class GroupByReducerDataset(dataset_ops.Dataset):
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping Defun for key_func."""
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_key_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- # pylint: disable=protected-access
- if dataset_ops._should_unpack_args(nested_args):
- ret = key_func(*nested_args)
- # pylint: enable=protected-access
- else:
- ret = key_func(nested_args)
- ret = ops.convert_to_tensor(ret)
- if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
- raise ValueError(
- "`key_func` must return a single tf.int64 tensor. "
- "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
- return ret
-
- self._key_func = tf_key_func
- self._key_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func, "tf.contrib.data.group_by_reducer()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s"
+ % (wrapped_func.output_types, wrapped_func.output_shapes))
+ self._key_func = wrapped_func.function
def _make_init_func(self, init_func):
"""Make wrapping Defun for init_func."""
-
- @function.Defun(dtypes.int64)
- def tf_init_func(key):
- """A wrapper for Defun that facilitates shape inference."""
- key.set_shape([])
- ret = init_func(key)
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._state_classes = sparse.get_classes(ret)
- self._state_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._state_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._init_func = tf_init_func
- self._init_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ init_func, "tf.contrib.data.group_by_reducer()",
+ input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ self._init_func = wrapped_func.function
+ self._state_classes = wrapped_func.output_classes
+ self._state_shapes = wrapped_func.output_shapes
+ self._state_types = wrapped_func.output_types
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping Defun for reduce_func."""
@@ -346,85 +300,47 @@ class GroupByReducerDataset(dataset_ops.Dataset):
need_to_rerun = True
while need_to_rerun:
- # Create a list in which `tf_reduce_func` will store the new shapes.
- flat_new_state_shapes = []
-
- @function.Defun(*(nest.flatten(
- sparse.as_dense_types(
- self._state_types, self._state_classes)) + nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes))))
- def tf_reduce_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- for arg, shape in zip(
- args,
- nest.flatten(
- sparse.as_dense_shapes(self._state_shapes, self._state_classes))
- + nest.flatten(
- sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes))):
- arg.set_shape(shape)
-
- pivot = len(nest.flatten(self._state_shapes))
- nested_state_args = nest.pack_sequence_as(self._state_types,
- args[:pivot])
- nested_state_args = sparse.deserialize_sparse_tensors(
- nested_state_args, self._state_types, self._state_shapes,
- self._state_classes)
- nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
- args[pivot:])
- nested_input_args = sparse.deserialize_sparse_tensors(
- nested_input_args, input_dataset.output_types,
- input_dataset.output_shapes, input_dataset.output_classes)
-
- ret = reduce_func(nested_state_args, nested_input_args)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- # Extract shape information from the returned values.
- flat_new_state = nest.flatten(ret)
- flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])
-
- # Extract and validate type information from the returned values.
- for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
- if t.dtype != dtype:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types,
- nest.pack_sequence_as(self._state_types,
- [t.dtype for t in flat_new_state])))
-
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret,
- [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- # Use the private method that will execute `tf_reduce_func` but delay
- # adding it to the graph in case we need to rerun the function.
- tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access
-
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func, "tf.contrib.data.group_by_reducer()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(wrapped_func.output_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ for new_state_type, state_type in zip(
+ nest.flatten(wrapped_func.output_types),
+ nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
weakened_state_shapes = [
- old.most_specific_compatible_shape(new)
- for old, new in zip(flat_state_shapes, flat_new_state_shapes)
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
]
need_to_rerun = False
- for old_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if old_shape.ndims is not None and (
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
weakened_shape.ndims is None or
- old_shape.as_list() != weakened_shape.as_list()):
+ original_shape.as_list() != weakened_shape.as_list()):
need_to_rerun = True
break
@@ -432,52 +348,19 @@ class GroupByReducerDataset(dataset_ops.Dataset):
self._state_shapes = nest.pack_sequence_as(self._state_shapes,
weakened_state_shapes)
- self._reduce_func = tf_reduce_func
+ self._reduce_func = wrapped_func.function
self._reduce_func.add_to_graph(ops.get_default_graph())
def _make_finalize_func(self, finalize_func):
"""Make wrapping Defun for finalize_func."""
-
- @function.Defun(*(nest.flatten(
- sparse.as_dense_types(self._state_types, self._state_classes))))
- def tf_finalize_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- for arg, shape in zip(
- args,
- nest.flatten(
- sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(self._state_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, self._state_types, self._state_shapes,
- self._state_classes)
-
- ret = finalize_func(nested_args)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._output_classes = sparse.get_classes(ret)
- self._output_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._output_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._finalize_func = tf_finalize_func
- self._finalize_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ finalize_func, "tf.contrib.data.group_by_reducer()",
+ input_classes=self._state_classes, input_shapes=self._state_shapes,
+ input_types=self._state_types)
+ self._finalize_func = wrapped_func.function
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
@property
def output_classes(self):
@@ -520,77 +403,53 @@ class GroupByWindowDataset(dataset_ops.Dataset):
def _make_window_size_func(self, window_size_func):
"""Make wrapping Defun for window_size_func."""
-
- @function.Defun(dtypes.int64)
- def tf_window_size_func(key):
- key.set_shape([])
- window_size = ops.convert_to_tensor(
- window_size_func(key), dtype=dtypes.int64)
- if window_size.dtype != dtypes.int64:
- raise ValueError(
- "`window_size_func` must return a single tf.int64 tensor.")
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access
- return window_size
-
- self._window_size_func = tf_window_size_func
- self._window_size_func.add_to_graph(ops.get_default_graph())
+ def window_size_func_wrapper(key):
+ return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ window_size_func_wrapper, "tf.contrib.data.group_by_window()",
+ input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`window_size_func` must return a single tf.int64 scalar tensor.")
+ self._window_size_func = wrapped_func.function
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping Defun for key_func."""
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_key_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- # pylint: disable=protected-access
- if dataset_ops._should_unpack_args(nested_args):
- ret = key_func(*nested_args)
- # pylint: enable=protected-access
- else:
- ret = key_func(nested_args)
- ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
- if ret.dtype != dtypes.int64:
- raise ValueError("`key_func` must return a single tf.int64 tensor.")
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access
- return ret
-
- self._key_func = tf_key_func
- self._key_func.add_to_graph(ops.get_default_graph())
+ def key_func_wrapper(*args):
+ return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 scalar tensor.")
+ self._key_func = wrapped_func.function
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping Defun for reduce_func."""
-
- @function.Defun(dtypes.int64, dtypes.variant)
- def tf_reduce_func(key, window_dataset_variant):
- """A wrapper for Defun that facilitates shape inference."""
- key.set_shape([])
+ def reduce_func_wrapper(key, window_dataset_variant):
+ """Wrapper that converts between tf.variant and Dataset objects."""
window_dataset = _VariantDataset(
window_dataset_variant, input_dataset.output_types,
input_dataset.output_shapes, input_dataset.output_classes)
- if not isinstance(window_dataset, dataset_ops.Dataset):
- raise TypeError("`window_dataset` must return a `Dataset` object.")
output_dataset = reduce_func(key, window_dataset)
if not isinstance(output_dataset, dataset_ops.Dataset):
raise TypeError("`reduce_func` must return a `Dataset` object.")
self._output_classes = output_dataset.output_classes
self._output_types = output_dataset.output_types
self._output_shapes = output_dataset.output_shapes
- dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access
return output_dataset._as_variant_tensor() # pylint: disable=protected-access
- self._reduce_func = tf_reduce_func
- self._reduce_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func_wrapper, "tf.contrib.data.reduce_by_window()",
+ input_classes=(ops.Tensor, ops.Tensor),
+ input_shapes=(tensor_shape.scalar(), tensor_shape.scalar()),
+ input_types=(dtypes.int64, dtypes.variant))
+ self._reduce_func = wrapped_func.function
@property
def output_classes(self):
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 9612ac5ae9..2ca3805d66 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -61,6 +61,7 @@ class OptimizeDataset(dataset_ops.Dataset):
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._optimizations,
**dataset_ops.flat_structure(self))
+
@property
def output_classes(self):
return self._input_dataset.output_classes
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 67eede981c..ea9dcfe68f 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -22,7 +22,6 @@ import collections
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
@@ -67,104 +66,45 @@ class _ScanDataset(dataset_ops.Dataset):
need_to_rerun = True
while need_to_rerun:
- # Create a list in which `tf_scan_func` will store the new shapes.
- flat_new_state_shapes = []
-
- @function.Defun(*(nest.flatten(
- sparse.as_dense_types(
- self._state_types, self._state_classes)) + nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes))))
- def tf_scan_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the state and input_dataset.
- for arg, shape in zip(
- args,
- nest.flatten(
- sparse.as_dense_shapes(self._state_shapes, self._state_classes))
- + nest.flatten(
- sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes))):
- arg.set_shape(shape)
-
- pivot = len(nest.flatten(self._state_shapes))
- print(self._state_classes)
- nested_state_args = nest.pack_sequence_as(self._state_types,
- args[:pivot])
- nested_state_args = sparse.deserialize_sparse_tensors(
- nested_state_args, self._state_types, self._state_shapes,
- self._state_classes)
- print(input_dataset.output_classes)
- nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
- args[pivot:])
- nested_input_args = sparse.deserialize_sparse_tensors(
- nested_input_args, input_dataset.output_types,
- input_dataset.output_shapes, input_dataset.output_classes)
-
- ret = scan_func(nested_state_args, nested_input_args)
- if not isinstance(ret, collections.Sequence) or len(ret) != 2:
- raise TypeError("The scan function must return a pair comprising the "
- "new state and the output value.")
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
- new_state, output_value = ret
-
- # Extract and validate class information from the returned values.
- for t, clazz in zip(
- nest.flatten(new_state), nest.flatten(self._state_classes)):
- if not isinstance(t, clazz):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes,
- nest.pack_sequence_as(
- self._state_types,
- [type(t) for t in nest.flatten(new_state)])))
- self._output_classes = sparse.get_classes(output_value)
-
- # Extract shape information from the returned values.
- flat_new_state_shapes.extend(
- [t.get_shape() for t in nest.flatten(new_state)])
- self._output_shapes = nest.pack_sequence_as(
- output_value, [t.get_shape() for t in nest.flatten(output_value)])
-
- # Extract and validate type information from the returned values.
- for t, dtype in zip(
- nest.flatten(new_state), nest.flatten(self._state_types)):
- if t.dtype != dtype:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types,
- nest.pack_sequence_as(
- self._state_types,
- [t.dtype for t in nest.flatten(new_state)])))
- self._output_types = nest.pack_sequence_as(
- output_value, [t.dtype for t in nest.flatten(output_value)])
-
- dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access
-
- # Serialize any sparse tensors.
- new_state = nest.pack_sequence_as(new_state, [
- t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
- ])
- output_value = nest.pack_sequence_as(output_value, [
- t for t in nest.flatten(
- sparse.serialize_sparse_tensors(output_value))
- ])
- return nest.flatten(new_state) + nest.flatten(output_value)
-
- # Use the private method that will execute `tf_scan_func` but delay
- # adding it to the graph in case we need to rerun the function.
- tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ scan_func, "tf.contrib.data.scan()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+ if not (
+ isinstance(wrapped_func.output_types, collections.Sequence) and
+ len(wrapped_func.output_types) == 2):
+ raise TypeError("The scan function must return a pair comprising the "
+ "new state and the output value.")
+
+ new_state_classes, self._output_classes = wrapped_func.output_classes
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(new_state_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, new_state_classes))
+
+ # Extract and validate type information from the returned values.
+ new_state_types, self._output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(new_state_types), nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, new_state_types))
+
+ # Extract shape information from the returned values.
+ new_state_shapes, self._output_shapes = wrapped_func.output_shapes
flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(new_state_shapes)
weakened_state_shapes = [
original.most_specific_compatible_shape(new)
for original, new in zip(flat_state_shapes, flat_new_state_shapes)
@@ -180,12 +120,10 @@ class _ScanDataset(dataset_ops.Dataset):
break
if need_to_rerun:
- # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
- # `tf_scan_func`.
self._state_shapes = nest.pack_sequence_as(self._state_shapes,
weakened_state_shapes)
- self._scan_func = tf_scan_func
+ self._scan_func = wrapped_func.function
self._scan_func.add_to_graph(ops.get_default_graph())
def _as_variant_tensor(self):
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 9572ade8e4..aca544b7e7 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -238,17 +238,6 @@ class DistributedVariable(DistributedDelegate):
pass
-# Register a conversion function which reads the value of the variable,
-# allowing instances of the class to be used as tensors.
-def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
- # Try to avoid assignments to and other mutations of MirroredVariable
- # state except through a DistributionStrategy.update() call.
- assert not as_ref
- return ops.internal_convert_to_tensor(
- var.get(), dtype=dtype, name=name, as_ref=as_ref)
-
-
-ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
ops.register_dense_tensor_like_type(DistributedVariable)
@@ -342,6 +331,20 @@ class MirroredVariable(DistributedVariable, Mirrored,
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
+ # Try to avoid assignments to and other mutations of MirroredVariable
+ # state except through a DistributionStrategy.update() call.
+ assert not as_ref
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(MirroredVariable,
+ _tensor_conversion_mirrored)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""
@@ -431,6 +434,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+# Register a conversion function for TowerLocalVariable which allows as_ref to
+# be true.
+def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(TowerLocalVariable,
+ _tensor_conversion_tower_local)
+
+
def _devices_match(d1, d2):
return device_util.canonicalize(d1) == device_util.canonicalize(d2)
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 1c95758d96..b0bd92c7b0 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -966,6 +966,18 @@ class TowerLocalVariableTest(test.TestCase):
save_path = self._save_normal()
self._restore_tower_local_sum(save_path)
+ def testTensorConversion(self):
+ with context.graph_mode():
+ _, tower_local = _make_tower_local("sum")
+ converted = ops.internal_convert_to_tensor(tower_local, as_ref=False)
+ self.assertIsInstance(converted, ops.Tensor)
+ self.assertEqual(converted.dtype, tower_local.dtype)
+
+ converted = ops.internal_convert_to_tensor(tower_local, as_ref=True)
+ # Resources variable are converted to tensors as well when as_ref is True.
+ self.assertIsInstance(converted, ops.Tensor)
+ self.assertEqual(converted.dtype, tower_local.dtype)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 1d9371c7ac..6f02c90368 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -11,6 +11,8 @@ py_library(
"//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets",
"//tensorflow/contrib/eager/python/examples/linear_regression",
"//tensorflow/contrib/eager/python/examples/resnet50",
+ "//tensorflow/contrib/eager/python/examples/revnet",
+ "//tensorflow/contrib/eager/python/examples/revnet:config",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
"//tensorflow/contrib/eager/python/examples/spinn:data",
diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD
new file mode 100644
index 0000000000..bfb53cfff8
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD
@@ -0,0 +1,76 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+# Model
+py_library(
+ name = "ops",
+ srcs = ["ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "config",
+ srcs = ["config.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "blocks",
+ srcs = ["blocks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "revnet",
+ srcs = ["revnet.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":blocks",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+# Tests
+cuda_py_test(
+ name = "ops_test",
+ size = "large",
+ srcs = ["ops_test.py"],
+ additional_deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "blocks_test",
+ size = "large",
+ srcs = ["blocks_test.py"],
+ additional_deps = [
+ ":blocks",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "revnet_test",
+ size = "large",
+ srcs = ["revnet_test.py"],
+ additional_deps = [
+ ":config",
+ ":revnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
new file mode 100644
index 0000000000..fb4f9f068f
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -0,0 +1,335 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Building blocks with manual backward gradient computation.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import ops
+
+
+class RevBlock(tf.keras.Model):
+ """Single reversible block containing several `_Residual` blocks.
+
+ Each `_Residual` block in turn contains two _ResidualInner blocks,
+ corresponding to the `F`/`G` functions in the paper.
+ """
+
+ def __init__(self,
+ n_res,
+ filters,
+ strides,
+ input_shape,
+ batch_norm_first=False,
+ data_format="channels_first",
+ bottleneck=False,
+ fused=True):
+ """Initialize RevBlock.
+
+ Args:
+ n_res: number of residual blocks
+ filters: list/tuple of integers for output filter sizes of each residual
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ bottleneck: use bottleneck residual if True
+ fused: use fused batch normalization if True
+ """
+ super(RevBlock, self).__init__()
+ self.blocks = tf.contrib.checkpoint.List()
+ for i in range(n_res):
+ curr_batch_norm_first = batch_norm_first and i == 0
+ curr_strides = strides if i == 0 else (1, 1)
+ block = _Residual(
+ filters,
+ curr_strides,
+ input_shape,
+ batch_norm_first=curr_batch_norm_first,
+ data_format=data_format,
+ bottleneck=bottleneck,
+ fused=fused)
+ self.blocks.append(block)
+
+ if data_format == "channels_first":
+ input_shape = (filters, input_shape[1] // curr_strides[0],
+ input_shape[2] // curr_strides[1])
+ else:
+ input_shape = (input_shape[0] // curr_strides[0],
+ input_shape[1] // curr_strides[1], filters)
+
+ def call(self, h, training=True):
+ """Apply reversible block to inputs."""
+
+ for block in self.blocks:
+ h = block(h, training=training)
+ return h
+
+ def backward_grads_and_vars(self, x, y, dy, training=True):
+ """Apply reversible block backward to outputs."""
+
+ grads_all = []
+ vars_all = []
+
+ for i in reversed(range(len(self.blocks))):
+ block = self.blocks[i]
+ y_inv = x if i == 0 else block.backward(y, training=training)
+ dy, grads, vars_ = block.backward_grads_and_vars(
+ y_inv, dy, training=training)
+ grads_all += grads
+ vars_all += vars_
+
+ return dy, grads_all, vars_all
+
+
+class _Residual(tf.keras.Model):
+ """Single residual block contained in a _RevBlock. Each `_Residual` object has
+ two _ResidualInner objects, corresponding to the `F` and `G` functions in the
+ paper.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC",
+ bottleneck: use bottleneck residual if True
+ fused: use fused batch normalization if True
+ """
+
+ def __init__(self,
+ filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ bottleneck=False,
+ fused=True):
+ super(_Residual, self).__init__()
+
+ self.filters = filters
+ self.strides = strides
+ self.axis = 1 if data_format == "channels_first" else 3
+ if data_format == "channels_first":
+ f_input_shape = (input_shape[0] // 2,) + input_shape[1:]
+ g_input_shape = (filters // 2, input_shape[1] // strides[0],
+ input_shape[2] // strides[1])
+ else:
+ f_input_shape = input_shape[:2] + (input_shape[2] // 2,)
+ g_input_shape = (input_shape[0] // strides[0],
+ input_shape[1] // strides[1], filters // 2)
+
+ factory = _BottleneckResidualInner if bottleneck else _ResidualInner
+ self.f = factory(
+ filters=filters // 2,
+ strides=strides,
+ input_shape=f_input_shape,
+ batch_norm_first=batch_norm_first,
+ data_format=data_format,
+ fused=fused)
+ self.g = factory(
+ filters=filters // 2,
+ strides=(1, 1),
+ input_shape=g_input_shape,
+ batch_norm_first=batch_norm_first,
+ data_format=data_format,
+ fused=fused)
+
+ def call(self, x, training=True, concat=True):
+ """Apply residual block to inputs."""
+
+ x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
+ f_x2 = self.f.call(x2, training=training)
+ # TODO(lxuechen): Replace with simpler downsampling
+ x1_down = ops.downsample(
+ x1, self.filters // 2, self.strides, axis=self.axis)
+ x2_down = ops.downsample(
+ x2, self.filters // 2, self.strides, axis=self.axis)
+ y1 = f_x2 + x1_down
+ g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error
+ y2 = g_y1 + x2_down
+ if not concat: # Concat option needed for correct backward grads
+ return y1, y2
+ return tf.concat([y1, y2], axis=self.axis)
+
+ def backward(self, y, training=True):
+ """Reconstruct inputs from outputs; only valid when stride 1."""
+
+ assert self.strides == (1, 1)
+
+ y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
+ g_y1 = self.g.call(y1, training=training)
+ x2 = y2 - g_y1
+ f_x2 = self.f.call(x2, training=training)
+ x1 = y1 - f_x2
+
+ return tf.concat([x1, x2], axis=self.axis)
+
+ def backward_grads_and_vars(self, x, dy, training=True):
+ """Manually compute backward gradients given input and output grads."""
+
+ with tf.GradientTape(persistent=True) as tape:
+ x_stop = tf.stop_gradient(x)
+ x1, x2 = tf.split(x_stop, num_or_size_splits=2, axis=self.axis)
+ tape.watch([x1, x2])
+ # Stitch back x for `call` so tape records correct grads
+ x = tf.concat([x1, x2], axis=self.axis)
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
+ y1, y2 = self.call(x, training=training, concat=False)
+ x2_down = ops.downsample(
+ x2, self.filters // 2, self.strides, axis=self.axis)
+
+ grads_combined = tape.gradient(
+ y2, [y1] + self.g.variables, output_gradients=[dy2])
+ dy2_y1, dg = grads_combined[0], grads_combined[1:]
+ dy1_plus = dy2_y1 + dy1
+
+ grads_combined = tape.gradient(
+ y1, [x1, x2] + self.f.variables, output_gradients=[dy1_plus])
+ dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:]
+ dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0]
+
+ del tape
+
+ grads = df + dg
+ vars_ = self.f.variables + self.g.variables
+
+ return tf.concat([dx1, dx2], axis=self.axis), grads, vars_
+
+
+def _BottleneckResidualInner(filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ fused=True):
+ """Single bottleneck residual inner function contained in _Resdual.
+
+ Corresponds to the `F`/`G` functions in the paper.
+ Suitable for training on ImageNet dataset.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ fused: use fused batch normalization if True
+
+ Returns:
+ A keras model
+ """
+
+ axis = 1 if data_format == "channels_first" else 3
+ model = tf.keras.Sequential()
+ if batch_norm_first:
+ model.add(
+ tf.keras.layers.BatchNormalization(
+ axis=axis, input_shape=input_shape, fused=fused))
+ model.add(tf.keras.layers.LeakyReLU(alpha=0.))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters // 4,
+ kernel_size=1,
+ strides=strides,
+ input_shape=input_shape,
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME"))
+
+ model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused))
+ model.add(tf.keras.layers.LeakyReLU(alpha=0.))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters // 4,
+ kernel_size=3,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME"))
+
+ model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused))
+ model.add(tf.keras.layers.LeakyReLU(alpha=0.))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME"))
+
+ return model
+
+
+def _ResidualInner(filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ fused=True):
+ """Single residual inner function contained in _ResdualBlock.
+
+ Corresponds to the `F`/`G` functions in the paper.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ fused: use fused batch normalization if True
+
+ Returns:
+ A keras model
+ """
+
+ axis = 1 if data_format == "channels_first" else 3
+ model = tf.keras.Sequential()
+ if batch_norm_first:
+ model.add(
+ tf.keras.layers.BatchNormalization(
+ axis=axis, input_shape=input_shape, fused=fused))
+ model.add(tf.keras.layers.LeakyReLU(alpha=0.))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=3,
+ strides=strides,
+ input_shape=input_shape,
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME"))
+
+ model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused))
+ model.add(tf.keras.layers.LeakyReLU(alpha=0.))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=3,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME"))
+
+ return model
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
new file mode 100644
index 0000000000..f4436fd925
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -0,0 +1,346 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic building blocks used in eager mode RevNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import blocks
+
+
+def _validate_block_call_channels_last(block_factory, test):
+ """Generic testing function for `channels_last` data format.
+
+ Completes a set of tests varying data format, stride, and batch normalization
+ configured train vs test time.
+ Args:
+ block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock,
+ blocks._ResidualInner
+ test: tf.test.TestCase object
+ """
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (224, 224, 32)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride 1
+ block = block_factory(
+ filters=64,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 224, 224, 64))
+ test.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = block_factory(
+ filters=64,
+ strides=(2, 2),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 112, 112, 64))
+ test.assertNotAllClose(y_tr, y_ev)
+
+
+def _validate_block_call_channels_first(block_factory, test):
+ """Generic testing function for `channels_first` data format.
+
+ Completes a set of tests varying data format, stride, and batch normalization
+ configured train vs test time.
+ Args:
+ block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock,
+ blocks._ResidualInner
+ test: tf.test.TestCase object
+ """
+ if not tf.test.is_gpu_available():
+ test.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (32, 224, 224)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride of 1
+ block = block_factory(filters=64, strides=(1, 1), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 64, 224, 224))
+ test.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = block_factory(filters=64, strides=(2, 2), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 64, 112, 112))
+ test.assertNotAllClose(y_tr, y_ev)
+
+
+class RevBlockTest(tf.test.TestCase):
+
+ def test_call_channels_first(self):
+ """Test `call` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (32, 224, 224)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride of 1
+ block = blocks.RevBlock(
+ n_res=3, filters=64, strides=(1, 1), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, (16, 64, 224, 224))
+ self.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = blocks.RevBlock(
+ n_res=3, filters=64, strides=(2, 2), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, [16, 64, 112, 112])
+ self.assertNotAllClose(y_tr, y_ev)
+
+ def test_call_channels_last(self):
+ """Test `call` function with `channels_last` data format."""
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (224, 224, 32)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride 1
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=64,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, (16, 224, 224, 64))
+ self.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=64,
+ strides=(2, 2),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, (16, 112, 112, 64))
+ self.assertNotAllClose(y_tr, y_ev)
+
+ def test_backward_grads_and_vars_channels_first(self):
+ """Test `backward` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (32, 224, 224)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride 1
+ y = tf.random_normal(shape=data_shape)
+ dy = tf.random_normal(shape=data_shape)
+ block = blocks.RevBlock(
+ n_res=3, filters=32, strides=(1, 1), input_shape=input_shape)
+ dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
+ self.assertEqual(dy.shape, x.shape)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+
+ # Stride 2
+ y = tf.random_normal(shape=(16, 32, 112, 112))
+ dy = tf.random_normal(shape=(16, 32, 112, 112))
+ block = blocks.RevBlock(
+ n_res=3, filters=32, strides=(2, 2), input_shape=input_shape)
+ dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
+ self.assertEqual(dy.shape, x.shape)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+
+ def test_backward_grads_and_vars_channels_last(self):
+ """Test `backward` function with `channels_last` data format."""
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (224, 224, 32)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride 1
+ y = tf.random_normal(shape=data_shape)
+ dy = tf.random_normal(shape=data_shape)
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=32,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
+ self.assertEqual(dy.shape, x.shape)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+
+ # Stride 2
+ y = tf.random_normal(shape=(16, 112, 112, 32))
+ dy = tf.random_normal(shape=(16, 112, 112, 32))
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=32,
+ strides=(2, 2),
+ input_shape=input_shape,
+ data_format="channels_last")
+ dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
+ self.assertEqual(dy.shape, x.shape)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+
+
+class _ResidualTest(tf.test.TestCase):
+
+ def test_call(self):
+ """Test `call` function.
+
+ Varying downsampling and data format options.
+ """
+
+ _validate_block_call_channels_first(blocks._Residual, self)
+ _validate_block_call_channels_last(blocks._Residual, self)
+
+ def test_backward_channels_first(self):
+ """Test `backward` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (16, 224, 224)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+ residual = blocks._Residual(
+ filters=16, strides=(1, 1), input_shape=input_shape)
+ y_tr, y_ev = residual(x, training=True), residual(x, training=False)
+ x_ = residual.backward(y_tr, training=True)
+ # The numerical loss is alarming; reconstructed inputs could differ from
+ # the original inputs often by more than 1e-3
+ self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
+ x_ = residual.backward(y_ev, training=False)
+ self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
+
+ def test_backward_channels_last(self):
+ """Test `backward` function with `channels_last` data format."""
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (224, 224, 16)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+ residual = blocks._Residual(
+ filters=16,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = residual(x, training=True), residual(x, training=False)
+ x_ = residual.backward(y_tr, training=True)
+ # Egregious numerical error
+ self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
+ x_ = residual.backward(y_ev, training=False)
+ self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
+
+ def test_backward_grads_and_vars_channels_first(self):
+ """Test `backward_grads` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (16, 224, 224)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+ dy = tf.random_normal(shape=data_shape)
+ residual = blocks._Residual(
+ filters=16, strides=(1, 1), input_shape=input_shape)
+ dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars(
+ x, dy=dy, training=True)
+ dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars(
+ x, dy=dy, training=False)
+ self.assertNotAllClose(dx_tr, dx_ev)
+ self.assertTrue(isinstance(grads_tr, list))
+ self.assertTrue(isinstance(grads_ev, list))
+ self.assertTrue(isinstance(vars_tr, list))
+ self.assertTrue(isinstance(vars_ev, list))
+ for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev,
+ vars_ev):
+ if grad_tr is not None: # Batch norm moving mean, var gives None grad
+ self.assertEqual(grad_tr.shape, grad_ev.shape)
+ self.assertEqual(var_tr.shape, var_ev.shape)
+ self.assertEqual(grad_tr.shape, var_tr.shape)
+
+ def test_backward_grads_and_vars_channels_last(self):
+ """Test `backward_grads` function with `channels_last` data format."""
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (224, 224, 16)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+ dy = tf.random_normal(shape=data_shape)
+ residual = blocks._Residual(
+ filters=16,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars(
+ x, dy=dy, training=True)
+ dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars(
+ x, dy=dy, training=False)
+ self.assertNotAllClose(dx_tr, dx_ev)
+ self.assertTrue(isinstance(grads_tr, list))
+ self.assertTrue(isinstance(grads_ev, list))
+ self.assertTrue(isinstance(vars_tr, list))
+ self.assertTrue(isinstance(vars_ev, list))
+ for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev,
+ vars_ev):
+ if grad_tr is not None: # Batch norm moving mean, var gives None grad
+ self.assertEqual(grad_tr.shape, grad_ev.shape)
+ self.assertEqual(var_tr.shape, var_ev.shape)
+ self.assertEqual(grad_tr.shape, var_tr.shape)
+
+
+class _ResidualInnerTest(tf.test.TestCase):
+
+ def test_call(self):
+ """Test `call` function."""
+
+ _validate_block_call_channels_first(blocks._ResidualInner, self)
+ _validate_block_call_channels_last(blocks._ResidualInner, self)
+
+
+class _BottleneckResidualInner(tf.test.TestCase):
+
+ def test_call(self):
+ """Test `call` function."""
+
+ _validate_block_call_channels_first(blocks._BottleneckResidualInner, self)
+ _validate_block_call_channels_last(blocks._BottleneckResidualInner, self)
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py
new file mode 100644
index 0000000000..495a78d550
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/config.py
@@ -0,0 +1,117 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Configuration in format of tf.contrib.training.HParams.
+Supports CIFAR-10, CIFAR-100, and ImageNet datasets.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def get_hparams_cifar_38():
+ """RevNet-38 configurations for CIFAR-10/CIFAR-100."""
+
+ config = tf.contrib.training.HParams()
+ config.add_hparam("init_filters", 32)
+ config.add_hparam("init_kernel", 3)
+ config.add_hparam("init_stride", 1)
+ config.add_hparam("n_classes", 10)
+ config.add_hparam("n_rev_blocks", 3)
+ config.add_hparam("n_res", [3, 3, 3])
+ config.add_hparam("filters", [32, 64, 112])
+ config.add_hparam("strides", [1, 2, 2])
+ config.add_hparam("batch_size", 10)
+ config.add_hparam("bottleneck", False)
+ config.add_hparam("fused", True)
+ config.add_hparam("init_max_pool", False)
+ if tf.test.is_gpu_available():
+ config.add_hparam("input_shape", (3, 32, 32))
+ config.add_hparam("data_format", "channels_first")
+ else:
+ config.add_hparam("input_shape", (32, 32, 3))
+ config.add_hparam("data_format", "channels_last")
+
+ # Training details
+ config.add_hparam("weight_decay", 2e-4)
+ config.add_hparam("momentum", .9)
+ config.add_hparam("lr_decay_steps", [40000, 60000])
+ config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3])
+ config.add_hparam("max_train_iter", 80000)
+ config.add_hparam("seed", 1234)
+ config.add_hparam("shuffle", True)
+ config.add_hparam("prefetch", True)
+ config.add_hparam("print_every", 50)
+ config.add_hparam("dtype", tf.float32)
+ config.add_hparam("eval_batch_size", 500)
+ config.add_hparam("div255", True)
+ # For tf.data.Dataset
+ config.add_hparam("epochs", config.max_train_iter // config.batch_size)
+
+ return config
+
+
+def get_hparams_imagenet_56():
+ """RevNet-56 configurations for ImageNet."""
+
+ config = tf.contrib.training.HParams()
+ config.add_hparam("init_filters", 128)
+ config.add_hparam("init_kernel", 7)
+ config.add_hparam("init_stride", 2)
+ config.add_hparam("n_classes", 1000)
+ config.add_hparam("n_rev_blocks", 4)
+ config.add_hparam("n_res", [2, 2, 2, 2])
+ config.add_hparam("filters", [128, 256, 512, 832])
+ config.add_hparam("strides", [1, 2, 2, 2])
+ config.add_hparam("batch_size", 16)
+ config.add_hparam("bottleneck", True)
+ config.add_hparam("fused", True)
+ config.add_hparam("init_max_pool", True)
+ if tf.test.is_gpu_available():
+ config.add_hparam("input_shape", (3, 224, 224))
+ config.add_hparam("data_format", "channels_first")
+ else:
+ config.add_hparam("input_shape", (224, 224, 3))
+ config.add_hparam("data_format", "channels_last")
+
+ # Training details
+ config.add_hparam("weight_decay", 1e-4)
+ config.add_hparam("momentum", .9)
+ config.add_hparam("lr_decay_steps", [160000, 320000, 480000])
+ config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4])
+ config.add_hparam("max_train_iter", 600000)
+ config.add_hparam("seed", 1234)
+ config.add_hparam("shuffle", True)
+ config.add_hparam("prefetch", True)
+ config.add_hparam("print_every", 50)
+ config.add_hparam("dtype", tf.float32)
+ config.add_hparam("eval_batch_size", 500)
+ config.add_hparam("div255", True)
+ # For tf.data.Dataset
+ config.add_hparam("epochs", config.max_train_iter // config.batch_size)
+
+ if config.bottleneck:
+ filters = [f * 4 for f in config.filters]
+ config.filters = filters
+
+ return config
diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py
new file mode 100644
index 0000000000..9ed5d363e6
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Customized basic operations.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def downsample(x, filters, strides, axis=1):
+ """Downsample feature map with avg pooling, if filter size doesn't match."""
+
+ def pad_strides(strides, axis=1):
+ """Convert length 2 to length 4 strides.
+
+ Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations
+ such as `tf.nn.avg_pool` use length 4 strides.
+
+ Args:
+ strides: length 2 list/tuple strides for height and width
+ axis: integer specifying feature dimension according to data format
+ Returns:
+ length 4 strides padded with 1 on batch and channel dimension
+ """
+
+ assert len(strides) == 2
+
+ if axis == 1:
+ return [1, 1, strides[0], strides[1]]
+ return [1, strides[0], strides[1], 1]
+
+ assert len(x.shape) == 4 and (axis == 1 or axis == 3)
+
+ data_format = "NCHW" if axis == 1 else "NHWC"
+ strides_ = pad_strides(strides, axis=axis)
+
+ if strides[0] > 1:
+ x = tf.nn.avg_pool(
+ x, strides_, strides_, padding="VALID", data_format=data_format)
+
+ in_filter = x.shape[axis]
+ out_filter = filters
+
+ if in_filter < out_filter:
+ pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2]
+ if axis == 1:
+ x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]])
+ else:
+ x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size])
+ # In case `tape.gradient(x, [x])` produces a list of `None`
+ return x + 0.
diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py
new file mode 100644
index 0000000000..5bc2641faf
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic ops used in eager mode RevNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import ops
+tfe = tf.contrib.eager
+
+
+class OpsTest(tf.test.TestCase):
+
+ def test_downsample(self):
+ """Test `possible_down_sample` function with mock object."""
+
+ batch_size = 100
+ # NHWC format
+ x = tf.random_normal(shape=[batch_size, 32, 32, 3])
+ # HW doesn't change but number of features increased
+ y = ops.downsample(x, filters=5, strides=(1, 1), axis=3)
+ self.assertEqual(y.shape, [batch_size, 32, 32, 5])
+ # Feature map doesn't change but HW reduced
+ y = ops.downsample(x, filters=3, strides=(2, 2), axis=3)
+ self.assertEqual(y.shape, [batch_size, 16, 16, 3])
+ # Number of feature increased and HW reduced
+ y = ops.downsample(x, filters=5, strides=(2, 2), axis=3)
+ self.assertEqual(y.shape, [batch_size, 16, 16, 5])
+
+ # Test gradient flow
+ x = tf.random_normal(shape=[batch_size, 32, 32, 3])
+ with tfe.GradientTape() as tape:
+ tape.watch(x)
+ y = ops.downsample(x, filters=3, strides=(1, 1))
+ self.assertEqual(y.shape, x.shape)
+ dy = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ grad, = tape.gradient(y, [x], output_gradients=[dy])
+ self.assertEqual(grad.shape, x.shape)
+
+ # Default NCHW format
+ if tf.test.is_gpu_available():
+ x = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ # HW doesn't change but feature map reduced
+ y = ops.downsample(x, filters=5, strides=(1, 1))
+ self.assertEqual(y.shape, [batch_size, 5, 32, 32])
+ # Feature map doesn't change but HW reduced
+ y = ops.downsample(x, filters=3, strides=(2, 2))
+ self.assertEqual(y.shape, [batch_size, 3, 16, 16])
+ # Both feature map and HW reduced
+ y = ops.downsample(x, filters=5, strides=(2, 2))
+ self.assertEqual(y.shape, [batch_size, 5, 16, 16])
+
+ # Test gradient flow
+ x = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ with tfe.GradientTape() as tape:
+ tape.watch(x)
+ y = ops.downsample(x, filters=3, strides=(1, 1))
+ self.assertEqual(y.shape, x.shape)
+ dy = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ grad, = tape.gradient(y, [x], output_gradients=[dy])
+ self.assertEqual(grad.shape, x.shape)
+
+
+if __name__ == '__main__':
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
new file mode 100644
index 0000000000..aa3f7efe1b
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -0,0 +1,263 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Code for main model.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import operator
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import blocks
+
+
+# Global Conventions:
+# 1) Default data format is NCWH, targeting GPU
+# 2) Each block has attribute axis, inferred from data_format
+# 3) Default training option to True for batch normalization
+class RevNet(tf.keras.Model):
+ """RevNet that depends on all the blocks."""
+
+ def __init__(self, config):
+ """Initialize RevNet with building blocks.
+
+ Args:
+ config: tf.contrib.training.HParams object; specifies hyperparameters
+ """
+ super(RevNet, self).__init__()
+ self.axis = 1 if config.data_format == "channels_first" else 3
+ self.config = config
+
+ self._init_block = self._construct_init_block()
+ self._block_list = self._construct_intermediate_blocks()
+ self._final_block = self._construct_final_block()
+
+ def _construct_init_block(self):
+ init_block = tf.keras.Sequential(
+ [
+ tf.keras.layers.Conv2D(
+ filters=self.config.init_filters,
+ kernel_size=self.config.init_kernel,
+ strides=(self.config.init_stride, self.config.init_stride),
+ data_format=self.config.data_format,
+ use_bias=False,
+ padding="SAME",
+ input_shape=self.config.input_shape),
+ tf.keras.layers.BatchNormalization(
+ axis=self.axis, fused=self.config.fused),
+ tf.keras.layers.LeakyReLU(alpha=0.)
+ ],
+ name="init")
+ if self.config.init_max_pool:
+ init_block.add(
+ tf.keras.layers.MaxPooling2D(
+ pool_size=(3, 3),
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format))
+ return init_block
+
+ def _construct_final_block(self):
+ f = self.config.filters[-1] # Number of filters
+ r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio
+ r *= self.config.init_stride
+ if self.config.init_max_pool:
+ r *= 2
+
+ if self.config.data_format == "channels_first":
+ w, h = self.config.input_shape[1], self.config.input_shape[2]
+ input_shape = (f, w // r, h // r)
+ elif self.config.data_format == "channels_last":
+ w, h = self.config.input_shape[0], self.config.input_shape[1]
+ input_shape = (w // r, h // r, f)
+ else:
+ raise ValueError("Data format should be either `channels_first`"
+ " or `channels_last`")
+
+ final_block = tf.keras.Sequential(
+ [
+ tf.keras.layers.BatchNormalization(
+ axis=self.axis,
+ input_shape=input_shape,
+ fused=self.config.fused),
+ tf.keras.layers.LeakyReLU(alpha=0.), # Vanilla ReLU
+ tf.keras.layers.GlobalAveragePooling2D(
+ data_format=self.config.data_format),
+ tf.keras.layers.Dense(self.config.n_classes)
+ ],
+ name="final")
+ return final_block
+
+ def _construct_intermediate_blocks(self):
+ # Precompute input shape after initial block
+ stride = self.config.init_stride
+ if self.config.init_max_pool:
+ stride *= 2
+ if self.config.data_format == "channels_first":
+ w, h = self.config.input_shape[1], self.config.input_shape[2]
+ input_shape = (self.config.init_filters, w // stride, h // stride)
+ else:
+ w, h = self.config.input_shape[0], self.config.input_shape[1]
+ input_shape = (w // stride, h // stride, self.config.init_filters)
+
+ # Aggregate intermediate blocks
+ block_list = tf.contrib.checkpoint.List()
+ for i in range(self.config.n_rev_blocks):
+ # RevBlock configurations
+ n_res = self.config.n_res[i]
+ filters = self.config.filters[i]
+ if filters % 2 != 0:
+ raise ValueError("Number of output filters must be even to ensure"
+ "correct partitioning of channels")
+ stride = self.config.strides[i]
+ strides = (self.config.strides[i], self.config.strides[i])
+
+ # Add block
+ rev_block = blocks.RevBlock(
+ n_res,
+ filters,
+ strides,
+ input_shape,
+ batch_norm_first=(i != 0), # Only skip on first block
+ data_format=self.config.data_format,
+ bottleneck=self.config.bottleneck,
+ fused=self.config.fused)
+ block_list.append(rev_block)
+
+ # Precompute input shape for the next block
+ if self.config.data_format == "channels_first":
+ w, h = input_shape[1], input_shape[2]
+ input_shape = (filters, w // stride, h // stride)
+ else:
+ w, h = input_shape[0], input_shape[1]
+ input_shape = (w // stride, h // stride, filters)
+
+ return block_list
+
+ def call(self, inputs, training=True):
+ """Forward pass."""
+
+ # Only store hidden states during training
+ if training:
+ saved_hidden = [inputs]
+
+ h = self._init_block(inputs, training=training)
+ if training:
+ saved_hidden.append(h)
+
+ for block in self._block_list:
+ h = block(h, training=training)
+ if training:
+ saved_hidden.append(h)
+
+ logits = self._final_block(h, training=training)
+
+ return (logits, saved_hidden) if training else (logits, None)
+
+ def compute_loss(self, logits, labels):
+ """Compute cross entropy loss."""
+
+ cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
+
+ return tf.reduce_mean(cross_ent)
+
+ def compute_gradients(self, inputs, labels, training=True):
+ """Manually computes gradients.
+
+ Args:
+ inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
+ labels: One-hot labels for classification
+ training: for batch normalization
+
+ Returns:
+ list of tuple each being (grad, var) for optimizer use
+ """
+
+ # Forward pass record hidden states before downsampling
+ _, saved_hidden = self.call(inputs, training=training)
+
+ grads_all = []
+ vars_all = []
+
+ # Manually backprop through last block
+ x = saved_hidden[-1]
+ with tf.GradientTape() as tape:
+ tape.watch(x)
+ logits = self._final_block(x, training=training)
+ cost = self.compute_loss(logits, labels)
+
+ grads_combined = tape.gradient(cost, [x] + self._final_block.variables)
+ dy, grads_ = grads_combined[0], grads_combined[1:]
+ grads_all += grads_
+ vars_all += self._final_block.variables
+
+ # Manually backprop through intermediate blocks
+ for block in reversed(self._block_list):
+ y = saved_hidden.pop()
+ x = saved_hidden[-1]
+ dy, grads, vars_ = block.backward_grads_and_vars(
+ x, y, dy, training=training)
+ grads_all += grads
+ vars_all += vars_
+
+ # Manually backprop through first block
+ saved_hidden.pop()
+ x = saved_hidden.pop()
+ assert not saved_hidden # Cleared after backprop
+
+ with tf.GradientTape() as tape:
+ y = self._init_block(x, training=training) # Recomputing
+
+ grads_all += tape.gradient(
+ y, self._init_block.variables, output_gradients=[dy])
+ vars_all += self._init_block.variables
+
+ return grads_all, vars_all
+
+ def train_step(self,
+ inputs,
+ labels,
+ optimizer,
+ global_step=None,
+ report=False):
+ """Train for one iteration."""
+
+ grads_all, vars_all = self.compute_gradients(inputs, labels, training=True)
+ optimizer.apply_gradients(zip(grads_all, vars_all), global_step=global_step)
+
+ if report:
+ logits, _ = self.call(inputs, training=True)
+ loss = self.compute_loss(logits, labels)
+
+ return loss
+
+ def eval_step(self, inputs, labels):
+ """Evaluate."""
+
+ logits, _ = self.call(inputs, training=False)
+ preds = tf.cast(tf.argmax(logits, axis=1), tf.int32)
+ corrects = tf.cast(tf.equal(preds, labels), tf.float32)
+ accuracy = tf.reduce_mean(corrects)
+
+ return accuracy
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
new file mode 100644
index 0000000000..68502ceac2
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -0,0 +1,277 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic building blocks used in eager mode RevNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import time
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import config as config_
+from tensorflow.contrib.eager.python.examples.revnet import revnet
+from tensorflow.python.client import device_lib
+tfe = tf.contrib.eager
+
+
+class RevnetTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(RevnetTest, self).setUp()
+ config = config_.get_hparams_imagenet_56()
+ shape = (config.batch_size,) + config.input_shape
+ self.model = revnet.RevNet(config=config)
+ self.x = tf.random_normal(shape=shape)
+ self.t = tf.random_uniform(
+ shape=[config.batch_size],
+ minval=0,
+ maxval=config.n_classes,
+ dtype=tf.int32)
+ self.config = config
+
+ def tearDown(self):
+ del self.model
+ del self.x
+ del self.t
+ del self.config
+ super(RevnetTest, self).tearDown()
+
+ def test_call(self):
+ """Test `call` function."""
+
+ y, _ = self.model(self.x, training=False)
+ self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes])
+
+ def test_compute_gradients(self):
+ """Test `compute_gradients` function."""
+
+ grads, vars_ = self.model.compute_gradients(inputs=self.x, labels=self.t)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+ self.assertEqual(len(grads), len(vars_))
+ for grad, var in zip(grads, vars_):
+ if grad is not None:
+ self.assertEqual(grad.shape, var.shape)
+
+ def test_train_step(self):
+ """Test `train_step` function."""
+
+ logits, _ = self.model(self.x, training=True)
+ loss = self.model.compute_loss(logits=logits, labels=self.t)
+ optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
+
+ # Loss should be decreasing after each optimization step
+ for _ in range(3):
+ loss_ = self.model.train_step(self.x, self.t, optimizer, report=True)
+ self.assertTrue(loss_.numpy() <= loss.numpy())
+ loss = loss_
+
+ def test_call_defun(self):
+ """Test `call` function with tfe.defun apply."""
+
+ y, _ = tfe.defun(self.model.call)(self.x, training=False)
+ self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes])
+
+ def test_train_step_defun(self):
+ self.model.call = tfe.defun(self.model.call)
+ logits, _ = self.model(self.x, training=True)
+ loss = self.model.compute_loss(logits=logits, labels=self.t)
+ optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
+
+ for _ in range(3):
+ loss_ = self.model.train_step(self.x, self.t, optimizer, report=True)
+ self.assertTrue(loss_.numpy() <= loss.numpy())
+ loss = loss_
+
+ # Initialize new model, so that other tests are not affected
+ self.model = revnet.RevNet(config=self.config)
+
+
+# Benchmark related
+def device_and_data_format():
+ return ("/gpu:0",
+ "channels_first") if tf.test.is_gpu_available() else ("/cpu:0",
+ "channels_last")
+
+
+def random_batch(batch_size, config):
+ shape = (batch_size,) + config.input_shape
+ images = tf.random_uniform(shape)
+ labels = tf.random_uniform(
+ [batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32)
+
+ return images, labels
+
+
+class MockIterator(object):
+
+ def __init__(self, tensors):
+ self._tensors = [tf.identity(x) for x in tensors]
+
+ def next(self):
+ return self._tensors
+
+
+class RevnetBenchmark(tf.test.Benchmark):
+ """Eager and graph benchmarks for RevNet."""
+
+ def _train_batch_sizes(self):
+ """Shamelessly copied from `resnet50_test.py`.
+
+ Note: This is targeted towards ImageNet. CIFAR-10 should allow more
+ aggressive batch sizes.
+
+ Returns:
+ A tuple of possible batch sizes
+ """
+ for device in device_lib.list_local_devices():
+ if tf.DeviceSpec.from_string(device.name).device_type == "GPU":
+ if "K20" in device.physical_device_desc:
+ return (16,)
+ if "P100" in device.physical_device_desc:
+ return (16, 32, 64)
+ if tf.DeviceSpec.from_string(device.name).device_type == "TPU":
+ return (32,)
+ return (16, 32)
+
+ def _force_device_sync(self):
+ """Shamelessly copied from `resnet50_test.py`."""
+ tf.constant(1.).cpu()
+
+ def _report(self, label, start, num_iters, device, batch_size, data_format):
+ avg_time = (time.time() - start) / num_iters
+ dev = tf.DeviceSpec.from_string(device).device_type.lower()
+ name = "%s_%s_batch_%d_%s" % (label, dev, batch_size, data_format)
+ extras = {"examples_per_sec": batch_size / avg_time}
+ self.report_benchmark(
+ iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+ def _benchmark_eager_apply(self,
+ label,
+ device_and_format,
+ defun=False,
+ execution_mode=None,
+ compiled=False):
+ config = config_.get_hparams_imagenet_56()
+ with tfe.execution_mode(execution_mode):
+ device, data_format = device_and_format
+ model = revnet.RevNet(config=config)
+ if defun:
+ model.call = tfe.defun(model.call, compiled=compiled)
+ batch_size = 64
+ num_burn = 5
+ num_iters = 10
+ with tf.device(device):
+ images, _ = random_batch(batch_size, config)
+ for _ in range(num_burn):
+ model(images, training=False)
+ if execution_mode:
+ tfe.async_wait()
+ gc.collect()
+ start = time.time()
+ for _ in range(num_iters):
+ model(images, training=False)
+ if execution_mode:
+ tfe.async_wait()
+ self._report(label, start, num_iters, device, batch_size, data_format)
+
+ def benchmark_eager_apply_sync(self):
+ self._benchmark_eager_apply(
+ "eager_apply_sync", device_and_data_format(), defun=False)
+
+ def benchmark_eager_apply_async(self):
+ self._benchmark_eager_apply(
+ "eager_apply_async",
+ device_and_data_format(),
+ defun=False,
+ execution_mode=tfe.ASYNC)
+
+ def benchmark_eager_call_defun(self):
+ self._benchmark_eager_apply(
+ "eager_apply_with_defun", device_and_data_format(), defun=True)
+
+ def _benchmark_eager_train(self,
+ label,
+ make_iterator,
+ device_and_format,
+ defun=False,
+ execution_mode=None,
+ compiled=False):
+ config = config_.get_hparams_imagenet_56()
+ with tfe.execution_mode(execution_mode):
+ device, data_format = device_and_format
+ for batch_size in self._train_batch_sizes():
+ (images, labels) = random_batch(batch_size, config)
+ model = revnet.RevNet(config=config)
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ if defun:
+ model.call = tfe.defun(model.call)
+
+ num_burn = 3
+ num_iters = 10
+ with tf.device(device):
+ iterator = make_iterator((images, labels))
+ for _ in range(num_burn):
+ (images, labels) = iterator.next()
+ model.train_step(images, labels, optimizer)
+ if execution_mode:
+ tfe.async_wait()
+ self._force_device_sync()
+ gc.collect()
+
+ start = time.time()
+ for _ in range(num_iters):
+ (images, labels) = iterator.next()
+ model.train_step(images, labels, optimizer)
+ if execution_mode:
+ tfe.async_wait()
+ self._force_device_sync()
+ self._report(label, start, num_iters, device, batch_size, data_format)
+
+ def benchmark_eager_train_sync(self):
+ self._benchmark_eager_train(
+ "eager_train_sync", MockIterator, device_and_data_format(), defun=False)
+
+ def benchmark_eager_train_async(self):
+ self._benchmark_eager_train(
+ "eager_train_async",
+ MockIterator,
+ device_and_data_format(),
+ defun=False,
+ execution_mode=tfe.ASYNC)
+
+ def benchmark_eager_train_defun(self):
+ self._benchmark_eager_train(
+ "eager_train", MockIterator, device_and_data_format(), defun=False)
+
+ def benchmark_eager_train_datasets_with_defun(self):
+
+ def make_iterator(tensors):
+ with tf.device("/device:CPU:0"):
+ ds = tf.data.Dataset.from_tensors(tensors).repeat()
+ return tfe.Iterator(ds)
+
+ self._benchmark_eager_train(
+ "eager_train_dataset_with_defun",
+ make_iterator,
+ device_and_data_format(),
+ defun=True)
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 1937ffb583..30d297a5fb 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -117,7 +117,7 @@ py_library(
py_test(
name = "dnn_test",
- size = "small",
+ size = "medium",
srcs = ["python/estimator/dnn_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py
index 7ff25b95c0..f1c60a912c 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn.py
@@ -53,6 +53,13 @@ class DNNEstimator(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator with warm-starting from a previous checkpoint.
+ estimator = DNNEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ warm_start_from="/path/to/checkpoint/dir")
+
# Input builders
def input_fn_train: # returns x, y
pass
@@ -92,7 +99,8 @@ class DNNEstimator(estimator.Estimator):
activation_fn=nn.relu,
dropout=None,
input_layer_partitioner=None,
- config=None):
+ config=None,
+ warm_start_from=None):
"""Initializes a `DNNEstimator` instance.
Args:
@@ -116,6 +124,11 @@ class DNNEstimator(estimator.Estimator):
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or
+ a `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
"""
def _model_fn(features, labels, mode, config):
return dnn_lib._dnn_model_fn( # pylint: disable=protected-access
@@ -131,4 +144,5 @@ class DNNEstimator(estimator.Estimator):
input_layer_partitioner=input_layer_partitioner,
config=config)
super(DNNEstimator, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn, model_dir=model_dir, config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
index 75e3107670..050b0428bf 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
@@ -38,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
-def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
+def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
"""Returns a DNNEstimator that uses regression_head."""
return dnn.DNNEstimator(
head=head_lib.regression_head(
@@ -48,6 +48,12 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
*args, **kwargs)
+def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
+ """Returns a DNNEstimator that uses multi_class_head."""
+ return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes),
+ *args, **kwargs)
+
+
class DNNEstimatorEvaluateTest(
dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
@@ -75,6 +81,15 @@ class DNNEstimatorTrainTest(
self, _dnn_estimator_fn)
+class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_estimator_classifier_fn, _dnn_estimator_fn)
+
+
class DNNEstimatorIntegrationTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 89b5f4c413..45d7b74046 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -110,7 +110,7 @@ class SequenceInputLayerTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
def test_embedding_column_with_non_sequence_categorical(self):
- """Tests that error is raised for non-sequence categorical column."""
+ """Tests that error is raised for non-sequence embedding column."""
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
@@ -132,6 +132,107 @@ class SequenceInputLayerTest(test.TestCase):
features={'aaa': sparse_input},
feature_columns=[embedding_column_a])
+ def test_shared_embedding_column(self):
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [2, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2))
+
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 4.), # id 1
+ (5., 6.) # id 2
+ )
+
+ def _get_initializer(embedding_dimension, embedding_values):
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ return _initializer
+
+ expected_input_layer = [
+ # example 0, ids_a [2], ids_b [1]
+ [[5., 6., 3., 4.], [0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [2, 0]
+ [[1., 2., 5., 6.], [3., 4., 1., 2.]],
+ ]
+ expected_sequence_length = [1, 2]
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ # Test that columns are reordered alphabetically.
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension,
+ initializer=_get_initializer(embedding_dimension, embedding_values))
+
+ input_layer, sequence_length = sfc.sequence_input_layer(
+ features={
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ },
+ feature_columns=shared_embedding_columns)
+
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
+ self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_shared_embedding_column_with_non_sequence_categorical(self):
+ """Tests that error is raised for non-sequence shared embedding column."""
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'In embedding_column: aaa_shared_embedding\. categorical_column must '
+ r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'):
+ _, _ = sfc.sequence_input_layer(
+ features={
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b
+ },
+ feature_columns=shared_embedding_columns)
+
def test_indicator_column(self):
vocabulary_size_a = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
@@ -578,6 +679,182 @@ class SequenceEmbeddingColumnTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
+class SequenceSharedEmbeddingColumnTest(test.TestCase):
+
+ def test_get_sequence_dense_tensor(self):
+ vocabulary_size = 3
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [0, 2]
+ # example 2, ids [0]
+ # example 3, ids []
+ indices=((0, 0), (1, 0), (1, 1), (2, 0)),
+ values=(1, 0, 2, 0),
+ dense_shape=(4, 2))
+
+ expected_lookups_a = [
+ # example 0, ids [2]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 2.], [3., 5.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [1]
+ [[3., 5.], [0., 0.]],
+ ]
+
+ expected_lookups_b = [
+ # example 0, ids [1]
+ [[3., 5.], [0., 0.]],
+ # example 1, ids [0, 2]
+ [[1., 2.], [7., 11.]],
+ # example 2, ids [0]
+ [[1., 2.], [0., 0.]],
+ # example 3, ids []
+ [[0., 0.], [0., 0.]],
+ ]
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input_a
+ }))[0]
+ embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'bbb': sparse_input_b
+ }))[0]
+
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
+ self.assertAllEqual(
+ expected_lookups_a, embedding_lookup_a.eval(session=sess))
+ self.assertAllEqual(
+ expected_lookups_b, embedding_lookup_b.eval(session=sess))
+
+ def test_sequence_length(self):
+ vocabulary_size = 3
+
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ expected_sequence_length_a = [1, 2]
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [0, 2]
+ # example 1, ids [1]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ expected_sequence_length_b = [2, 1]
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+
+ sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input_a
+ }))[1]
+ sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'bbb': sparse_input_b
+ }))[1]
+
+ with monitored_session.MonitoredSession() as sess:
+ sequence_length_a = sess.run(sequence_length_a)
+ self.assertAllEqual(expected_sequence_length_a, sequence_length_a)
+ self.assertEqual(np.int64, sequence_length_a.dtype)
+ sequence_length_b = sess.run(sequence_length_b)
+ self.assertAllEqual(expected_sequence_length_b, sequence_length_b)
+ self.assertEqual(np.int64, sequence_length_b.dtype)
+
+ def test_sequence_length_with_empty_rows(self):
+ """Tests _sequence_length when some examples do not have ids."""
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids []
+ # example 1, ids [2]
+ # example 2, ids [0, 1]
+ # example 3, ids []
+ # example 4, ids [1]
+ # example 5, ids []
+ indices=((1, 0), (2, 0), (2, 1), (4, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(6, 2))
+ expected_sequence_length_a = [0, 1, 2, 0, 1, 0]
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids []
+ # example 2, ids []
+ # example 3, ids []
+ # example 4, ids [1]
+ # example 5, ids [0, 1]
+ indices=((0, 0), (4, 0), (5, 0), (5, 1)),
+ values=(2, 1, 0, 1),
+ dense_shape=(6, 2))
+ expected_sequence_length_b = [1, 0, 0, 0, 1, 2]
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+
+ sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input_a
+ }))[1]
+ sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'bbb': sparse_input_b
+ }))[1]
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length_a, sequence_length_a.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length_b, sequence_length_b.eval(session=sess))
+
+
class SequenceIndicatorColumnTest(test.TestCase):
def test_get_sequence_dense_tensor(self):
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 40ae01bfcc..e8e3180019 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -712,7 +712,8 @@ class VariableDeviceChooser(object):
num_tasks=0,
job_name='ps',
device_type='CPU',
- device_index=0):
+ device_index=0,
+ replica=None):
"""Initialize VariableDeviceChooser.
Usage:
@@ -733,12 +734,15 @@ class VariableDeviceChooser(object):
self._job_name = job_name
self._device_type = device_type
self._device_index = device_index
+ self._replica = replica
self._num_tasks = num_tasks
self._next_task_id = 0
def __call__(self, op):
- device_spec = tf_device.DeviceSpec(device_type=self._device_type,
- device_index=self._device_index)
+ device_spec = tf_device.DeviceSpec(
+ replica=self._replica,
+ device_type=self._device_type,
+ device_index=self._device_index)
if self._num_tasks > 0:
task_id = self._next_task_id
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 37ea6eb12a..7e0c7dbec1 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -506,6 +506,35 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+ def testVariableWithVariableDeviceChooserWithReplica(self):
+
+ with ops.Graph().as_default():
+ device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2)
+ with arg_scope([variables_lib2.variable], device=device_fn):
+ a = variables_lib2.variable('a', [])
+ b = variables_lib2.variable('b', [])
+ c = variables_lib2.variable('c', [], device='cpu:12')
+ d = variables_lib2.variable('d', [])
+ with ops.device('cpu:99'):
+ e_init = constant_op.constant(12)
+ e = variables_lib2.variable('e', initializer=e_init)
+ # The values below highlight how the VariableDeviceChooser puts initial
+ # values on the same device as the variable job.
+ self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0')
+ self.assertEqual(a.initial_value.op.colocation_groups(),
+ a.op.colocation_groups())
+ self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0')
+ self.assertEqual(b.initial_value.op.colocation_groups(),
+ b.op.colocation_groups())
+ self.assertDeviceEqual(c.device, '/cpu:12')
+ self.assertEqual(c.initial_value.op.colocation_groups(),
+ c.op.colocation_groups())
+ self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0')
+ self.assertEqual(d.initial_value.op.colocation_groups(),
+ d.op.colocation_groups())
+ self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0')
+ self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
def testVariableGPUPlacement(self):
with ops.Graph().as_default():
@@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase):
return saver.save(sess, checkpoint_dir, global_step=global_step)
def testLoadExistingVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'load_existing_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables'))
init_value0 = 10.0
init_value1 = 20.0
@@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase):
# Tests restoring PartitionedVariables and tests using a dictionary
# of lists as the assign_from_checkpoint() var_list param.
def testLoadPartitionedVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'load_partitioned_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables'))
init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]])
init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case.
@@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase):
partitioner = partitioned_variables.variable_axis_size_partitioner(2)
var0 = variables_lib2.variable(
'var0', shape=init_value0.shape, partitioner=partitioner)
- var0full = variables_lib2.variable(
- 'var0full', shape=init_value0.shape)
+ var0full = variables_lib2.variable('var0full', shape=init_value0.shape)
var1 = variables_lib2.variable(
'var1', shape=init_value1.shape, partitioner=partitioner)
# Convert var0 and var1 into a list of underlying variables.
vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase):
# Request and test the variable values. PartitionedVariables can't
# be evaled so we wrap them in an identity.
- self.assertTrue(np.array_equal(
- init_value0, array_ops.identity(var0).eval()))
- self.assertTrue(np.array_equal(
- init_value0, var0full.eval()))
- self.assertTrue(np.array_equal(
- init_value1, array_ops.identity(var1).eval()))
+ self.assertTrue(
+ np.array_equal(init_value0,
+ array_ops.identity(var0).eval()))
+ self.assertTrue(np.array_equal(init_value0, var0full.eval()))
+ self.assertTrue(
+ np.array_equal(init_value1,
+ array_ops.identity(var1).eval()))
def testRaisesValueErrorIfAVariableIsntFound(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'raises_value_error_if_var_isnt_found'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'raises_value_error_if_var_isnt_found'))
init_value0 = 10.0
init_value1 = 20.0
@@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase):
variables_lib2.assign_from_checkpoint(model_path, vars_to_restore)
def testInitFromCheckpointWithScopes(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'init_from_checkpoint_with_scopes'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'init_from_checkpoint_with_scopes'))
init_value0 = np.asarray(
[1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1))
@@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=init_value1.shape)
vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
return saver.save(sess, checkpoint_dir, global_step=global_step)
def testLoadExistingVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'load_existing_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'load_existing_vars_no_reshape'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'load_existing_vars_no_reshape'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_fn(sess)
def testLoadExistingVariablesDifferentShapeAllowReshape(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(),
- 'load_existing_variables_different_shape_allow_reshape'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(
+ self.get_temp_dir(),
+ 'load_existing_variables_different_shape_allow_reshape'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testNotFoundError(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'not_found_error'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'not_found_error'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var2 = variables_lib2.variable('my_var2', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_fn(sess)
def testMissingVariablesList(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'missing_variables_list'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testMissingVariablesDict(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'missing_variables_dict'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase):
def testZeroInitializer(self):
for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64):
for use_init in (False, True):
- self._testZeroInitializer(
- [10, 20], array_ops.ones(
- [10, 20], dtype=dtype), use_init)
+ self._testZeroInitializer([10, 20], array_ops.ones(
+ [10, 20], dtype=dtype), use_init)
class ZeroVarInitializerOpTest(test.TestCase):
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
index 3ba1026383..2ede5daee7 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -652,7 +652,8 @@ def map_fn(fn, labeled_tensor, name=None):
tensor_lt = core.LabeledTensor(tensor, original_axes)
return fn(tensor_lt).tensor
- map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor)
+ map_op = functional_ops.map_fn(
+ tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype)
map_lt = core.LabeledTensor(map_op, final_axes)
return core.identity(map_lt, name=scope)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
index 06060b99e7..a85cff4f70 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
@@ -683,11 +683,12 @@ def parse_feature_columns_from_sequence_examples(
the serialized proto.
Returns:
- A tuple consisting of:
- context_features: a dict mapping `FeatureColumns` from
- `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
- sequence_features: a dict mapping `FeatureColumns` from
- `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
+ A tuple consisting of (context_features, sequence_features)
+
+ * context_features: a dict mapping `FeatureColumns` from
+ `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
+ * sequence_features: a dict mapping `FeatureColumns` from
+ `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
"""
# Sequence example parsing requires a single (scalar) example.
try:
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile
index cc8a8035d1..2b6997146e 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/Makefile
@@ -70,6 +70,12 @@ LIB_PATH := $(LIBDIR)$(LIB_NAME)
# A small example program that shows how to link against the library.
MINIMAL_PATH := $(BINDIR)minimal
+# Benchmark static library and binary
+BENCHMARK_LIB_NAME := benchmark-lib.a
+BENCHMARK_BINARY_NAME := benchmark_model
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+
MINIMAL_SRCS := \
tensorflow/contrib/lite/examples/minimal/minimal.cc
MINIMAL_OBJS := $(addprefix $(OBJDIR), \
@@ -78,12 +84,19 @@ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
+PROFILER_SRCS := \
+ tensorflow/contrib/lite/profiling/time.cc
+PROFILE_SUMMARIZER_SRCS := \
+ tensorflow/contrib/lite/profiling/profile_summarizer.cc \
+ tensorflow/core/util/stats_calculator.cc
+
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \
+$(PROFILER_SRCS) \
$(wildcard tensorflow/contrib/lite/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \
@@ -107,18 +120,31 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
LIB_OBJS := $(TF_LITE_CC_OBJS)
+
+# Benchmark sources
+BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
+BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \
+ $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \
+ $(PROFILE_SUMMARIZER_SRCS)
+
+BENCHMARK_SRCS := $(filter-out \
+ $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
+ $(BENCHMARK_ALL_SRCS))
+
+BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
+
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.cc
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
-
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.c
@mkdir -p $(dir $@)
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(MINIMAL_PATH)
+all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
# Gathers together all the objects we've compiled into a single '.a' archive.
$(LIB_PATH): $(LIB_OBJS)
@@ -131,6 +157,21 @@ $(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
-o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
+
+$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
+ @mkdir -p $(dir $@)
+ $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
+
+benchmark_lib: $(BENCHMARK_LIB)
+$(info $(BENCHMARK_BINARY))
+$(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $(BENCHMARK_BINARY) \
+ $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS)
+
+benchmark: $(BENCHMARK_BINARY)
+
# Gets rid of all generated files.
clean:
rm -rf $(MAKEFILE_DIR)/gen
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 4f836d3677..22be64d6ff 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -31,7 +31,7 @@ struct AllocationInfo {
// The tensor index to be allocated or deallocated.
int tensor;
// Whether to allocate or deallocate
- enum { ALLOC, DEALLOC } type;
+ enum Type { ALLOC, DEALLOC } type;
};
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
@@ -67,6 +67,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Keeps track of references to each tensor.
std::vector<int> refcounts(graph_info_->num_tensors(), 0);
+ // `allocated` and `deallocated` are technically list of boolean values.
+ // We're saving the compiled binary size by using `vector<int>`.
+ std::vector<int> allocated(graph_info_->num_tensors(), false);
+ std::vector<int> deallocated(graph_info_->num_tensors(), false);
+
+ auto allocate = [this, &allocated, &deallocated](int node,
+ int tensor) -> TfLiteStatus {
+ if (allocated[tensor]) {
+ return kTfLiteOk;
+ }
+ TF_LITE_ENSURE(context_, !deallocated[tensor]);
+ alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC});
+ allocated[tensor] = true;
+ return kTfLiteOk;
+ };
+
+ auto deallocate = [this, &allocated, &deallocated](
+ int node, int tensor) -> TfLiteStatus {
+ if (!allocated[tensor]) {
+ // Do not enqueue a DEALLOC if the tensor is never allocated.
+ // This happened with the constant tensors.
+ return kTfLiteOk;
+ }
+ TF_LITE_ENSURE(context_, !deallocated[tensor]);
+ alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC});
+ return kTfLiteOk;
+ };
// There will be an entry in alloc_queue_ for the allocation of each tensor
// and another for their deallocation.
@@ -79,6 +106,28 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
refcounts[tensor_index]++;
}
+ // Variable tensors should are also never overwritten and need to be alive all
+ // the time.
+ for (int tensor_index : graph_info_->variables()) {
+ refcounts[tensor_index]++;
+ }
+
+ // Queue all graph inputs for allocation.
+ for (int tensor_index : graph_info_->inputs()) {
+ if (tensor_index != kOptionalTensor) {
+ TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
+ }
+ }
+
+ // Queue all graph variable tensors for allocation.
+ for (int tensor_index : graph_info_->variables()) {
+ if (tensor_index != kOptionalTensor) {
+ // Increase the reference count for input tensors by one, so it will
+ // never be deallocated.
+ TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
+ }
+ }
+
// Count references to node input tensors.
for (int i = 0; i < graph_info_->num_nodes(); ++i) {
const TfLiteNode& node = graph_info_->node(i);
@@ -94,10 +143,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Queue all graph inputs for allocation.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
- alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC});
+ TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}
-
// Go through the graph in execution order.
for (int i = 0; i < graph_info_->num_nodes(); ++i) {
const TfLiteNode& node = graph_info_->node(i);
@@ -106,7 +154,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
TfLiteIntArray* node_outputs = node.outputs;
for (int j = 0; j < node_outputs->size; ++j) {
int tensor_index = node_outputs->data[j];
- alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC});
+ TF_LITE_ENSURE_STATUS(allocate(i, tensor_index));
}
// Then update the ref-counts of the node's inputs, and if necessary queue
@@ -117,7 +165,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
if (tensor_index != kOptionalTensor) {
refcounts[tensor_index]--;
if (refcounts[tensor_index] == 0) {
- alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC});
+ TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
}
}
}
diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc
index 16171df10a..f0fd35216f 100644
--- a/tensorflow/contrib/lite/arena_planner_test.cc
+++ b/tensorflow/contrib/lite/arena_planner_test.cc
@@ -100,12 +100,18 @@ class TestGraph {
std::vector<TfLiteTensor>* tensors() { return &tensors_; }
const std::vector<int>& inputs() { return inputs_; }
const std::vector<int>& outputs() { return outputs_; }
+ const std::vector<int>& variables() { return variables_; }
+
+ void SetVariables(const std::vector<int>& variables) {
+ variables_ = variables;
+ }
private:
std::vector<TfLiteNode> nodes_;
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
std::vector<int> outputs_;
+ std::vector<int> variables_;
};
// The GraphInfo for a TestGraph.
@@ -123,6 +129,9 @@ class TestGraphInfo : public GraphInfo {
}
const std::vector<int>& inputs() const override { return graph_->inputs(); }
const std::vector<int>& outputs() const override { return graph_->outputs(); }
+ const std::vector<int>& variables() const override {
+ return graph_->variables();
+ }
private:
TestGraph* graph_;
@@ -306,13 +315,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) {
{
/* in, out, tmp */
{{0, 1}, {2}, {}}, // First op
- {{2, 0}, {4}, {5}}, // Second op, with temporary
+ {{2, 0}, {4}, {5}}, // Second op, with persistent
{{4, -1}, {3}, {}} // Third op, with optional
},
{3});
// Make #1 persistent so it goes into its own arena.
(*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent;
+ // The only use case for kTfLiteArenaRwPersistent is variable tensor now.
+ graph.SetVariables({1});
SetGraph(&graph);
Execute(0, 10);
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 974e6c5d98..612813caee 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -221,8 +221,7 @@ def generated_test_models():
"local_response_norm",
"log_softmax",
"log",
- # TODO(b/110143200): Enable after resolving issues with LSTM conversion.
- # "lstm",
+ "lstm",
"max_pool",
"maximum",
"mean",
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c
index 5c6f5e72a4..7f2aa316f4 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/context.c
@@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, TfLiteTensor* tensor) {
+ const void* allocation, bool is_variable, TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
@@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
tensor->bytes = size;
tensor->allocation_type = allocation_type;
tensor->allocation = allocation;
+ tensor->is_variable = is_variable;
}
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 4eb66cc225..15a37de9dc 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -138,6 +138,7 @@ typedef enum {
kTfLiteInt64 = 4,
kTfLiteString = 5,
kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
} TfLiteType;
// Parameters for asymmetric quantization. Quantized values can be converted
@@ -148,7 +149,7 @@ typedef struct {
int32_t zero_point;
} TfLiteQuantizationParams;
-// A union of points that points to memory for a given tensor.
+// A union of pointers that points to memory for a given tensor.
typedef union {
int* i32;
int64_t* i64;
@@ -157,6 +158,7 @@ typedef union {
const char* raw_const;
uint8_t* uint8;
bool* b;
+ int16_t* i16;
} TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
@@ -223,6 +225,9 @@ typedef struct {
// delegate buffer.
// WARNING: This is an // experimental interface that is subject to change.
bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
} TfLiteTensor;
// Free data memory of tensor `t`;
@@ -235,7 +240,8 @@ void TfLiteTensorFree(TfLiteTensor* t);
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, TfLiteTensor* tensor);
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 313af5fb75..77268d7aeb 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -46,6 +46,9 @@ class GraphInfo {
// Returns the indices of the output tensors.
virtual const std::vector<int>& outputs() const = 0;
+
+ // Returns the indices of the variable tensors.
+ virtual const std::vector<int>& variables() const = 0;
};
// Represents a subgraph of a TensorFlow Lite graph.
diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc
index ea38b43993..89a8f36b41 100644
--- a/tensorflow/contrib/lite/graph_info_test.cc
+++ b/tensorflow/contrib/lite/graph_info_test.cc
@@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo {
TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; }
const std::vector<int>& inputs() const override { return inputs_; }
const std::vector<int>& outputs() const override { return outputs_; }
+ const std::vector<int>& variables() const override { return variables_; }
void AddNode(const std::vector<int>& inputs,
const std::vector<int>& outputs) {
@@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo {
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
std::vector<int> outputs_;
+ std::vector<int> variables_;
};
// Partition a graph to generate a list of subgraphs. This wraps the API call
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index ebb0aedc20..3287f9c4fd 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -82,6 +82,9 @@ class InterpreterInfo : public GraphInfo {
const std::vector<int>& outputs() const override {
return interpreter_->outputs();
}
+ const std::vector<int>& variables() const override {
+ return interpreter_->variables();
+ }
public:
Interpreter* interpreter_;
@@ -302,6 +305,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
return kTfLiteOk;
}
+TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
+ TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(),
+ variables.size()));
+ variables_ = std::move(variables);
+ return kTfLiteOk;
+}
+
TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
const int* indices, int length) {
// Making sure kOptionalTensor is not re-defined to something other than -1.
@@ -334,6 +344,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
case kTfLiteFloat32:
*bytes = sizeof(float) * count;
break;
+ case kTfLiteInt16:
+ *bytes = sizeof(int16_t) * count;
+ break;
case kTfLiteInt32:
*bytes = sizeof(int32_t) * count;
break;
@@ -347,9 +360,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
*bytes = sizeof(bool) * count;
break;
default:
- ReportError(
- &context_,
- "Only float32, int32, int64, uint8, bool supported currently.");
+ ReportError(&context_,
+ "Only float32, int16, int32, int64, uint8, bool supported "
+ "currently.");
return kTfLiteError;
}
return kTfLiteOk;
@@ -367,6 +380,7 @@ TfLiteStatus Interpreter::AllocateTensors() {
}
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
+
if (state_ == kStateUninvokable) {
state_ = kStateInvokable;
}
@@ -375,6 +389,25 @@ TfLiteStatus Interpreter::AllocateTensors() {
return kTfLiteOk;
}
+// TODO(ycling): Consider to provide other functions to initialize variable
+// tensors to non-zero values.
+TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
+ for (auto& tensor : tensors_) {
+ if (!tensor.is_variable) {
+ continue;
+ }
+
+ // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be
+ // allocated after the initial `PrepareOpsAndTensors()` is called.
+ TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type,
+ kTfLiteArenaRwPersistent);
+ TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);
+
+ memset(tensor.data.raw, 0, tensor.bytes);
+ }
+ return kTfLiteOk;
+}
+
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
@@ -687,7 +720,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
- kTfLiteMmapRo, allocation, &tensor);
+ kTfLiteMmapRo, allocation, false, &tensor);
}
return kTfLiteOk;
}
@@ -698,7 +731,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
- const int* dims, TfLiteQuantizationParams quantization) {
+ const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
if (state_ == kStateInvokableAndImmutable) {
ReportError(
&context_,
@@ -716,11 +749,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
}
+
+ TfLiteAllocationType allocation_type = kTfLiteArenaRw;
+ if (type == kTfLiteString) {
+ if (is_variable) {
+ // We don't have a real use case for string variable tensor.
+ ReportError(&context_, "String variable tensor isn't supported.");
+ return kTfLiteError;
+ }
+ allocation_type = kTfLiteDynamic;
+ } else if (is_variable) {
+ allocation_type = kTfLiteArenaRwPersistent;
+ }
+
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
- /*buffer=*/nullptr, required_bytes,
- type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
- nullptr, &context_.tensors[tensor_index]);
+ /*buffer=*/nullptr, required_bytes, allocation_type,
+ nullptr, is_variable, &context_.tensors[tensor_index]);
return kTfLiteOk;
}
@@ -736,7 +781,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
TfLiteIntArray* new_size) {
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
if (tensor->allocation_type == kTfLiteArenaRw ||
- tensor->allocation_type == kTfLiteDynamic) {
+ tensor->allocation_type == kTfLiteDynamic ||
+ tensor->allocation_type == kTfLiteArenaRwPersistent) {
if (tensor->type != kTfLiteString) {
size_t bytesRequired;
TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 7315d83606..37961cd1dc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -118,6 +118,11 @@ class Interpreter {
// interpreter.
TfLiteStatus SetOutputs(std::vector<int> outputs);
+ // Provide a list of tensor indexes that are variable tensors.
+ // Each index is bound check and this modifies the consistent_ flag of the
+ // interpreter.
+ TfLiteStatus SetVariables(std::vector<int> variables);
+
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
@@ -160,13 +165,15 @@ class Interpreter {
// to Interpreter.
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ bool is_variable = false) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
- dims.data(), quantization);
+ dims.data(), quantization, is_variable);
}
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
- const int* dims, TfLiteQuantizationParams quantization);
+ const int* dims, TfLiteQuantizationParams quantization,
+ bool is_variable = false);
// Functions to access tensor data
@@ -182,6 +189,9 @@ class Interpreter {
// Read only access to list of outputs.
const std::vector<int>& outputs() const { return outputs_; }
+ // Read only access to list of variable tensors.
+ const std::vector<int>& variables() const { return variables_; }
+
// Return the name of a given output. The given index must be between 0 and
// outputs().size().
const char* GetOutputName(int index) const {
@@ -379,6 +389,10 @@ class Interpreter {
allow_buffer_handle_output_ = allow_buffer_handle_output;
}
+ // Reset all variable tensors to zero.
+ // WARNING: This is an experimental API and subject to change.
+ TfLiteStatus ResetVariableTensorsToZero();
+
private:
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
@@ -541,6 +555,9 @@ class Interpreter {
// interpreter.
std::vector<int> outputs_;
+ // Array of indices representing the tensors that are variable tensors.
+ std::vector<int> variables_;
+
// The error reporter delegate that tflite will forward queries errors to.
ErrorReporter* error_reporter_;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 4c78466480..b977cb089c 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -106,10 +106,9 @@ TEST(BasicInterpreter, CheckAllocate) {
TfLiteType type;
size_t size;
} cases[] = {
- {kTfLiteFloat32, sizeof(float)},
- {kTfLiteInt32, sizeof(int32_t)},
- {kTfLiteUInt8, sizeof(uint8_t)},
- {kTfLiteInt64, sizeof(int64_t)},
+ {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
+ {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
+ {kTfLiteInt16, sizeof(int16_t)},
};
for (auto test : cases) {
@@ -134,6 +133,7 @@ TEST(BasicInterpreter, CheckResize) {
const int32_t int32s[] = {-3, -4};
const uint8_t uint8s[] = {3, 4};
const int64_t int64s[] = {6, -7};
+ const int16_t int16s[] = {8, -9};
struct {
TfLiteType type;
@@ -144,6 +144,7 @@ TEST(BasicInterpreter, CheckResize) {
{kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
{kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
{kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
+ {kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
};
for (auto test : cases) {
@@ -179,10 +180,8 @@ TEST(BasicInterpreter, CheckAlignment) {
struct {
TfLiteType type;
} cases[] = {
- {kTfLiteFloat32},
- {kTfLiteInt32},
- {kTfLiteUInt8},
- {kTfLiteInt64},
+ {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
+ {kTfLiteInt64}, {kTfLiteInt16},
};
for (auto test : cases) {
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index 2e818f728e..e3cea19e16 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -1,5 +1,14 @@
# TF Lite Android App
+## Building in Android Studio with TensorFlow Lite AAR from JCenter.
+The build.gradle is configured to use TensorFlow Lite's nightly build.
+
+If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is
+undefined for type Interpreter), there has likely been a backwards compatible
+change to the API. You will need to pull new app code that's compatible with the
+nightly build and may need to first wait a few days for our external and internal
+code to merge.
+
## Building from Source with Bazel
1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel):
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index b76eaad8bb..7f29deed83 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -52,7 +52,7 @@ dependencies {
compile 'com.android.support:support-annotations:25.3.1'
compile 'com.android.support:support-v13:25.2.0'
- compile 'org.tensorflow:tensorflow-lite:+'
+ compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 75298b995d..7962fcbc9d 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -177,6 +177,40 @@ cc_library(
)
cc_library(
+ name = "legacy_optimized_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "optimized/depthwiseconv_float.h",
+ "optimized/depthwiseconv_uint8.h",
+ "optimized/depthwiseconv_uint8_3x3_filter.h",
+ "optimized/legacy_optimized_ops.h",
+ "optimized/optimized_ops.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":quantization_util",
+ ":strided_slice_logic",
+ ":types",
+ ":legacy_reference_base",
+ ":round",
+ "//third_party/eigen3",
+ "@gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ ":darwin_x86_64": tflite_deps_intel,
+ ":freebsd": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
name = "optimized",
hdrs = [
"optimized/cblas_conv.h",
@@ -274,6 +308,37 @@ cc_library(
)
cc_library(
+ name = "legacy_reference_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "reference/depthwiseconv_float.h",
+ "reference/depthwiseconv_uint8.h",
+ "reference/legacy_reference_ops.h",
+ "reference/reference_ops.h",
+ ],
+ deps = [
+ ":quantization_util",
+ ":round",
+ ":strided_slice_logic",
+ ":types",
+ "//third_party/eigen3",
+ "@gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ ":darwin_x86_64": tflite_deps_intel,
+ ":freebsd": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
name = "reference",
hdrs = ["tensor.h"],
deps = [
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
new file mode 100644
index 0000000000..c0dda4acf1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ return L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index ed2d04f20d..d0008cc4fb 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -1821,8 +1821,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
// Use dimensions M and N to construct dims for indexing directly into im2col
Dims<4> im2col_dims;
- im2col_dims.sizes[0] = col_dims.strides[3];
- im2col_dims.sizes[1] = row_dims.strides[3];
+ im2col_dims.sizes[0] = FlatSize(col_dims);
+ im2col_dims.sizes[1] = FlatSize(row_dims);
im2col_dims.sizes[2] = 1;
im2col_dims.sizes[3] = 1;
ComputeStrides(&im2col_dims);
@@ -1831,8 +1831,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- // Each row is an output pixel. Arrange the input data into this row in
- // an order we can conveniently multiply with the filter data.
+ // Each im2col row is an output pixel. Arrange the input data in this
+ // row in an order we can conveniently multiply with the filter data.
int row_offset = Offset(row_dims, out_x, out_y, batch, 0);
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
@@ -1848,7 +1848,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
T* dst = im2col_data +
Offset(im2col_dims, col_offset, row_offset, 0, 0);
if ((in_x >= 0) && (in_x < input_width)) {
- // Filter pixel is within the input, copy the data.
+ // Filter pixel is within the input, copy the input data.
T const* src =
input_data + Offset(input_dims, 0, in_x, in_y, batch);
memcpy(dst, src, input_depth * sizeof(T));
@@ -1858,7 +1858,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
} else {
- // Filter row is outside the input, zero out the entire im2col row.
+ // Filter row is outside the input, zero out the entire filter row.
int col_offset = Offset(col_dims, 0, 0, filter_y, 0);
T* dst =
im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0);
@@ -1922,7 +1922,7 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
(void)im2col_dims;
gemmlowp::ScopedProfilingLabel label("Conv");
- // A float set to 0x00000000h == 0.0f
+ // NB: static_cast<float>(0x00000000h) == 0.0f
const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
const Dims<4>* gemm_input_dims = nullptr;
@@ -2366,12 +2366,15 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
}
template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("L2Normalization");
static_assert(Ac == FusedActivationFunctionType::kNone, "");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
float squared_l2_norm = 0;
for (int c = 0; c < depth; ++c) {
@@ -2434,17 +2437,20 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
int32 input_zero_point, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
+ // Note that input_data advances by depth in the second pass below.
int32 diff = input_data[c] - input_zero_point;
square_l2_norm += diff * diff;
}
@@ -6365,69 +6371,84 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output,
}
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("TransposeConv");
- // THIS FUNCTION IS A COPY FROM reference_ops.h.
- // To optimize, start by using the conv code with transposed weights for the
- // case of stride_height = stride_width = 1.
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK(im2col_data);
+
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
const int input_height = ArraySize(input_dims, 2);
const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
const int filter_height = ArraySize(filter_dims, 2);
const int filter_width = ArraySize(filter_dims, 1);
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
+ MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth
- // Although transpose convolution simplifies to convolution with transposed
- // weights for strides of 1, non-unitary striding complicates matters. To
- // keep this reference implementation as clear as possible, we use a "scatter"
- // access pattern, where we loop through all the input elements, computing
- // their influence on the output, rather than looping through the output
- // elements in the typical "gather" access pattern of a conv. We therefore
- // must initialize the output array to zero.
- for (int batch = 0; batch < batches; ++batch) {
- for (int out_y = 0; out_y < output_height; ++out_y) {
- for (int out_x = 0; out_x < output_width; ++out_x) {
- for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
- 0.0f;
- }
- }
- }
- }
+ // Construct the MxN sized im2col matrix.
+ // The rows M, are sub-ordered B x H x W
+ Dims<4> row_dims;
+ row_dims.sizes[0] = output_width;
+ row_dims.sizes[1] = output_height;
+ row_dims.sizes[2] = batches;
+ row_dims.sizes[3] = 1;
+ ComputeStrides(&row_dims);
+
+ // The columns, N, are sub-ordered Kh x Kw x Din
+ Dims<4> col_dims;
+ col_dims.sizes[0] = input_depth;
+ col_dims.sizes[1] = filter_width;
+ col_dims.sizes[2] = filter_height;
+ col_dims.sizes[3] = 1;
+ ComputeStrides(&col_dims);
- // Loop through input elements one at a time.
+ // Use dimensions M and N to construct dims for indexing directly into im2col
+ Dims<4> im2col_dims;
+ im2col_dims.sizes[0] = FlatSize(col_dims);
+ im2col_dims.sizes[1] = FlatSize(row_dims);
+ im2col_dims.sizes[2] = 1;
+ im2col_dims.sizes[3] = 1;
+ ComputeStrides(&im2col_dims);
+
+ // Build the im2col matrix by looping through all the input pixels,
+ // computing their influence on the output, rather than looping through all
+ // the output pixels. We therefore must initialize the im2col array to zero.
+ // This is potentially inefficient because we subsequently overwrite bytes
+ // set here. However, in practice memset is very fast and costs negligible.
+ memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T));
+
+ // Loop through the output batches
for (int batch = 0; batch < batches; ++batch) {
+ // Loop through input pixels one at a time.
for (int in_y = 0; in_y < input_height; ++in_y) {
for (int in_x = 0; in_x < input_width; ++in_x) {
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- // Loop through the output elements it will influence
- const int out_x_origin = (in_x * stride_width) - pad_width;
- const int out_y_origin = (in_y * stride_height) - pad_height;
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ // Loop through the output pixels it will influence
+ const int out_x_origin = (in_x * stride_width) - pad_width;
+ const int out_y_origin = (in_y * stride_height) - pad_height;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int out_y = out_y_origin + filter_y;
+ // Is output pixel within height bounds?
+ if ((out_y >= 0) && (out_y < output_height)) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- for (int out_channel = 0; out_channel < output_depth;
- ++out_channel) {
- // Compute output element location
- const int out_x = out_x_origin + filter_x;
- const int out_y = out_y_origin + filter_y;
- // We cannot accumulate out of bounds
- if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
- (out_y < output_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
- float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
- output_data[Offset(output_dims, out_channel, out_x, out_y,
- batch)] += input_value * filter_value;
- }
+ const int out_x = out_x_origin + filter_x;
+ // Is output pixel within width bounds?
+ if ((out_x >= 0) && (out_x < output_width)) {
+ // Copy the input elements of this pixel
+ T const* src =
+ input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ T* dst = im2col_data +
+ Offset(im2col_dims,
+ Offset(col_dims, 0, filter_x, filter_y, 0),
+ Offset(row_dims, out_x, out_y, batch, 0), 0, 0);
+ memcpy(dst, src, input_depth * sizeof(T));
}
}
}
@@ -6437,6 +6458,31 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ gemmlowp::ScopedProfilingLabel label("TransposeConv");
+
+ // Note we could use transposed weights with forward conv for unstrided
+ // cases. But we are already getting good performance with this code as-is.
+ TFLITE_DCHECK(im2col_data);
+ TransposeIm2col(input_data, input_dims, filter_dims, stride_width,
+ stride_height, pad_width, pad_height, output_dims, 0,
+ im2col_data);
+
+ const auto im2col_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+}
+
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
new file mode 100644
index 0000000000..6f5f6a3e6f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ return L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+} // namespace reference_ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index e10900c5bd..6cef94a606 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -950,11 +950,14 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
}
template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
static_assert(Ac == FusedActivationFunctionType::kNone, "");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
float squared_l2_norm = 0;
for (int c = 0; c < depth; ++c) {
@@ -1015,16 +1018,19 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
int32 input_zero_point, uint8* output_data,
- const Dims<4>& output_dims) {
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+ const RuntimeShape& output_shape) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
- int32 diff =
- input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
+ int32 diff = input_data[depth * i + c] - input_zero_point;
square_l2_norm += diff * diff;
}
int32 inv_l2norm_multiplier;
@@ -1033,14 +1039,12 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
&inv_l2norm_shift);
for (int c = 0; c < depth; c++) {
- int32 diff =
- input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
+ int32 diff = input_data[depth * i + c] - input_zero_point;
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
int32 unclamped_output_val = 128 + rescaled_diff;
int32 output_val = std::min(255, std::max(0, unclamped_output_val));
- output_data[Offset(output_dims, c, i, 0, 0)] =
- static_cast<uint8>(output_val);
+ output_data[depth * i + c] = static_cast<uint8>(output_val);
}
}
}
@@ -3821,7 +3825,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, float* output_data,
- const Dims<4>& output_dims) {
+ const Dims<4>& output_dims, float* /*im2col_data*/,
+ const Dims<4>& /*im2col_dims*/) {
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ce887cea8b..518bee1c63 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -35,6 +35,11 @@ inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
}
template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
inline int32_t* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.i32 : nullptr;
}
@@ -63,6 +68,11 @@ inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
}
template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.i32 : nullptr;
}
@@ -114,6 +124,19 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
return GetTensorDims(dims->data, dims->size);
}
+inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
+ return RuntimeShape(data.size(), data.data());
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ auto* dims = tensor->dims;
+ return RuntimeShape(dims->size, dims->data);
+}
+
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 3ecef15271..64f4881a46 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -65,6 +65,10 @@ class RuntimeShape {
ReplaceWith(dimensions_count, dims_data);
}
+ RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
+ BuildFrom(init_list);
+ }
+
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
delete[] dims_pointer_;
@@ -214,6 +218,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
return offset;
}
+inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0));
+ TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1));
+ TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2));
+ TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3));
+ const int* dims_data = shape.DimsData();
+ return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
+}
+
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
@@ -228,6 +241,9 @@ inline int Offset(const Dims<4>& dims, int* index) {
}
// Get array size, DCHECKing that the dim index is in range.
+//
+// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
+// already performs this check.
template <int N>
int ArraySize(const Dims<N>& array, int index) {
TFLITE_DCHECK(index >= 0 && index < N);
@@ -249,6 +265,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
return MatchingArraySize(array1, index1, args...);
}
+// Get common shape dim, DCHECKing that they all agree.
+inline int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return shape1.Dims(index1);
+}
+
+template <typename... Args>
+int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2, Args... args) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return MatchingDim(shape1, index1, args...);
+}
+
+// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
template <int N>
inline int FlatSize(const Dims<N>& dims) {
int flat_size = 1;
@@ -368,6 +399,72 @@ inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
check_dims_3);
}
+// Data is required to be contiguous, and so many operators can use either the
+// full array flat size or the flat size with one dimension skipped (commonly
+// the depth).
+inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
+ const int dims_count = shape.DimensionsCount();
+ TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
+ const auto* dims_data = shape.DimsData();
+ int flat_size = 1;
+ for (int i = 0; i < dims_count; ++i) {
+ flat_size *= (i == skip_dim) ? 1 : dims_data[i];
+ }
+ return flat_size;
+}
+
+// A combination of MatchingFlatSize() and FlatSizeSkipDim().
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return FlatSizeSkipDim(shape, skip_dim);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
+ check_shape_3);
+}
+
template <int N>
bool IsPackedWithoutStrides(const Dims<N>& dims) {
int expected_stride = 1;
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index 3205c1cc52..a7b54c6b84 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -70,8 +70,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
#define TF_LITE_L2NORM(type) \
type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(output), GetTensorDims(output))
+ GetTensorData<float>(input), GetTensorShape(input), \
+ GetTensorData<float>(output), GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +81,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorDims(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorDims(output))
+#define TF_LITE_L2NORM(type) \
+ type::L2Normalization(GetTensorData<uint8>(input), GetTensorShape(input), \
+ input->params.zero_point, \
+ GetTensorData<uint8>(output), GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index b928f1b302..940718d67e 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -32,4 +32,4 @@ class BuiltinOpResolver : public MutableOpResolver {
} // namespace ops
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index e83b1ec987..8b9deeed20 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -119,10 +119,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
case kTfLiteFloat32:
- optimized_ops::TransposeConv(
+ reference_ops::TransposeConv(
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
stride_height, padding_size.width, padding_size.height,
+ GetTensorData<float>(output), GetTensorDims(output),
+ // Last two args specify im2col which reference_ops ignores.
+ // (Note this does not lead to a performance regression, as the
+ // previous optimized version was just a copy of the reference code.)
+ // TODO(b/110208176): Allocate im2col tensors and switch to
+ // optimized_ops.
GetTensorData<float>(output), GetTensorDims(output));
break;
default:
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 039f32b38e..bc62e4cc2d 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -45,6 +45,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_FLOAT32:
*type = kTfLiteFloat32;
break;
+ case TensorType_INT16:
+ *type = kTfLiteInt16;
+ break;
case TensorType_INT32:
*type = kTfLiteInt32;
break;
@@ -849,7 +852,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
const char* buffer_ptr;
TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
+ bool is_variable = tensor->is_variable();
if (buffer_ptr) {
+ if (is_variable) {
+ error_reporter_->Report(
+ "Tensor %d is a variable tensor with buffer. "
+ "It's not supported now.\n",
+ i);
+ status = kTfLiteError;
+ }
+
if (interpreter->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_) != kTfLiteOk) {
@@ -858,8 +870,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
} else {
- if (interpreter->SetTensorParametersReadWrite(
- i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
+ if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor),
+ dims, quantization,
+ is_variable) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;
@@ -943,6 +956,15 @@ TfLiteStatus InterpreterBuilder::operator()(
if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
+ std::vector<int> variables;
+ for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
+ auto* tensor = (*interpreter)->tensor(i);
+ if (tensor->is_variable) {
+ variables.push_back(i);
+ }
+ }
+ (**interpreter).SetVariables(variables);
+
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index dfdd80ea8a..3af809a2a1 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -50,6 +50,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteString";
case kTfLiteBool:
return "kTfLiteBool";
+ case kTfLiteInt16:
+ return "kTfLiteInt16";
}
return "(invalid)";
}
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 5f304ad45d..e5e5c4fb02 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -68,6 +68,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_FLOAT32;
case kTfLiteInt32:
return NPY_INT32;
+ case kTfLiteInt16:
+ return NPY_INT16;
case kTfLiteUInt8:
return NPY_UINT8;
case kTfLiteInt64:
@@ -90,6 +92,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
return kTfLiteFloat32;
case NPY_INT32:
return kTfLiteInt32;
+ case NPY_INT16:
+ return kTfLiteInt16;
case NPY_UINT8:
return kTfLiteUInt8;
case NPY_INT64:
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 876ffbbffa..8315066cd1 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -22,6 +22,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@Interpreter
@@OpHint
@@convert_op_hints_to_stubs
+@@build_toco_convert_protos
@@FLOAT
@@QUANTIZED_UINT8
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index ee5208df14..c7b955a165 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -34,6 +34,7 @@ enum TensorType : byte {
INT64 = 4,
STRING = 5,
BOOL = 6,
+ INT16 = 7,
}
// Parameters for converting a quantized tensor back to float. Given a
@@ -63,6 +64,8 @@ table Tensor {
buffer:uint;
name:string; // For debugging and importing back into tensorflow.
quantization:QuantizationParameters; // Optional.
+
+ is_variable:bool = false;
}
// A list of builtin operators. Builtin operators are slightly faster than custom
@@ -520,6 +523,16 @@ table Operator {
builtin_options:BuiltinOptions;
custom_options:[ubyte];
custom_options_format:CustomOptionsFormat;
+
+ // A list of booleans indicating the input tensors which are being mutated by
+ // this operator.(e.g. used by RNN and LSTM).
+ // For example, if the "inputs" array refers to 5 tensors and the second and
+ // fifth are mutable variables, then this list will contain
+ // [false, true, false, false, true].
+ //
+ // If the list is empty, no variable is mutated in this operator.
+ // The list either has the same length as `inputs`, or is empty.
+ mutating_variable_inputs:[bool];
}
// The root type, defining a subgraph, which typically represents an entire
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 887e47ed1e..81d4574da7 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -216,11 +216,12 @@ enum TensorType {
TensorType_INT64 = 4,
TensorType_STRING = 5,
TensorType_BOOL = 6,
+ TensorType_INT16 = 7,
TensorType_MIN = TensorType_FLOAT32,
- TensorType_MAX = TensorType_BOOL
+ TensorType_MAX = TensorType_INT16
};
-inline TensorType (&EnumValuesTensorType())[7] {
+inline TensorType (&EnumValuesTensorType())[8] {
static TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@@ -228,7 +229,8 @@ inline TensorType (&EnumValuesTensorType())[7] {
TensorType_UINT8,
TensorType_INT64,
TensorType_STRING,
- TensorType_BOOL
+ TensorType_BOOL,
+ TensorType_INT16
};
return values;
}
@@ -242,6 +244,7 @@ inline const char **EnumNamesTensorType() {
"INT64",
"STRING",
"BOOL",
+ "INT16",
nullptr
};
return names;
@@ -1671,9 +1674,11 @@ struct TensorT : public flatbuffers::NativeTable {
uint32_t buffer;
std::string name;
std::unique_ptr<QuantizationParametersT> quantization;
+ bool is_variable;
TensorT()
: type(TensorType_FLOAT32),
- buffer(0) {
+ buffer(0),
+ is_variable(false) {
}
};
@@ -1684,7 +1689,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_TYPE = 6,
VT_BUFFER = 8,
VT_NAME = 10,
- VT_QUANTIZATION = 12
+ VT_QUANTIZATION = 12,
+ VT_IS_VARIABLE = 14
};
const flatbuffers::Vector<int32_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
@@ -1701,6 +1707,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const QuantizationParameters *quantization() const {
return GetPointer<const QuantizationParameters *>(VT_QUANTIZATION);
}
+ bool is_variable() const {
+ return GetField<uint8_t>(VT_IS_VARIABLE, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SHAPE) &&
@@ -1711,6 +1720,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.Verify(name()) &&
VerifyOffset(verifier, VT_QUANTIZATION) &&
verifier.VerifyTable(quantization()) &&
+ VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
verifier.EndTable();
}
TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1736,6 +1746,9 @@ struct TensorBuilder {
void add_quantization(flatbuffers::Offset<QuantizationParameters> quantization) {
fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization);
}
+ void add_is_variable(bool is_variable) {
+ fbb_.AddElement<uint8_t>(Tensor::VT_IS_VARIABLE, static_cast<uint8_t>(is_variable), 0);
+ }
explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1754,12 +1767,14 @@ inline flatbuffers::Offset<Tensor> CreateTensor(
TensorType type = TensorType_FLOAT32,
uint32_t buffer = 0,
flatbuffers::Offset<flatbuffers::String> name = 0,
- flatbuffers::Offset<QuantizationParameters> quantization = 0) {
+ flatbuffers::Offset<QuantizationParameters> quantization = 0,
+ bool is_variable = false) {
TensorBuilder builder_(_fbb);
builder_.add_quantization(quantization);
builder_.add_name(name);
builder_.add_buffer(buffer);
builder_.add_shape(shape);
+ builder_.add_is_variable(is_variable);
builder_.add_type(type);
return builder_.Finish();
}
@@ -1770,14 +1785,16 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
TensorType type = TensorType_FLOAT32,
uint32_t buffer = 0,
const char *name = nullptr,
- flatbuffers::Offset<QuantizationParameters> quantization = 0) {
+ flatbuffers::Offset<QuantizationParameters> quantization = 0,
+ bool is_variable = false) {
return tflite::CreateTensor(
_fbb,
shape ? _fbb.CreateVector<int32_t>(*shape) : 0,
type,
buffer,
name ? _fbb.CreateString(name) : 0,
- quantization);
+ quantization,
+ is_variable);
}
flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -5004,6 +5021,7 @@ struct OperatorT : public flatbuffers::NativeTable {
BuiltinOptionsUnion builtin_options;
std::vector<uint8_t> custom_options;
CustomOptionsFormat custom_options_format;
+ std::vector<bool> mutating_variable_inputs;
OperatorT()
: opcode_index(0),
custom_options_format(CustomOptionsFormat_FLEXBUFFERS) {
@@ -5019,7 +5037,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_BUILTIN_OPTIONS_TYPE = 10,
VT_BUILTIN_OPTIONS = 12,
VT_CUSTOM_OPTIONS = 14,
- VT_CUSTOM_OPTIONS_FORMAT = 16
+ VT_CUSTOM_OPTIONS_FORMAT = 16,
+ VT_MUTATING_VARIABLE_INPUTS = 18
};
uint32_t opcode_index() const {
return GetField<uint32_t>(VT_OPCODE_INDEX, 0);
@@ -5205,6 +5224,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
CustomOptionsFormat custom_options_format() const {
return static_cast<CustomOptionsFormat>(GetField<int8_t>(VT_CUSTOM_OPTIONS_FORMAT, 0));
}
+ const flatbuffers::Vector<uint8_t> *mutating_variable_inputs() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_MUTATING_VARIABLE_INPUTS);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX) &&
@@ -5218,6 +5240,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_CUSTOM_OPTIONS) &&
verifier.Verify(custom_options()) &&
VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT) &&
+ VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) &&
+ verifier.Verify(mutating_variable_inputs()) &&
verifier.EndTable();
}
OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -5465,6 +5489,9 @@ struct OperatorBuilder {
void add_custom_options_format(CustomOptionsFormat custom_options_format) {
fbb_.AddElement<int8_t>(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast<int8_t>(custom_options_format), 0);
}
+ void add_mutating_variable_inputs(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> mutating_variable_inputs) {
+ fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs);
+ }
explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -5485,8 +5512,10 @@ inline flatbuffers::Offset<Operator> CreateOperator(
BuiltinOptions builtin_options_type = BuiltinOptions_NONE,
flatbuffers::Offset<void> builtin_options = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_options = 0,
- CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) {
+ CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> mutating_variable_inputs = 0) {
OperatorBuilder builder_(_fbb);
+ builder_.add_mutating_variable_inputs(mutating_variable_inputs);
builder_.add_custom_options(custom_options);
builder_.add_builtin_options(builtin_options);
builder_.add_outputs(outputs);
@@ -5505,7 +5534,8 @@ inline flatbuffers::Offset<Operator> CreateOperatorDirect(
BuiltinOptions builtin_options_type = BuiltinOptions_NONE,
flatbuffers::Offset<void> builtin_options = 0,
const std::vector<uint8_t> *custom_options = nullptr,
- CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) {
+ CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS,
+ const std::vector<uint8_t> *mutating_variable_inputs = nullptr) {
return tflite::CreateOperator(
_fbb,
opcode_index,
@@ -5514,7 +5544,8 @@ inline flatbuffers::Offset<Operator> CreateOperatorDirect(
builtin_options_type,
builtin_options,
custom_options ? _fbb.CreateVector<uint8_t>(*custom_options) : 0,
- custom_options_format);
+ custom_options_format,
+ mutating_variable_inputs ? _fbb.CreateVector<uint8_t>(*mutating_variable_inputs) : 0);
}
flatbuffers::Offset<Operator> CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -5885,6 +5916,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t
{ auto _e = buffer(); _o->buffer = _e; };
{ auto _e = name(); if (_e) _o->name = _e->str(); };
{ auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
+ { auto _e = is_variable(); _o->is_variable = _e; };
}
inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -5900,13 +5932,15 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
auto _buffer = _o->buffer;
auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name);
auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
+ auto _is_variable = _o->is_variable;
return tflite::CreateTensor(
_fbb,
_shape,
_type,
_buffer,
_name,
- _quantization);
+ _quantization,
+ _is_variable);
}
inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -7429,6 +7463,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi
{ auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); };
{ auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } };
{ auto _e = custom_options_format(); _o->custom_options_format = _e; };
+ { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } };
}
inline flatbuffers::Offset<Operator> Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7446,6 +7481,7 @@ inline flatbuffers::Offset<Operator> CreateOperator(flatbuffers::FlatBufferBuild
auto _builtin_options = _o->builtin_options.Pack(_fbb);
auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0;
auto _custom_options_format = _o->custom_options_format;
+ auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0;
return tflite::CreateOperator(
_fbb,
_opcode_index,
@@ -7454,7 +7490,8 @@ inline flatbuffers::Offset<Operator> CreateOperator(flatbuffers::FlatBufferBuild
_builtin_options_type,
_builtin_options,
_custom_options,
- _custom_options_format);
+ _custom_options_format,
+ _mutating_variable_inputs);
}
inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a89776b29f..a316a40b62 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -105,7 +105,7 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
dims->data[0] = offset_.size() - 1; // Store number of strings.
TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params,
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
- tensor);
+ tensor->is_variable, tensor);
}
int GetStringCount(const char* raw_buffer) {
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index f518bf864c..54edfdfb1d 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -285,7 +285,9 @@ bool TfLiteDriver::CheckResults() {
}
void TfLiteDriver::ResetLSTMStateTensors() {
- // This is a workaround for initializing state tensors for LSTM.
+ interpreter_->ResetVariableTensorsToZero();
+
+ // Below is a workaround for initializing state tensors for LSTM.
// TODO(ycling): Refactoring and find a better way to initialize state
// tensors. Maybe write the reset instructions into the test data.
for (auto node_index : interpreter_->execution_plan()) {
@@ -303,13 +305,6 @@ void TfLiteDriver::ResetLSTMStateTensors() {
int node_index = node.outputs->data[i];
ResetTensor(node_index);
}
- } else if (params->kernel_type == kTfLiteLSTMBasicKernel &&
- node.inputs->size == 5) {
- // The 2th and 5th inputs are state tensors.
- for (int i : {1, 4}) {
- int node_index = node.inputs->data[i];
- ResetTensor(node_index);
- }
}
}
}
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 0789dc9928..dd05c484fa 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -213,6 +213,7 @@ cc_library(
"graph_transformations/convert_squeeze_to_reshape.cc",
"graph_transformations/convert_trivial_addn_to_add.cc",
"graph_transformations/convert_trivial_stack_to_reshape.cc",
+ "graph_transformations/convert_trivial_tile_to_concat.cc",
"graph_transformations/convert_trivial_transpose_to_reshape.cc",
"graph_transformations/create_im2col_arrays.cc",
"graph_transformations/dequantize.cc",
@@ -224,6 +225,7 @@ cc_library(
"graph_transformations/fuse_activation_functions.cc",
"graph_transformations/fuse_binary_into_following_affine.cc",
"graph_transformations/fuse_binary_into_preceding_affine.cc",
+ "graph_transformations/fuse_broadcast_into_following_binary.cc",
"graph_transformations/graph_transformations.cc",
"graph_transformations/hardcode_min_max.cc",
"graph_transformations/identify_dilated_conv.cc",
@@ -293,7 +295,6 @@ cc_library(
"graph_transformations/resolve_tensorflow_matmul.cc",
"graph_transformations/resolve_tensorflow_merge.cc",
"graph_transformations/resolve_tensorflow_switch.cc",
- "graph_transformations/resolve_tensorflow_tile.cc",
"graph_transformations/resolve_transpose_attributes.cc",
"graph_transformations/unfuse_activation_functions.cc",
"graph_transformations/unpartition_embedding_lookup.cc",
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index c7c80ab21c..6e5e0d0137 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1687,6 +1687,22 @@ void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertTileOperator(const Model& model,
+ const TensorFlowTileOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* tile_op = tensorflow_graph->add_node();
+ tile_op->set_op("Tile");
+ tile_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *tile_op->add_input() = src_op.inputs[0];
+ *tile_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*tile_op->mutable_attr())["T"].set_type(data_type);
+ const auto multiples_data_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
+}
+
void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
GraphDef* tensorflow_graph) {
auto* topk_op = tensorflow_graph->add_node();
@@ -1953,6 +1969,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kSelect) {
ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowTile) {
+ ConvertTileOperator(model,
+ static_cast<const TensorFlowTileOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
new file mode 100644
index 0000000000..5ab399206b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
+ auto tile_it = model->operators.begin() + op_index;
+ if (tile_it->get()->type != OperatorType::kTensorFlowTile) {
+ return false;
+ }
+ auto* tile_op = static_cast<TransposeOperator*>(tile_it->get());
+
+ const auto& input_array = model->GetArray(tile_op->inputs[0]);
+ const auto& multiples_array = model->GetArray(tile_op->inputs[1]);
+ const auto& output_array = model->GetArray(tile_op->outputs[0]);
+ if (!input_array.has_shape() || !multiples_array.has_shape() ||
+ !output_array.has_shape()) {
+ // Yield until PropagateFixedSizes has been run on this op.
+ return false;
+ }
+ // Note: We can assume we have error checked inputs in PropagateFixedSizes.
+
+ if (!multiples_array.buffer) {
+ // Yield until the multiples is constant.
+ return false;
+ }
+ std::vector<int32> const& multiples =
+ multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // We can simplify the tile if only a single dimension is being multiplied.
+ // It then just becomes a concat along that dimension.
+ int non_one_dims = 0;
+ int concat_axis = 0;
+ for (int i = 0; i < multiples.size(); ++i) {
+ if (multiples[i] != 1) {
+ ++non_one_dims;
+ concat_axis = i;
+ }
+ }
+ if (non_one_dims != 1) {
+ // The tile is non-trivial. Good luck.
+ AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)",
+ LogName(*tile_op));
+ return false;
+ }
+
+ // The tile is like a concat.
+ AddMessageF("Simplifying %s to a Concat along a single axis %d",
+ LogName(*tile_op), concat_axis);
+
+ auto* concat_op = new ConcatenationOperator;
+
+ // Copy input and output.
+ // Note that we multiply out the input by the number of times requested.
+ for (int i = 0; i < multiples[concat_axis]; ++i) {
+ concat_op->inputs.push_back(tile_op->inputs[0]);
+ }
+ concat_op->axis = concat_axis;
+ concat_op->outputs = tile_op->outputs;
+
+ // Delete multiples array if unused.
+ if (IsDiscardableArray(*model, tile_op->inputs[1]) &&
+ CountOpsWithInput(*model, tile_op->inputs[1]) == 1) {
+ model->EraseArray(tile_op->inputs[1]);
+ }
+
+ // Replace the operator in the graph.
+ const auto concat_it = model->operators.emplace(tile_it, concat_op);
+ tile_it = concat_it + 1;
+ CHECK_EQ(tile_it->get(), tile_op);
+ model->operators.erase(tile_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 8ca2cd66ac..1e68cd678b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -25,17 +25,12 @@ limitations under the License.
namespace toco {
-bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
- auto conv_it = model->operators.begin() + op_index;
- if (conv_it->get()->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
- if (conv_op->outputs.size() == 2) {
+bool ProcessConvOperator(Model* model, ConvOperator* op) {
+ if (op->outputs.size() == 2) {
// We already have an im2col array
return false;
}
- const auto& weights_array = model->GetArray(conv_op->inputs[1]);
+ const auto& weights_array = model->GetArray(op->inputs[1]);
if (!weights_array.has_shape()) {
// We need to yield until weights dims have been resolved, because
// from the weights dims we determine whether an im2col array is
@@ -45,26 +40,52 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
const auto& weights_shape = weights_array.shape();
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
- if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
- conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 &&
- conv_op->dilation_height_factor == 1) {
+ if (kwidth == 1 && kheight == 1 && op->stride_width == 1 &&
+ op->stride_height == 1 && op->dilation_width_factor == 1 &&
+ op->dilation_height_factor == 1) {
// 1x1 unstrided undilated conv does not need an im2col array.
return false;
}
// Create the im2col array.
- CHECK_EQ(conv_op->outputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
const string& im2col_array_name =
- AvailableArrayName(*model, conv_op->inputs[0] + "_im2col");
+ AvailableArrayName(*model, op->inputs[0] + "_im2col");
model->GetOrCreateArray(im2col_array_name);
- conv_op->outputs.push_back(im2col_array_name);
- AddMessageF(
- "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, "
- "stride_height=%d",
- LogName(*conv_op), kwidth, kheight, conv_op->stride_width,
- conv_op->stride_height);
+ op->outputs.push_back(im2col_array_name);
return true;
}
+bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
+ if (op->outputs.size() == 2) {
+ // We already have an im2col array
+ return false;
+ }
+
+ // Always create an im2col array for transpose_conv.
+ CHECK_EQ(op->outputs.size(), 1);
+ const string& im2col_array_name = AvailableArrayName(
+ *model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col");
+ model->GetOrCreateArray(im2col_array_name);
+ op->outputs.push_back(im2col_array_name);
+
+ return true;
+}
+
+bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ switch (op->type) {
+ case OperatorType::kConv:
+ return ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ case OperatorType::kTransposeConv:
+ return ProcessTransposeConvOperator(
+ model, static_cast<TransposeConvOperator*>(op));
+ default:
+ return false;
+ }
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
new file mode 100644
index 0000000000..874d8def57
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+// Returns true if the given op is strictly a broadcasting operation.
+// This is commonly seen as a Concat of the same input multiple times, and is
+// often generated from Tile ops that were converted via the
+// convert_trivial_tile_to_concat transformation.
+bool IsBroadcastingOp(const Model& model, Operator* op) {
+ // Concatenation of identical inputs is usually a broadcast.
+ if (op->type == OperatorType::kConcatenation) {
+ // Verify that all inputs are the same.
+ for (int i = 1; i < op->inputs.size(); ++i) {
+ if (op->inputs[i] != op->inputs[0]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // There are other things we could look for (Stack/etc) when needed.
+ return false;
+}
+
+} // namespace
+
+// Finds an operation that looks like a broadcast (concat of the same sources
+// along the last dimension) and drops it by relying on the ability of certain
+// binary ops to perform an implicit broadcast.
+bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->inputs.size() != 2) {
+ return false;
+ }
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ // NOTE: either of these ops may be nullptr if the input array is constant.
+ Operator* const op[2] = {
+ GetOpWithOutput(*model, binary_op->inputs[0]),
+ GetOpWithOutput(*model, binary_op->inputs[1]),
+ };
+
+ // Check whether either input is a broadcast-like concat.
+ bool is_op_0_broadcast = op[0] && IsBroadcastingOp(*model, op[0]);
+ bool is_op_1_broadcast = op[1] && IsBroadcastingOp(*model, op[1]);
+ if (!is_op_0_broadcast && !is_op_1_broadcast) {
+ // Neither input is a broadcast-looking thing.
+ AddMessageF("Neither input looks broadcasty");
+ return false;
+ } else if (is_op_0_broadcast && is_op_1_broadcast) {
+ AddMessageF(
+ "Unable to fuse broadcast into %s as both inputs (%s, %s) are "
+ "broadcasts",
+ LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)",
+ op[1] ? LogName(*op[1]) : "(?)");
+ return false;
+ }
+ int broadcast_index = is_op_0_broadcast ? 0 : 1;
+
+ // Just pull out the input of the broadcast op and pass it directly to the
+ // binary op.
+ AddMessageF("Fusing broadcast op %s into the following binary %s",
+ LogName(*op[broadcast_index]), LogName(*binary_op));
+ binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0];
+
+ // We leave the broadcast op in; it'll get cleaned up if it's not used later.
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 1bc7557d46..62a09acdfb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -117,12 +117,14 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape)
+DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
+DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
@@ -165,7 +167,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
-DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 6d51fc8c31..77c0886811 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -103,6 +103,7 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) {
case OperatorType::kTensorFlowReshape:
case OperatorType::kTranspose:
case OperatorType::kSelect:
+ case OperatorType::kTensorFlowTile:
// Reshapes and transposes don't change values.
return false;
default:
@@ -124,6 +125,9 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
case OperatorType::kTranspose:
// Ignore reshape/transpose shapes/dimensions.
return input_index != 0;
+ case OperatorType::kTensorFlowTile:
+ // Ignore tile multiples.
+ return input_index != 0;
default:
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 170a499d4e..e7da9051d8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -211,12 +211,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
// might as well calculate the output shape and ensure it matches the
// specified one
- // Check if we have already run.
- auto& output_array = model->GetArray(op->outputs[0]);
- if (output_array.has_shape()) {
- return;
- }
-
// SPECIFIED OUTPUT SHAPE
// The below is the specified, or prescribed output shape, _given_ to the
// operator as an input.
@@ -284,7 +278,17 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
// Set the output shape according to the specified output shape.
std::vector<int32> const& specified_output_shape =
specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto& output_array = model->GetArray(op->outputs[0]);
*(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
+
+ // Set im2col array dimensions if there is one.
+ if (op->outputs.size() == 2) {
+ const int input_depth = weights_shape.dims(3);
+ auto& im2col_array = model->GetArray(op->outputs[1]);
+ im2col_array.copy_shape(
+ Shape{specified_output_shape[0], specified_output_shape[1],
+ specified_output_shape[2], input_depth * kheight * kwidth});
+ }
}
void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
@@ -1505,6 +1509,48 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
}
}
+void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // We have already run.
+ return;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ auto& multiples_array = model->GetArray(op->inputs[1]);
+ if (!multiples_array.has_shape()) {
+ // Yield until multiples shape been resolved.
+ return;
+ }
+ if (!multiples_array.buffer) {
+ // Yield until the multiples is constant.
+ return;
+ }
+ CHECK(multiples_array.data_type == ArrayDataType::kInt32)
+ << "Tile multiples input must be int32";
+
+ std::vector<int32> const& multiples =
+ multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(multiples.size(), input_shape.dimensions_count())
+ << "Tile multiples input " << op->inputs[1]
+ << " must be same length as input dimensions";
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->resize(multiples.size());
+ for (int i = 0; i < mutable_dims->size(); ++i) {
+ (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1623,14 +1669,6 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
break;
- case OperatorType::kTensorFlowTile:
- // We don't currently implement the propagation of fixed sizes through
- // a TensorFlow Tile.
- //
- // Fortunately, we don't need to: so far, we have only dealt with Tile
- // or Slice ops in subgraphs that are identified as L2Normalization.
- // See IdentifyL2Normalization.
- break;
case OperatorType::kTensorFlowSwitch:
// We can't know the sizes of the outputs until we have resolved the
// predicate, and once we have resolved the predicate, the whole
@@ -1734,6 +1772,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSparseToDenseOperator(model,
static_cast<SparseToDenseOperator*>(op));
break;
+ case OperatorType::kTensorFlowTile:
+ ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
deleted file mode 100644
index 1ddf54c778..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-namespace {
-
-void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
- int operand_index) {
- CHECK(tile_op->type == OperatorType::kTensorFlowTile);
- CHECK_EQ(binary_op->inputs.size(), 2);
- CHECK_EQ(tile_op->inputs.size(), 2);
- const string tile_multiplier_array = tile_op->inputs[1];
- const string tile_output_array = tile_op->outputs[0];
- binary_op->inputs[operand_index] = tile_op->inputs[0];
- auto tile_it = model->operators.begin();
- for (; tile_it != model->operators.end(); ++tile_it) {
- if (tile_it->get() == tile_op) {
- break;
- }
- }
- CHECK(tile_it != model->operators.end());
- CHECK(tile_it->get() == tile_op);
- model->operators.erase(tile_it);
- if (!CountOpsWithInput(*model, tile_multiplier_array) &&
- !GetOpWithOutput(*model, tile_multiplier_array)) {
- model->EraseArray(tile_multiplier_array);
- }
- if (!CountOpsWithInput(*model, tile_output_array)) {
- model->EraseArray(tile_output_array);
- }
-}
-} // namespace
-
-bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) {
- const auto binary_it = model->operators.begin() + op_index;
- auto* binary_op = binary_it->get();
- // Test for binary ops of types that we know how to resolve
- if (binary_op->inputs.size() != 2) {
- return false;
- }
- if (binary_op->type != OperatorType::kAdd &&
- binary_op->type != OperatorType::kMul &&
- binary_op->type != OperatorType::kSub &&
- binary_op->type != OperatorType::kDiv) {
- return false;
- }
-
- Operator* const op[2] = {
- GetOpWithOutput(*model, binary_op->inputs[0]),
- GetOpWithOutput(*model, binary_op->inputs[1]),
- };
-
- // In the unlikely case where both operands are Tile, we can't infer the
- // output
- // size without the Tile nodes, so we have to bail out.
- if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] &&
- op[1]->type == OperatorType::kTensorFlowTile) {
- return false;
- }
-
- for (int i = 0; i < 2; i++) {
- if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) {
- // We can only remove a Tile operator is no other op than the present
- // binary op was consuming its tiled output.
- if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) {
- AddMessageF("Removing %s", LogName(*op[i]));
- RemoveTileOperator(model, op[i], binary_op, i);
- return true;
- }
- }
- }
- return false;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2f43adb07b..7bdec47aa9 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -1222,8 +1222,10 @@ struct TensorFlowSumOperator : Operator {
};
// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
-// Not fully supported, just a placeholder to handle TensorFlow graphs and
-// support graph transformations to other operator types by matching sub-graphs.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: int array with length of rank(input[0])
struct TensorFlowTileOperator : Operator {
TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {}
};
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index a2d753657b..7ba2603a95 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -99,7 +99,8 @@ void LoadOperatorsMap(
Offset<Vector<Offset<Tensor>>> ExportTensors(
const Model& model, const details::TensorsMap& tensors_map,
- FlatBufferBuilder* builder, std::vector<const Array*>* buffers_to_write) {
+ FlatBufferBuilder* builder, std::vector<const Array*>* buffers_to_write,
+ const std::set<int32_t>& variable_tensor_indices) {
// In the end we will need to produce a vector sorted by the indices of the
// tensors in the tensors_map.
std::map<int, Offset<Tensor>> ordered_tensors;
@@ -139,9 +140,11 @@ Offset<Vector<Offset<Tensor>>> ExportTensors(
scale, zero_point);
int index = tensors_map.at(tensor_name);
+ bool is_variable =
+ variable_tensor_indices.find(index) != variable_tensor_indices.end();
ordered_tensors[index] =
CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index,
- builder->CreateString(tensor_name), q_param);
+ builder->CreateString(tensor_name), q_param, is_variable);
}
std::vector<Offset<Tensor>> tensor_vector;
@@ -239,7 +242,10 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map,
- const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) {
+ const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
+ std::set<int32_t>* variable_tensor_indices) {
+ variable_tensor_indices->clear();
+
// The operators are in execution order, so we just follow tf.mini order.
std::vector<Offset<Operator>> op_vector;
for (const auto& op : model.operators) {
@@ -256,18 +262,36 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
+ auto tflite_op_it = ops_by_type.find(op->type);
+ BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
+ ? nullptr
+ : tflite_op_it->second.get();
+
// This is a custom op unless we can find it in ops_by_type, and even then
// it could be a custom op (such as kTensorFlowUnsupported).
-
auto options = Options::Custom(0);
- if (ops_by_type.count(op->type) != 0) {
- options = ops_by_type.at(op->type)->Serialize(*op, builder);
+
+ std::vector<bool> mutating_input_variables;
+ if (tflite_op) {
+ options = tflite_op->Serialize(*op, builder);
+ mutating_input_variables = tflite_op->GetMutatingInputVariables(*op);
+
+ if (!mutating_input_variables.empty()) {
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ if (!mutating_input_variables[i]) {
+ continue;
+ }
+ int32_t variable_tensor_index = tensors_map.at(op->inputs[i]);
+ variable_tensor_indices->insert(variable_tensor_index);
+ }
+ }
}
// The only supported CustomOptionFormat is FLEXBUFFERS now.
op_vector.push_back(CreateOperator(
*builder, op_index, builder->CreateVector(inputs),
builder->CreateVector(outputs), options.type, options.builtin,
- options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS));
+ options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS,
+ builder->CreateVector(mutating_input_variables)));
}
return builder->CreateVector(op_vector);
@@ -308,13 +332,10 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write);
- auto inputs = ExportInputTensors(model, tensors_map, &builder);
- auto outputs = ExportOutputTensors(model, tensors_map, &builder);
-
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
&builder, &error_summary);
+
const string fake_quant_operation_name = "FAKE_QUANT";
if (error_summary.count(fake_quant_operation_name) != 0) {
@@ -353,11 +374,18 @@ void Export(
<< absl::StrJoin(error_summary_final, ", ") << ".";
}
- auto ops =
- ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder);
+ std::set<int32_t> variable_tensor_indices;
+ auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
+ &builder, &variable_tensor_indices);
+
+ auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
+ variable_tensor_indices);
+ auto inputs = ExportInputTensors(model, tensors_map, &builder);
+ auto outputs = ExportOutputTensors(model, tensors_map, &builder);
// TODO(aselle): add support to toco for multiple subgraphs.
- auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops);
+ auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops,
+ /* name */ 0);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 7490ab960b..a0fbb58aca 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -668,6 +668,24 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
return 2;
}
}
+
+ std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const override {
+ const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+
+ switch (lstm_op.kernel_type) {
+ case LstmCellOperator::KERNEL_FULL:
+ // TODO(ycling): Change the full kernel to use the new variable tensor
+ // design. This requires moving the state tensors from output to input.
+ return std::vector<bool>();
+ case LstmCellOperator::KERNEL_BASIC: {
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
+ mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
+ mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
+ return mutating_input_variables;
+ }
+ }
+ }
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 5e9c20e40d..d9ea23edf2 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -87,6 +87,17 @@ class BaseOperator {
// overridden. (See example in `operator_test.cc`)
virtual int GetVersion(const Operator& op) const = 0;
+ // Given a Toco `Operator`, return a list of booleans indicating the op
+ // mutates which input variables.
+ // * If the op mutates any input variables, it should return a list of bool
+ // with the same length as inputs.
+ // * Otherwise, it will return an empty list.
+ virtual std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const {
+ // Most ops don't have variable tensors. This function can be overridden.
+ return std::vector<bool>();
+ }
+
private:
string name_;
OperatorType type_;
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
index 4867c3a62e..42c5d7e8eb 100644
--- a/tensorflow/contrib/lite/toco/tflite/types.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -88,6 +88,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
switch (array_data_type) {
case ArrayDataType::kFloat:
return ::tflite::TensorType_FLOAT32;
+ case ArrayDataType::kInt16:
+ return ::tflite::TensorType_INT16;
case ArrayDataType::kInt32:
return ::tflite::TensorType_INT32;
case ArrayDataType::kInt64:
@@ -109,6 +111,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
switch (::tflite::TensorType(tensor_type)) {
case ::tflite::TensorType_FLOAT32:
return ArrayDataType::kFloat;
+ case ::tflite::TensorType_INT16:
+ return ArrayDataType::kInt16;
case ::tflite::TensorType_INT32:
return ArrayDataType::kInt32;
case ::tflite::TensorType_INT64:
@@ -131,6 +135,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
switch (array.data_type) {
case ArrayDataType::kFloat:
return CopyBuffer<ArrayDataType::kFloat>(array, builder);
+ case ArrayDataType::kInt16:
+ return CopyBuffer<ArrayDataType::kInt16>(array, builder);
case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
case ArrayDataType::kInt64:
@@ -154,6 +160,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
switch (tensor.type()) {
case ::tflite::TensorType_FLOAT32:
return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
+ case ::tflite::TensorType_INT16:
+ return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
case ::tflite::TensorType_INT32:
return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
case ::tflite::TensorType_INT64:
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
index 564f303b9b..8c6ef95bfa 100644
--- a/tensorflow/contrib/lite/toco/tflite/types_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -151,6 +151,12 @@ TEST(DataBuffer, Int32) {
::testing::ElementsAre(1, 1 << 30));
}
+TEST(DataBuffer, Int16) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt16>({1, 1 << 14});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt16>().data,
+ ::testing::ElementsAre(1, 1 << 14));
+}
+
TEST(DataBuffer, String) {
Array recovered = ToFlatBufferAndBack<ArrayDataType::kString>(
{"AA", "BBB", "Best. String. Ever."});
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 1fe76f8163..3173d524b7 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -56,6 +56,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ConvertSqueezeToReshape);
transformations->Add(new ConvertTrivialAddNToAdd);
transformations->Add(new ConvertTrivialStackToReshape);
+ transformations->Add(new ConvertTrivialTileToConcat);
transformations->Add(new ConvertTrivialTransposeToReshape);
transformations->Add(new ConvertReorderAxes);
transformations->Add(new ResolveReshapeAttributes);
@@ -76,6 +77,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowMatMul);
transformations->Add(new FuseBinaryIntoPrecedingAffine);
transformations->Add(new FuseBinaryIntoFollowingAffine);
+ transformations->Add(new FuseBroadcastIntoFollowingBinary);
transformations->Add(new MergeReshapeIntoPrecedingTranspose);
transformations->Add(new ReorderElementwiseUnary);
transformations->Add(new ReorderReshapeTranspose);
@@ -94,7 +96,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowMerge);
transformations->Add(new ResolveSqueezeAttributes);
transformations->Add(new ResolveTensorFlowSwitch);
- transformations->Add(new ResolveTensorFlowTile);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
transformations->Add(new IdentifyDilatedConv);
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index 4f2c82ca23..66cb493e5c 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -77,7 +77,31 @@ py_test(
py_test(
name = "metric_ops_test",
srcs = ["python/ops/metric_ops_test.py"],
- shard_count = 16,
+ shard_count = 30,
+ srcs_version = "PY2AND3",
+ tags = ["noasan"], # times out b/63678675
+ deps = [
+ ":metrics_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "metric_ops_large_test",
+ size = "large",
+ srcs = ["python/ops/metric_ops_large_test.py"],
srcs_version = "PY2AND3",
tags = ["noasan"], # times out b/63678675
deps = [
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
new file mode 100644
index 0000000000..7acfc383eb
--- /dev/null
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Large tests for metric_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.contrib.metrics.python.ops import metric_ops
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testLargeCase(self):
+ shape = [32, 512, 256, 1]
+ predictions = random_ops.random_uniform(
+ shape, 0.0, 1.0, dtype=dtypes_lib.float32)
+ labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5)
+
+ result, update_op = metric_ops.precision_recall_at_equal_thresholds(
+ labels=labels, predictions=predictions, num_thresholds=201)
+ # Run many updates, enough to cause highly inaccurate values if the
+ # code used float32 for accumulation.
+ num_updates = 71
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_updates):
+ sess.run(update_op)
+
+ prdata = sess.run(result)
+
+ # Since we use random values, we won't know the tp/fp/tn/fn values, but
+ # tp and fp at threshold 0 should be the total number of positive and
+ # negative labels, hence their sum should be total number of pixels.
+ expected_value = 1.0 * np.product(shape) * num_updates
+ got_value = prdata.tp[0] + prdata.fp[0]
+ # They should be at least within 1.
+ self.assertNear(got_value, expected_value, 1.0)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index b13f08a37d..e720097636 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2391,34 +2391,6 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
for _ in range(3):
self._testResultsEqual(initial_result, result)
- def testLargeCase(self):
- self.skipTest("Test consistently timing out")
- shape = [32, 512, 256, 1]
- predictions = random_ops.random_uniform(
- shape, 0.0, 1.0, dtype=dtypes_lib.float32)
- labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5)
-
- result, update_op = metric_ops.precision_recall_at_equal_thresholds(
- labels=labels, predictions=predictions, num_thresholds=201)
- # Run many updates, enough to cause highly inaccurate values if the
- # code used float32 for accumulation.
- num_updates = 71
-
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- for _ in xrange(num_updates):
- sess.run(update_op)
-
- prdata = sess.run(result)
-
- # Since we use random values, we won't know the tp/fp/tn/fn values, but
- # tp and fp at threshold 0 should be the total number of positive and
- # negative labels, hence their sum should be total number of pixels.
- expected_value = 1.0 * np.product(shape) * num_updates
- got_value = prdata.tp[0] + prdata.fp[0]
- # They should be at least within 1.
- self.assertNear(got_value, expected_value, 1.0)
-
def _testCase(self,
predictions,
labels,
@@ -4727,199 +4699,204 @@ class StreamingSparseRecallTest(test.TestCase):
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2)
- def test_one_label_at_k1_weighted(self):
+ def _test_one_label_at_k1_weighted(self, labels):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
- [0, 0, 1, 0]])
- dense_labels = np.array([[3], [2]], dtype=np.int64)
- for labels in (sparse_labels, dense_labels):
- # Class 3: 1 label, 2 predictions, 1 correct.
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0,))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(2.0,))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(2.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=NAN,
- class_id=3,
- weights=(0.0, 0.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=NAN,
- class_id=3,
- weights=(0.0, 0.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=NAN,
- class_id=3,
- weights=(0.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=NAN,
- class_id=3,
- weights=(0.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 0.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 0.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=2.0 / 2,
- class_id=3,
- weights=(2.0, 3.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=2.0 / 2,
- class_id=3,
- weights=(2.0, 3.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=3.0 / 3,
- class_id=3,
- weights=(3.0, 2.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=3.0 / 3,
- class_id=3,
- weights=(3.0, 2.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=0.3 / 0.3,
- class_id=3,
- weights=(0.3, 0.6))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=0.3 / 0.3,
- class_id=3,
- weights=(0.3, 0.6))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=0.6 / 0.6,
- class_id=3,
- weights=(0.6, 0.3))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=0.6 / 0.6,
- class_id=3,
- weights=(0.6, 0.3))
+ # Class 3: 1 label, 2 predictions, 1 correct.
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(2.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 0.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 0.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=2.0 / 2,
+ class_id=3,
+ weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2,
+ class_id=3,
+ weights=(2.0, 3.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=3.0 / 3,
+ class_id=3,
+ weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=3.0 / 3,
+ class_id=3,
+ weights=(3.0, 2.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=0.3 / 0.3,
+ class_id=3,
+ weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.3 / 0.3,
+ class_id=3,
+ weights=(0.3, 0.6))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=0.6 / 0.6,
+ class_id=3,
+ weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.6 / 0.6,
+ class_id=3,
+ weights=(0.6, 0.3))
- # All classes: 2 labels, 2 predictions, 1 correct.
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=NAN, weights=(0.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=NAN, weights=(0.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
+ # All classes: 2 labels, 2 predictions, 1 correct.
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=NAN, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=(0.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
+
+ def test_one_label_at_k1_weighted_sparse_labels(self):
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
+ self._test_one_label_at_k1_weighted(sparse_labels)
+
+ def test_one_label_at_k1_weighted_dense_labels(self):
+ dense_labels = np.array([[3], [2]], dtype=np.int64)
+ self._test_one_label_at_k1_weighted(dense_labels)
def test_three_labels_at_k5_nan(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index d389050e67..06553929dc 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -23,15 +23,23 @@ REGISTER_OP("CrossReplicaSum")
.Input("input: T")
.Output("output: T")
.Attr("T: {bfloat16, float}")
+ .Attr("group_assignment: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An Op to sum inputs across replicated TPU instances. Each
-instance supplies its own input, and the output of each is the sum of
-all the inputs.
+instance supplies its own input. If group_assignment is empty, the output of
+each is the sum of all the inputs, otherwise the output of each is the sum of
+the inputs belonging to the same group.
+
+For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1.
+Thus we get the outputs: `[A+C, B+D, A+C, B+D]`.
input: The local input to the sum.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
+group_assignment: The list of group ids. `group_assignment[i]` represents the
+ group id of replica i.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index ab2a7a0d4b..f632c953c8 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -44,6 +44,27 @@ REGISTER_OP("TPUReplicatedInput")
" with other shapes.");
}
c->set_output(0, cur);
+
+ // If this is a resource, unify the resource shapes.
+ DataType dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
+ if (dtype == DT_RESOURCE) {
+ const std::vector<shape_inference::ShapeAndType>* shapes_and_types =
+ nullptr;
+ for (int i = c->num_inputs() - 1; i >= 0; --i) {
+ if (shapes_and_types) {
+ if (!c->MergeInputHandleShapesAndTypes(i, *shapes_and_types)) {
+ return errors::InvalidArgument(
+ "Incompatible resource shapes for replicated TPU input.");
+ }
+ } else {
+ shapes_and_types = c->input_handle_shapes_and_types(i);
+ }
+ }
+ if (shapes_and_types) {
+ c->set_output_handle_shapes_and_types(0, *shapes_and_types);
+ }
+ }
return Status::OK();
})
.Doc(
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 508c7a842f..7f1d25732e 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -35,19 +35,19 @@ flags.DEFINE_string(
None,
help='GCE zone where the Cloud TPU is located in. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
-flags.DEFINE_string('tpu_name', None,
+flags.DEFINE_string('tpu', None,
'Name of the Cloud TPU for Cluster Resolvers. You must '
'specify either this flag or --service_addr.')
# Tool specific parameters
flags.DEFINE_string(
'service_addr', None, 'Address of TPU profiler service e.g. '
- 'localhost:8466, you must specify either this flag or --tpu_name.')
+ 'localhost:8466, you must specify either this flag or --tpu.')
flags.DEFINE_string(
'workers_list', None, 'The list of worker TPUs that we are about to profile'
- ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu_name or '
+ ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or '
'--service_addr to profile a subset of tpu nodes. You can also use only'
- '--tpu_name and leave this flag unspecified to profile all the tpus.')
+ '--tpu and leave this flag unspecified to profile all the tpus.')
flags.DEFINE_string('logdir', None,
'Path of TensorBoard log directory e.g. /tmp/tb_log, '
'gs://tb_bucket')
@@ -76,19 +76,19 @@ def run_main():
def main(unused_argv=None):
tf.logging.set_verbosity(tf.logging.INFO)
- if FLAGS.service_addr is None and FLAGS.tpu_name is None:
- sys.exit('You must specify either --service_addr or --tpu_name.')
+ if FLAGS.service_addr is None and FLAGS.tpu is None:
+ sys.exit('You must specify either --service_addr or --tpu.')
tpu_cluster_resolver = None
if FLAGS.service_addr is not None:
- if FLAGS.tpu_name is not None:
- tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring '
- '--tpu_name and using --service_addr.')
+ if FLAGS.tpu is not None:
+ tf.logging.warn('Both --service_addr and --tpu are set. Ignoring '
+ '--tpu and using --service_addr.')
service_addr = FLAGS.service_addr
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
- [FLAGS.tpu_name],
+ [FLAGS.tpu],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
service_addr = tpu_cluster_resolver.get_master()
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index ebd478fd02..f97a972f01 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.6.0'
+_VERSION = '1.7.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
@@ -46,7 +46,7 @@ setup(
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
- 'Development Status :: 4 - Beta',
+ 'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index 618479e1a6..bd9ba6697e 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.6.0"
+#define TPU_PROFILER_VERSION "1.7.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index 14c63a7976..bf442d9116 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -38,9 +38,8 @@ if platform.system() != "Windows":
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
- del op # Unused
# The gradient of a cross replica sum is also a cross-replica sum.
- return gen_tpu_ops.cross_replica_sum(grad)
+ return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index cd0fd6ae8a..dc473c5846 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -591,16 +591,22 @@ def split_compile_and_replicate(computation,
with tpu_function.tpu_shard_context(
num_replicas), ops.control_dependencies([metadata]):
- # The EncapsulateTPUComputations rewrite needs to identify the
- # replicated arguments inside each computation. Adds identity operators
- # tagged with an attribute _tpu_replicated_input to identify the
- # replicated inputs.
+ # For backward compatibility reasons, we tag replicated inputs with the
+ # _tpu_replicated_input attribute. This does nothing and exists only for
+ # backward compatibility.
+ # TODO(phawkins): delete the attr_scope after 6/28/2018.
# pylint: disable=protected-access
- with graph._attr_scope({"_tpu_replicated_input":
- attr_value_pb2.AttrValue(b=True)}):
+ with graph._attr_scope({
+ "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True)
+ }):
+ # Add identity ops so even unused inputs are "consumed" by the
+ # computation. This is to avoid orphaned TPUReplicatedInput nodes.
+ # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
+ # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
computation_inputs = [
array_ops.identity(x, name="replicated_input_{}".format(i))
- for i, x in enumerate(computation_inputs)]
+ for i, x in enumerate(computation_inputs)
+ ]
# pylint: enable=protected-access
# If there is an infeed queue, adds the dequeued values to the
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index e76cf83e4d..15f99d7eeb 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.ops.losses import losses
@@ -32,7 +34,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
def __init__(self,
opt,
reduction=losses.Reduction.MEAN,
- name="CrossShardOptimizer"):
+ name="CrossShardOptimizer",
+ group_assignment=None):
"""Construct a new cross-shard optimizer.
Args:
@@ -40,6 +43,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
+ group_assignment: Optional list of group ids for applying the optimizer
+ to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
@@ -50,6 +55,35 @@ class CrossShardOptimizer(optimizer.Optimizer):
super(CrossShardOptimizer, self).__init__(False, name)
self._opt = opt
self._reduction = reduction
+ self._group_assignment = group_assignment
+
+ def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
+ """Verify group_assignment and get the subgroup size".
+
+ Args:
+ group_assignment: list of group ids for applying the optimizer
+ to subgroups.
+ num_shards: The number of TPU shards.
+
+ Returns:
+ The size of one subgroup in group_assignment.
+
+ Raises:
+ ValueError: If group_assignment is invalid.
+ """
+ if not group_assignment:
+ return None
+ if len(group_assignment) != num_shards:
+ raise ValueError("The size of group_assignment does not equal to "
+ "num_shard({0}). Got group_assignment={1}".format(
+ num_shards, self._group_assignment))
+ subgroup_size_list = dict(collections.Counter(group_assignment)).values()
+ if all(subgroup_size_list[0] == size for size in subgroup_size_list):
+ return subgroup_size_list[0]
+ else:
+ raise ValueError("The size of each subgroup in group_assignment must "
+ "be equal. Got group_assignment={}".format(
+ self._group_assignment))
def compute_gradients(self, loss, var_list=None, **kwargs):
"""Compute gradients of "loss" for the variables in "var_list".
@@ -71,7 +105,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
A list of (gradient, variable) pairs.
Raises:
- ValueError: If not within a tpu_shard_context.
+ ValueError: If not within a tpu_shard_context or group_assignment is
+ invalid.
"""
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
@@ -79,9 +114,17 @@ class CrossShardOptimizer(optimizer.Optimizer):
"CrossShardOptimizer should be used within a tpu_shard_context, but "
"got unset number_of_shards. Assuming 1.")
num_shards = 1
+
+ subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
+ num_shards)
+
if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
- scale = 1.0 / num_shards
+ if self._group_assignment:
+ scale = 1.0 / subgroup_size
+ else:
+ scale = 1.0 / num_shards
loss *= scale
+
return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
@@ -110,7 +153,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
if grad is None:
summed_grads_and_vars.append((grad, var))
else:
- summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var))
+ summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
+ grad, self._group_assignment), var))
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def get_slot(self, *args, **kwargs):
diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc
index 07ac974ff9..931c943dbc 100644
--- a/tensorflow/core/api_def/excluded_ops.cc
+++ b/tensorflow/core/api_def/excluded_ops.cc
@@ -20,7 +20,8 @@ namespace tensorflow {
const std::unordered_set<std::string>* GetExcludedOps() {
static std::unordered_set<std::string>* excluded_ops =
new std::unordered_set<std::string>(
- {"BigQueryReader", "GenerateBigQueryReaderPartitions"});
+ {"BigQueryReader", "GenerateBigQueryReaderPartitions",
+ "GcsConfigureBlockCache", "GcsConfigureCredentials"});
return excluded_ops;
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 5cef93c605..87ba609dd7 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -447,6 +447,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Create a run state and start execution.
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
+#ifndef __ANDROID__
// Set up for collectives if the RunOption declares a key.
if (run_options.experimental().collective_graph_key() > 0) {
if (!collective_executor_mgr_) {
@@ -461,6 +462,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
}
+#endif
// Start parallel Executors.
const size_t num_executors = executors_and_keys->items.size();
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index b5fee36ff4..567c81870c 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -187,8 +187,7 @@ Status CondBuilder::AddOutputs() {
} else {
// Feed the outputs directly from the merge nodes so that downstream ops
// can start before all the outputs have been computed.
- graph_->AddEdge(merges[e->src_output()], e->src_output(), e->dst(),
- e->dst_input());
+ graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
}
}
return Status::OK();
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index b59ced869d..ec26d92a61 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -186,10 +186,6 @@ class DeviceBase {
virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
- const bool has_eigen_cpu_device() const {
- return (eigen_cpu_device_ != nullptr);
- }
-
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_;
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index a0f449d64f..ce213a63be 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -13,14 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/op_kernel.h"
#include <unordered_map>
#include <utility>
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -42,7 +40,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -273,19 +270,6 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
- if (params->device->has_eigen_cpu_device()) {
- int64 block_size = -1, output_size = -1, num_threads = 1;
- const Eigen::ThreadPoolDevice* thread_pool =
- params_->device->eigen_cpu_device();
- AttrSlice attributes(op_kernel().def());
- if (GetNodeAttr(attributes, "_block_size", &block_size) == Status::OK() &&
- GetNodeAttr(attributes, "_output_size", &output_size) == Status::OK()) {
- num_threads = std::min(Eigen::divup(output_size, block_size),
- static_cast<int64>(thread_pool->numThreads()));
- eigen_cpu_device_ = MakeUnique<Eigen::ThreadPoolDevice>(
- thread_pool->getPool(), num_threads);
- }
- }
}
OpKernelContext::~OpKernelContext() {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index d307078e63..a3ad29e02f 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1004,7 +1004,6 @@ class OpKernelContext {
// OpKernels can use these eigen devices to carry out their
// numerical computation.
const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
- if (eigen_cpu_device_ != nullptr) return *eigen_cpu_device_;
return *device()->eigen_cpu_device();
}
const Eigen::GpuDevice& eigen_gpu_device() const {
@@ -1140,7 +1139,6 @@ class OpKernelContext {
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
gtl::InlinedVector<TensorValue, 4> outputs_;
- std::unique_ptr<Eigen::ThreadPoolDevice> eigen_cpu_device_;
// Constructed only if <params->record_tensor_accesses>.
ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 0f748515ef..568f0870c0 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -265,6 +266,28 @@ Status Node::input_node(int idx, const Node** const_n) const {
return Status::OK();
}
+// InputTensor
+
+bool InputTensor::operator==(const InputTensor& other) const {
+ return node == other.node && index == other.index;
+}
+
+uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.index));
+}
+
+// OutputTensor
+
+bool OutputTensor::operator==(const OutputTensor& other) const {
+ return node == other.node && index == other.index;
+}
+
+uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.index));
+}
+
// Graph
Graph::Graph(const OpRegistryInterface* ops)
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 33fb7cb57a..a147c94689 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -284,6 +284,16 @@ struct InputTensor {
InputTensor(const Node* n, int i) : node(n), index(i) {}
InputTensor() : node(nullptr), index(0) {}
+
+ // Returns true if this InputTensor is identical to 'other'. Nodes are
+ // compared using pointer equality.
+ bool operator==(const InputTensor& other) const;
+
+ // A hash function for InputTensors. Nodes are hashed based on their pointer
+ // value.
+ struct Hash {
+ uint64 operator()(InputTensor const& s) const;
+ };
};
// Represents an output of a node, i.e., the `index`-th output of `node`. Note
@@ -295,6 +305,16 @@ struct OutputTensor {
OutputTensor(const Node* n, int i) : node(n), index(i) {}
OutputTensor() : node(nullptr), index(0) {}
+
+ // Returns true if this OutputTensor is identical to 'other'. Nodes are
+ // compared using pointer equality.
+ bool operator==(const OutputTensor& other) const;
+
+ // A hash function for OutputTensors. Nodes are hashed based on their pointer
+ // value.
+ struct Hash {
+ uint64 operator()(OutputTensor const& s) const;
+ };
};
class Edge {
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 2a47a4c495..2227904dbf 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -193,6 +193,8 @@ bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
+bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
+
bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index e7f39981c0..7110a9c63d 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -74,6 +74,7 @@ bool IsImag(const NodeDef& node);
bool IsInvGrad(const NodeDef& node);
bool IsLess(const NodeDef& node);
bool IsLessEqual(const NodeDef& node);
+bool IsLog(const NodeDef& node);
bool IsLogicalAnd(const NodeDef& node);
bool IsLogicalNot(const NodeDef& node);
bool IsLogicalOr(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 2073c2968b..33c2a0d420 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -210,8 +210,7 @@ cc_library(
hdrs = ["graph_optimizer_stage.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
@@ -225,6 +224,7 @@ tf_cuda_cc_test(
deps = [
":graph_optimizer_stage",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 51110b4bda..9d500f8f54 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1084,8 +1084,11 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
+ // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
+ if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ }
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -2484,6 +2487,119 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
};
+class ConvertLog1pStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
+ ~ConvertLog1pStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsAdd(*input)) {
+ return Status::OK();
+ }
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ bool modified = false;
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
+ if (!modified) {
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
+ }
+ if (modified) {
+ *simplified_node_name = node->name();
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
+ bool* modified) {
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[i];
+ for (int k = 0; k < t.shape().dim_size(); ++k) {
+ // Skip if t shape is not fully determined.
+ if (t.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[j];
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return errors::InvalidArgument("Cannot get broadcast shape for: ",
+ t.DebugString(), " and ", c.DebugString());
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(t.shape()) && t.has_value()) {
+ Tensor tensor(t.dtype(), t.shape());
+ if (!tensor.FromProto(t.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ t.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < tensor.NumElements(); ++k) {
+ if (!GetElement(tensor, k, &element)) {
+ // input data type is not supported by log1p. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
+ node->set_op("Log1p");
+ node->set_input(0, y->name());
+ node->add_input(AsControlDependency(x->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ *modified = true;
+ }
+ return Status::OK();
+ }
+
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2713,7 +2829,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get());
+ graph_properties_.get(), node_map_.get(),
+ opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -2759,6 +2876,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
+ if (options_.convert_log1p)
+ pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 40c5e9fc56..9a6081dcd8 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -75,6 +75,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool replace_mul_with_square = true;
bool simplify_aggregation = true;
bool convert_pow = true;
+ bool convert_log1p = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index ff96cb6480..177c237fe7 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -264,6 +264,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.simplify_aggregation = true;
}
+
+ void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_log1p = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -1510,7 +1515,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- ArithmeticOptimizer optimizer;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
EnableOnlyRemoveIdentityTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
@@ -2486,6 +2491,43 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
CompareGraphs(want, got);
}
+TEST_F(ArithmeticOptimizerTest, Log1p) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
+ auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
+ auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
+ auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
+ auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
+ Output out1 = ops::Log(s.WithOpName("out1"), a12);
+ Output out2 = ops::Log(s.WithOpName("out2"), a23);
+
+ GrapplerItem item;
+ item.fetch = {"out1", "out2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyLog1p(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(2, tensors.size());
+
+ GraphDef want;
+ AddNode("x1", "Const", {}, {}, &want);
+ AddNode("x2", "Const", {}, {}, &want);
+ AddNode("x3", "Const", {}, {}, &want);
+ AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
+ AddNode("out1", "Log1p",
+ {"x2", AsControlDependency("x1"), AsControlDependency("x3")}, {},
+ &want);
+ AddNode("out2", "Log", {"a23"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 78a6d0d835..3f5bab9d3b 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -65,7 +65,7 @@ void DeleteNodes(const std::set<int>& nodes_to_delete, GraphDef* graph) {
} // namespace
-bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) {
+bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
if (!IsIdentity(node)) {
return true;
}
@@ -108,7 +108,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) {
return true;
}
-bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
+bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
if (!fetch_nodes_known_ ||
nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
@@ -142,6 +142,61 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
return true;
}
+bool DependencyOptimizer::BypassingNodeIsBeneficial(
+ const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
+ const std::vector<NodeDef*>& output_nodes) const {
+ const bool is_identity = IsIdentity(node);
+ const int num_outputs = output_nodes.size();
+ const int num_inputs = node.input_size();
+
+ // Don't increase the number of edges in the graph.
+ if (num_inputs * num_outputs > num_inputs + num_outputs) {
+ return false;
+ }
+
+ // Make sure that we don't increase the number of edges that cross
+ // device boundaries.
+ if ((num_inputs == 1 && num_outputs > 1 &&
+ input_nodes[0]->device() != node.device()) ||
+ (num_inputs > 1 && num_outputs == 1 &&
+ output_nodes[0]->device() != node.device())) {
+ return false;
+ }
+
+ // TODO(rmlarsen): Not all device crossings are equally expensive.
+ // Assign a cost to each based on device affinity and compute a
+ // cost before and after.
+ const string& node_dev = node.device();
+ int num_cross_in = 0;
+ for (NodeDef* input_node : input_nodes) {
+ num_cross_in += static_cast<int>(input_node->device() != node_dev);
+ }
+ int num_cross_out = 0;
+ for (NodeDef* output_node : output_nodes) {
+ num_cross_out += static_cast<int>(output_node->device() != node_dev);
+ }
+ if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
+ // This identity node follows a device crossing, so it might be
+ // following a _Recv node after partioning. Do not remove such nodes,
+ // unless they only have consumers on the same device as themselves.
+ return false;
+ }
+
+ // Make sure we do not increase the number of device crossings.
+ const int num_cross_before = num_cross_in + num_cross_out;
+ int num_cross_after = 0;
+ for (NodeDef* input_node : input_nodes) {
+ for (NodeDef* output_node : output_nodes) {
+ num_cross_after +=
+ static_cast<int>(input_node->device() != output_node->device());
+ }
+ }
+ if (num_cross_after > num_cross_before) {
+ return false;
+ }
+ return true;
+}
+
void DependencyOptimizer::OptimizeNode(int node_idx,
SetVector<int>* nodes_to_simplify,
std::set<int>* nodes_to_delete) {
@@ -269,21 +324,11 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
// y --^> | | --^> b /\ +---+
// +----------+ y --^> b
- if (is_noop || is_identity) {
- if (is_identity && !SafeToRemoveIdentity(*node)) {
- return;
- }
-
+ if (is_noop || (is_identity && SafeToRemoveIdentity(*node))) {
const auto& output_node_set = node_map_->GetOutputs(node_name);
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
output_node_set.end());
- const int num_outputs = output_nodes.size();
const int num_inputs = node->input_size();
-
- // Don't increase the number of edges in the graph.
- if (num_inputs * num_outputs > num_inputs + num_outputs) {
- return;
- }
std::vector<NodeDef*> input_nodes;
for (int i = 0; i < num_inputs; ++i) {
NodeDef* input_node = node_map_->GetNode(node->input(i));
@@ -294,44 +339,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
input_nodes.push_back(input_node);
}
- // Make sure that we don't increase the number of edges that cross
- // device boundaries.
- if ((num_inputs == 1 && num_outputs > 1 &&
- input_nodes[0]->device() != node->device()) ||
- (num_inputs > 1 && num_outputs == 1 &&
- output_nodes[0]->device() != node->device())) {
- return;
- }
-
- // TODO(rmlarsen): Not all device crossings are equally expensive.
- // Assign a cost to each based on device affinity and compute a
- // cost before and after.
- const string& node_dev = node->device();
- int num_cross_in = 0;
- for (NodeDef* input_node : input_nodes) {
- num_cross_in += static_cast<int>(input_node->device() != node_dev);
- }
- int num_cross_out = 0;
- for (NodeDef* output_node : output_nodes) {
- num_cross_out += static_cast<int>(output_node->device() != node_dev);
- }
- if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
- // This identity node follows a device crossing, so it might be
- // following a _Recv node after partioning. Do not remove such nodes,
- // unless they only have consumers on the same device as themselves.
- return;
- }
-
- // Make sure we do not increase the number of device crossings.
- const int num_cross_before = num_cross_in + num_cross_out;
- int num_cross_after = 0;
- for (NodeDef* input_node : input_nodes) {
- for (NodeDef* output_node : output_nodes) {
- num_cross_after +=
- static_cast<int>(input_node->device() != output_node->device());
- }
- }
- if (num_cross_after > num_cross_before) {
+ if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
return;
}
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
index c97ff23e88..48cfa236af 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
@@ -43,11 +43,17 @@ class DependencyOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
+ // Returns true if bypassing node does not increase the number of edges or
+ // number of edges crossing a device boundary.
+ bool BypassingNodeIsBeneficial(
+ const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
+ const std::vector<NodeDef*>& output_nodes) const;
+
// Returns true if node is not an Identity node or if it is an Identity
// that is safe to remove.
- bool SafeToRemoveIdentity(const NodeDef& node);
+ bool SafeToRemoveIdentity(const NodeDef& node) const;
// Returns true if it is safe to convert node to NoOp.
- bool SafeToConvertToNoOp(const NodeDef& node);
+ bool SafeToConvertToNoOp(const NodeDef& node) const;
// Removes all duplicate control dependencies.
void CleanControlInputs();
// Builds a map from the &optimized_graph_->node(i) to i.
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 2fbdd76a77..2afb5df431 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -44,16 +45,19 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
- GraphProperties* graph_properties, NodeMap* node_map)
+ GraphProperties* graph_properties, NodeMap* node_map,
+ RewriterConfig::Toggle opt_level)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
- node_map(node_map) {}
+ node_map(node_map),
+ opt_level(opt_level) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
+ RewriterConfig::Toggle opt_level;
};
Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
index 3f5ab87a5a..34f28c7c27 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -59,7 +60,8 @@ TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ nullptr,
/*graph_properties*/ nullptr,
- /*node_name*/ nullptr);
+ /*node_name*/ nullptr,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
const auto node = ParseNodeScopeAndName("a/b/c/Add");
@@ -94,7 +96,8 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map);
+ /*node_name*/ &node_map,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -133,7 +136,8 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map);
+ /*node_name*/ &node_map,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index ebf844d75f..fd3a0ad422 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -108,6 +108,7 @@ REGISTER_GPU_HOST_KERNEL(bool);
REGISTER_GPU_HOST_REF_KERNEL(bool);
REGISTER_GPU_HOST_KERNEL(string);
REGISTER_GPU_HOST_REF_KERNEL(string);
+REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index bdd08222d4..aca75176a5 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -404,9 +404,10 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
// image ('work_unit_size').
// TODO(andydavis)
+ // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
// *) Consider reducing 'target_working_set_size' if L3 is shared by
// other concurrently running tensorflow ops.
- const size_t target_working_set_size = Eigen::l3CacheSize() / sizeof(T);
+ const size_t target_working_set_size = (30LL << 20) / sizeof(T);
const size_t size_A = output_image_size * filter_total_size;
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 95301b170f..63a775afa8 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -420,8 +420,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
const int output_image_size =
dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
- const size_t l2_cache_size = Eigen::l2CacheSize();
- const size_t l3_cache_size = Eigen::l3CacheSize();
+ // TODO(andydavis) Get L2/L3 cache sizes from device.
+ const size_t l2_cache_size = 256LL << 10;
+ const size_t l3_cache_size = 30LL << 20;
// Use L3 cache size as target working set size.
const size_t target_working_set_size = l3_cache_size / sizeof(T);
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 703ef194a1..586677a2d6 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -189,14 +189,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- batch_results_((params.dataset->num_parallel_calls_ +
- params.dataset->batch_size_ - 1) /
- params.dataset->batch_size_) {
- for (int i = 0; i < batch_results_.size(); ++i) {
- batch_results_[i].Initialize(params.dataset->batch_size_);
- }
- }
+ : DatasetIterator<Dataset>(params) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -216,17 +209,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- mutex_lock external_l(external_mu_);
- mutex_lock l(mu_);
- EnsureRunnerThreadStarted(ctx);
- BatchResult* result = &batch_results_[ComputeIndex(input_batch_)];
- WaitForBatch(result, &l);
+ std::shared_ptr<BatchResult> result;
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (batch_results_.empty() ||
+ batch_results_.front()->num_calls > 0) {
+ cond_var_.wait(l);
+ }
+ std::swap(result, batch_results_.front());
+ batch_results_.pop_front();
+ cond_var_.notify_all();
+ }
return ProcessBatch(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
@@ -236,10 +235,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("call_counter"), call_counter_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_batch"), input_batch_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("output_batch"), output_batch_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"),
batch_results_.size()));
for (size_t i = 0; i < batch_results_.size(); ++i) {
@@ -250,19 +245,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("input_batch"), &input_batch_));
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("output_batch"), &output_batch_));
int64 batch_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_results_size"),
&batch_results_size));
- CHECK_EQ(batch_results_.size(), batch_results_size);
for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
}
@@ -271,21 +260,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
struct BatchResult {
- mutex mu;
- bool end_of_input GUARDED_BY(mu);
- int64 num_elements GUARDED_BY(mu);
- std::vector<Tensor> output;
- bool output_allocated GUARDED_BY(mu);
- Status status GUARDED_BY(mu);
- // Used for coordination between the main thread and the callback
- // threads. In particular, the main thread will wait for the value
- // of `num_calls` to reach zero before processing the batch result.
- condition_variable cond_var; // access guarded by owner's mutex
- // Counts the number of outstanding calls for this batch.
- int64 num_calls; // access guarded by owner's mutex
-
- void Initialize(int64 batch_size) {
- mutex_lock l(mu);
+ explicit BatchResult(int64 batch_size) {
end_of_input = false;
num_calls = batch_size;
num_elements = 0;
@@ -297,12 +272,21 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu);
status.Update(s);
}
+
+ mutex mu;
+ bool end_of_input GUARDED_BY(mu);
+ int64 num_elements GUARDED_BY(mu);
+ std::vector<Tensor> output;
+ bool output_allocated GUARDED_BY(mu);
+ Status status GUARDED_BY(mu);
+ // Counts the number of outstanding calls for this batch.
+ int64 num_calls; // access guarded by owner's mutex
};
void Callback(const std::shared_ptr<IteratorContext>& ctx,
- BatchResult* result, std::vector<Tensor>* return_values,
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values,
int64 offset, const Status& status) {
- std::unique_ptr<std::vector<Tensor>> cleanup_retvals(return_values);
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -340,15 +324,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
}
- void CallCompleted(BatchResult* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ void CallCompleted(const std::shared_ptr<BatchResult>& result)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
num_calls_--;
cond_var_.notify_all();
result->num_calls--;
- result->cond_var.notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
- BatchResult* result, int64 offset) {
+ const std::shared_ptr<BatchResult>& result,
+ int64 offset) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
@@ -370,9 +355,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
(*ctx->runner())(std::bind(
[this, result, offset](std::shared_ptr<IteratorContext> ctx,
std::vector<Tensor> input_element) {
- std::vector<Tensor>* return_values = new std::vector<Tensor>();
+ std::shared_ptr<std::vector<Tensor>> return_values(
+ new std::vector<Tensor>());
dataset()->captured_func_->RunAsync(
- ctx.get(), std::move(input_element), return_values,
+ ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
});
@@ -380,10 +366,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx, std::move(input_element)));
}
- int64 ComputeIndex(int64 n) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return n % batch_results_.size();
- }
-
Status CopyPartialBatch(Tensor* output, const Tensor& value,
int64 num_elements) {
switch (value.dtype()) {
@@ -417,9 +399,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
}
- void EnsureOutputAllocated(const std::shared_ptr<IteratorContext>& ctx,
- BatchResult* result,
- const std::vector<Tensor>* return_values) {
+ void EnsureOutputAllocated(
+ const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values) {
mutex_lock l(result->mu);
if (result->output_allocated) {
return;
@@ -437,15 +420,15 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- Status ProcessBatch(IteratorContext* ctx, BatchResult* result,
+ int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ }
+
+ Status ProcessBatch(IteratorContext* ctx,
+ const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- auto cleanup =
- gtl::MakeCleanup([this, result]() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- result->Initialize(dataset()->batch_size_);
- input_batch_++;
- cond_var_.notify_all();
- });
+ bool* end_of_sequence) {
mutex_lock l(result->mu);
if (result->num_elements == 0) {
*end_of_sequence = true;
@@ -489,8 +472,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
while (true) {
while (!cancelled_ &&
- (num_calls_ == dataset()->num_parallel_calls_ ||
- (output_batch_ - input_batch_ == batch_results_.size()))) {
+ (num_calls_ >= dataset()->num_parallel_calls_ ||
+ batch_results_.size() > MaxBatchResults() ||
+ (batch_results_.size() == MaxBatchResults() &&
+ call_counter_ % dataset()->batch_size_ == 0))) {
cond_var_.wait(l);
}
@@ -499,31 +484,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
while (num_calls_ < dataset()->num_parallel_calls_ &&
- (output_batch_ - input_batch_ < batch_results_.size())) {
- BatchResult* result = &batch_results_[ComputeIndex(output_batch_)];
+ (batch_results_.size() < MaxBatchResults() ||
+ (batch_results_.size() == MaxBatchResults() &&
+ call_counter_ % dataset()->batch_size_ != 0))) {
+ if (call_counter_ % dataset()->batch_size_ == 0) {
+ batch_results_.emplace_back(
+ new BatchResult(dataset()->batch_size_));
+ }
+ std::shared_ptr<BatchResult> result = batch_results_.back();
int64 offset = call_counter_++ % dataset()->batch_size_;
num_calls_++;
mu_.unlock();
CallFunction(ctx, result, offset);
mu_.lock();
- if (offset + 1 == dataset()->batch_size_) {
- // Done scheduling calls for the current batch.
- output_batch_++;
- }
}
}
}
- void WaitForBatch(BatchResult* result, mutex_lock* l)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- while (result->num_calls > 0) {
- result->cond_var.wait(*l);
- }
- }
-
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- BatchResult* result = &batch_results_[index];
+ batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
+ std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
result->end_of_input = reader->Contains(
@@ -585,7 +566,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- BatchResult* result = &batch_results_[index];
+ std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
if (result->end_of_input) {
@@ -646,21 +627,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
- // Used for serializing external parallelism.
- mutex external_mu_ ACQUIRED_BEFORE(mu_);
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
int64 call_counter_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
- // Identifies the next batch to be read by the caller.
- int64 input_batch_ GUARDED_BY(mu_) = 0;
- // Identifies the next batch to create.
- int64 output_batch_ GUARDED_BY(mu_) = 0;
- // Circular buffer for storing the (intermediate) batch results. When
- // using `input_batch_` and `output_batch_` to index into the buffer,
- // their value should be interpreted modulo the size of the buffer.
- std::vector<BatchResult> batch_results_ GUARDED_BY(mu_);
+ // Buffer for storing the (intermediate) batch results.
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index e2b6aa590e..2bafb985ef 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -39,8 +39,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx,
- buffer_size > 0 || buffer_size == PrefetchAutotuner::kAutoTune,
- errors::InvalidArgument("buffer_size must be > 0"));
+ buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune,
+ errors::InvalidArgument("buffer_size must be >= 0"));
*output = new Dataset(ctx, input, buffer_size);
}
@@ -112,13 +112,13 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
-
- while (true) {
+ {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
- while (!cancelled_ && !prefetch_thread_finished_ && buffer_.empty()) {
+ while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
+ auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
cond_var_.wait(l);
}
@@ -129,29 +129,20 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
}
if (!buffer_.empty()) {
- // A new element is available. Forward the status from
- // computing it, and (if we successfully got an element)
- // the output values.
- Status s = buffer_.front().status;
- if (s.ok()) {
- *out_tensors = std::move(buffer_.front().value);
- }
- auto_tuner_.RecordConsumption(buffer_.size());
- buffer_.pop_front();
- *end_of_sequence = false;
-
- // Wake the prefetch thread, in case it has been waiting
- // for space in the buffer.
- // Also wake up threads from other calls to GetNext.
- // TODO(mrry): Consider using different condition variables
- // for GetNext and Prefetch.
- cond_var_.notify_all();
- return s;
- } else if (prefetch_thread_finished_) {
+ return Consume(out_tensors, end_of_sequence);
+ }
+
+ if (prefetch_thread_finished_) {
*end_of_sequence = true;
return Status::OK();
}
+
+ DCHECK_EQ(auto_tuner_.buffer_limit(), 0);
}
+
+ mutex_lock parent_l(parent_mu_);
+ mutex_lock l(mu_);
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
@@ -227,6 +218,26 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> value;
};
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // A new element is available. Forward the status from computing it, and
+ // (if we successfully got an element) the output values.
+ Status s = buffer_.front().status;
+ if (s.ok()) {
+ *out_tensors = std::move(buffer_.front().value);
+ }
+ buffer_.pop_front();
+ *end_of_sequence = false;
+
+ // Wake the prefetch thread, in case it has been waiting for space
+ // in the buffer. Also wake up threads from other calls to GetNext.
+ //
+ // TODO(mrry): Consider using different condition variables for
+ // GetNext and Prefetch.
+ cond_var_.notify_all();
+ return s;
+ }
+
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
@@ -251,7 +262,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
{
mutex_lock l(mu_);
while (!cancelled_ &&
- buffer_.size() == auto_tuner_.buffer_limit()) {
+ buffer_.size() >= auto_tuner_.buffer_limit()) {
cond_var_.wait(l);
}
diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc
index 85a9702ae7..1aa8c72d66 100644
--- a/tensorflow/core/kernels/deep_conv2d.cc
+++ b/tensorflow/core/kernels/deep_conv2d.cc
@@ -393,8 +393,9 @@ struct TransformFilters {
// Calculate filter transform batch based on cache/filter sizes.
- // Cache budget (based on L2 cache size).
- const int64 cache_size = Eigen::l2CacheSize() / sizeof(T);
+ // Cache budget (based on L2 cache size = 256KB).
+ // TODO(andydavis) Read cache size from system.
+ const int64 cache_size = (256LL << 10) / sizeof(T);
// Fixed cost.
const int64 filter_transform_matrix_size =
@@ -1017,8 +1018,9 @@ struct DeepConv2D<CPUDevice, T> {
const int64 filter_shard_size = filter_shards_row * filter_shards_col;
const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
- // Cache budget (based on L2 cache size).
- const int64 cache_size = Eigen::l2CacheSize() / sizeof(T);
+ // Cache budget (based on L2 cache size = 256KB).
+ // TODO(andydavis) Read cache size from the system.
+ const int64 cache_size = (256LL << 10) / sizeof(T);
// Fixed costs.
const int64 tile_transform_matrix_size =
diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc
index 81e9fcfa95..b8028291b4 100644
--- a/tensorflow/core/ops/control_flow_ops.cc
+++ b/tensorflow/core/ops/control_flow_ops.cc
@@ -145,13 +145,12 @@ REGISTER_OP("Enter")
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
- } else {
- // Otherwise, propagate shape if output is a constant.
- bool is_constant;
- TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
- if (is_constant) {
- c->set_output(0, c->input(0));
- }
+ }
+ // Propagate shape if output is a constant.
+ bool is_constant;
+ TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
+ if (is_constant) {
+ c->set_output(0, c->input(0));
}
return Status::OK();
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 22ae6121e0..ec77861480 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -804,7 +804,9 @@ void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
mutex_lock l(block_cache_lock_);
file_block_cache_ =
MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs);
- stats_->Configure(this, &throttle_, file_block_cache_.get());
+ if (stats_ != nullptr) {
+ stats_->Configure(this, &throttle_, file_block_cache_.get());
+ }
}
// A helper function to build a FileBlockCache for GcsFileSystem.
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 7d558fe880..a319ccbdbe 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -634,6 +634,7 @@ def tf_additional_cloud_op_deps():
"//tensorflow:with_gcp_support_ios_override": [],
"//tensorflow:with_gcp_support": [
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
+ "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
],
"//conditions:default": [],
})
@@ -646,6 +647,7 @@ def tf_additional_cloud_kernel_deps():
"//tensorflow:with_gcp_support_ios_override": [],
"//tensorflow:with_gcp_support": [
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
+ "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
],
"//conditions:default": [],
})
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index d4311d1ab0..a5f7ecf0d1 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -43,6 +43,10 @@ string ToString(TensorFormat format) {
return "NCHW_VECT_C";
case FORMAT_NHWC_VECT_W:
return "NHWC_VECT_W";
+ case FORMAT_HWNC:
+ return "HWNC";
+ case FORMAT_HWCN:
+ return "HWCN";
default:
LOG(FATAL) << "Invalid Format: " << static_cast<int32>(format);
return "INVALID_FORMAT";
@@ -80,6 +84,14 @@ bool FormatFromString(const string& format_str, TensorFormat* format) {
*format = FORMAT_NHWC_VECT_W;
return true;
}
+ if (format_str == "HWNC") {
+ *format = FORMAT_HWNC;
+ return true;
+ }
+ if (format_str == "HWCN") {
+ *format = FORMAT_HWCN;
+ return true;
+ }
return false;
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index d3d5602f92..918835e1fb 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -59,6 +59,12 @@ enum TensorFormat {
// In the future we may change the meaning of these enums to include vectors
// of other types such as int16x2, with op implementations automatically
// determining which format is implied based on the datatype.
+
+ // FORMAT_HWNC is for TPUs.
+ FORMAT_HWNC = 4,
+
+ // FORMAT_HWCN is for TPUs.
+ FORMAT_HWCN = 5,
};
// Tensor format for convolutional filters.
@@ -105,11 +111,11 @@ string ToString(FilterTensorFormat format);
inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
switch (format) {
case FORMAT_NHWC:
- return num_dims - 2; // Exclude N,C.
case FORMAT_NCHW:
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
return num_dims - 2; // Exclude N,C.
case FORMAT_NCHW_VECT_C:
- return num_dims - 3; // Exclude N,C,VectDim.
case FORMAT_NHWC_VECT_W:
// Note: the VECT_W is not counted as an independent spatial dim here,
// since it just a component of the width dimension.
@@ -132,6 +138,8 @@ inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
switch (format) {
case FORMAT_NHWC:
case FORMAT_NCHW:
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
return num_spatial_dims + 2; // Include N,C.
case FORMAT_NCHW_VECT_C:
case FORMAT_NHWC_VECT_W:
@@ -158,6 +166,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
case FORMAT_NCHW_VECT_C:
case FORMAT_NHWC_VECT_W:
return 0;
+ case FORMAT_HWNC:
+ return num_dims - 2;
+ case FORMAT_HWCN:
+ return num_dims - 1;
default:
LOG(FATAL) << "Unknown format " << format;
return -1; // Avoid compiler warning about missing return value
@@ -170,8 +182,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
switch (format) {
case FORMAT_NHWC:
+ case FORMAT_HWNC:
return num_dims - 1;
case FORMAT_NHWC_VECT_W:
+ case FORMAT_HWCN:
return num_dims - 2;
case FORMAT_NCHW:
case FORMAT_NCHW_VECT_C:
@@ -210,6 +224,9 @@ inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
case FORMAT_NCHW:
case FORMAT_NCHW_VECT_C:
return spatial_dim + 2;
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
+ return spatial_dim;
default:
LOG(FATAL) << "Unknown format " << format;
return -1; // Avoid compiler warning about missing return value
@@ -310,6 +327,32 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
LOG(FATAL) << "Invalid dimension: " << dimension;
return -1; // Avoid compiler warning about missing return value
}
+ } else if (format == FORMAT_HWNC) {
+ switch (dimension) {
+ case '0': return 0;
+ case '1': return 1;
+ case '2': return 2;
+ case 'H': return NUM_SPATIAL_DIMS - 2;
+ case 'W': return NUM_SPATIAL_DIMS - 1;
+ case 'N': return NUM_SPATIAL_DIMS;
+ case 'C': return NUM_SPATIAL_DIMS + 1;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ } else if (format == FORMAT_HWCN) {
+ switch (dimension) {
+ case '0': return 0;
+ case '1': return 1;
+ case '2': return 2;
+ case 'H': return NUM_SPATIAL_DIMS - 2;
+ case 'W': return NUM_SPATIAL_DIMS - 1;
+ case 'C': return NUM_SPATIAL_DIMS;
+ case 'N': return NUM_SPATIAL_DIMS + 1;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
} else {
LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
return -1; // Avoid compiler warning about missing return value
diff --git a/tensorflow/core/util/tensor_format_test.cc b/tensorflow/core/util/tensor_format_test.cc
index 93902290eb..07cdce998a 100644
--- a/tensorflow/core/util/tensor_format_test.cc
+++ b/tensorflow/core/util/tensor_format_test.cc
@@ -26,10 +26,9 @@ namespace tensorflow {
{ val, #val }
std::pair<TensorFormat, const char*> test_data_formats[] = {
- EnumStringPair(FORMAT_NHWC),
- EnumStringPair(FORMAT_NCHW),
- EnumStringPair(FORMAT_NCHW_VECT_C),
- EnumStringPair(FORMAT_NHWC_VECT_W),
+ EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW),
+ EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W),
+ EnumStringPair(FORMAT_HWNC), EnumStringPair(FORMAT_HWCN),
};
std::pair<FilterTensorFormat, const char*> test_filter_formats[] = {
@@ -85,6 +84,16 @@ struct DimMaps {
{ 0, 2, 3, 1, { 2, 3, -1 } },
{ 0, 3, 4, 1, { 2, 3, 4 } }
};
+ StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid,
+ { 1, -1, 0, 2, { 0, -1, -1 } },
+ { 2, 0, 1, 3, { 0, 1, -1 } },
+ { 3, 1, 2, 4, { 0, 1, 2 } }
+ };
+ StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid,
+ { 2, -1, 0, 1, { 0, -1, -1 } },
+ { 3, 0, 1, 2, { 0, 1, -1 } },
+ { 4, 1, 2, 3, { 0, 1, 2 } }
+ };
#undef StaCoExTensorDm
#define StaCoExFilterDm static constexpr FilterDimMap
// 'H', 'W', 'I', 'O' 0 1 2
@@ -108,8 +117,10 @@ GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) {
(format == FORMAT_NHWC ||
format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] :
(format == FORMAT_NCHW ||
- format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims]
- : DimMaps::kTdmInvalid;
+ format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] :
+ (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] :
+ (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims]
+ : DimMaps::kTdmInvalid;
}
inline constexpr const FilterDimMap&
@@ -126,6 +137,8 @@ GetFilterDimMap(const int num_spatial_dims,
constexpr TensorDimMap DimMaps::kTdmInvalid;
constexpr TensorDimMap DimMaps::kTdmNHWC[4];
constexpr TensorDimMap DimMaps::kTdmNCHW[4];
+constexpr TensorDimMap DimMaps::kTdmHWNC[4];
+constexpr TensorDimMap DimMaps::kTdmHWCN[4];
constexpr FilterDimMap DimMaps::kFdmInvalid;
constexpr FilterDimMap DimMaps::kFdmHWIO[4];
constexpr FilterDimMap DimMaps::kFdmOIHW[4];
diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/docs_src/mobile/tflite/demo_android.md
index 7f2f8882a2..480d66bbb6 100644
--- a/tensorflow/docs_src/mobile/tflite/demo_android.md
+++ b/tensorflow/docs_src/mobile/tflite/demo_android.md
@@ -58,6 +58,9 @@ To get a model, either:
Now you can build and run the demo app.
+Some additional details are available on the
+[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md).
+
## Build TensorFlow Lite and the demo app from source
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index d7bc6a5a7d..d4070fdd1e 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -97,7 +97,7 @@ py_binary(
py_test(
name = "fully_connected_feed_test",
- size = "small",
+ size = "medium",
srcs = [
"fully_connected_feed.py",
],
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index e86c2f6993..3bde62fa1d 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -179,6 +179,7 @@ tf_py_test(
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dataset_ops_gen",
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index 646324cb95..63a0830272 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -24,35 +26,33 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase):
+class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
- def testBufferSize(self):
- buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
+ @parameterized.parameters((-1), (0), (5))
+ def testBufferSize(self, buffer_size):
+ buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size).make_initializable_iterator()
+ buffer_size=buffer_size_t).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
- sess.run(init_op, feed_dict={buffer_size: 5})
+ sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
for m in range(10):
self.assertEqual(m, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testInvalidBufferSize(self):
- buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
+ @parameterized.parameters((-2), (-42))
+ def testInvalidBufferSize(self, buffer_size):
+ buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size).make_initializable_iterator()
+ buffer_size=buffer_size_t).make_initializable_iterator()
init_op = iterator.initializer
with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
with self.test_session() as sess:
- sess.run(init_op, feed_dict={buffer_size: 0})
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
- with self.test_session() as sess:
- sess.run(init_op, feed_dict={buffer_size: -5})
+ sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
if __name__ == "__main__":
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index d0deed5ede..9e7af878d3 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -108,12 +108,7 @@ class Dataset(object):
if shared_name is None:
shared_name = ""
iterator_resource = gen_dataset_ops.iterator(
- container="",
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
iterator_resource)
@@ -171,13 +166,8 @@ class Dataset(object):
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
- dataset_factory=_make_dataset,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes,
- self.output_classes))), None,
- self.output_types, self.output_shapes, self.output_classes)
+ dataset_factory=_make_dataset, **flat_structure(self)),
+ None, self.output_types, self.output_shapes, self.output_classes)
@abc.abstractproperty
def output_classes(self):
@@ -1158,6 +1148,121 @@ class SparseTensorSliceDataset(Dataset):
return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)
+class StructuredFunctionWrapper(object):
+ """A wrapper for `Defun` that supports structured arguments and return values.
+ """
+
+ def __init__(self, func, transformation_name, dataset=None,
+ input_classes=None, input_shapes=None, input_types=None,
+ add_to_graph=True):
+ """Creates a new `StructuredFunctionWrapper` for the given function.
+
+ Args:
+ func: A function from a nested structure to another nested structure.
+ transformation_name: Human-readable name of the transformation in which
+ this function is being instantiated, for error messages.
+ dataset: (Optional.) A @{tf.data.Dataset}. If given, the structure of this
+ dataset will be assumed as the structure for `func` arguments; otherwise
+ `input_classes`, `input_shapes`, and `input_types` must be defined.
+ input_classes: (Optional.) A nested structure of `type`. If given, this
+ argument defines the Python types for `func` arguments.
+ input_shapes: (Optional.) A nested structure of @{tf.TensorShape}. If
+ given, this argument defines the shapes and structure for `func`
+ arguments.
+ input_types: (Optional.) A nested structure of @{tf.DType}. If given, this
+ argument defines the element types and structure for `func` arguments.
+ add_to_graph: (Optional.) If `True`, the function will be added to the
+ default graph.
+
+ Raises:
+ ValueError: If an invalid combination of `dataset`, `input_classes`,
+ `input_shapes`, and `input_types` is passed.
+ """
+ if dataset is None:
+ if input_classes is None or input_shapes is None or input_types is None:
+ raise ValueError("Either `dataset`, or all of `input_classes`, "
+ "`input_shapes`, and `input_types` must be specified.")
+ self._input_shapes = input_shapes
+ self._input_types = input_types
+ self._input_classes = input_classes
+ else:
+ if not (input_classes is None and input_shapes is None and
+ input_types is None):
+ raise ValueError("Either `dataset`, or all of `input_classes`, "
+ "`input_shapes`, and `input_types` must be specified.")
+ self._input_shapes = dataset.output_shapes
+ self._input_types = dataset.output_types
+ self._input_classes = dataset.output_classes
+
+ @function.Defun(*defun_args(
+ input_types=self._input_types, input_classes=self._input_classes))
+ def tf_data_structured_function_wrapper(*args):
+ """Wrapper for passing nested structures to and from tf.data functions."""
+ nested_args = restructure_args(args,
+ input_shapes=self._input_shapes,
+ input_types=self._input_types,
+ input_classes=self._input_classes)
+ ret = func(*nested_args)
+ # If `func` returns a list of tensors, `nest.flatten()` and
+ # `ops.convert_to_tensor()` would conspire to attempt to stack
+ # those tensors into a single tensor, because the customized
+ # version of `nest.flatten()` does not recurse into lists. Since
+ # it is more likely that the list arose from returning the
+ # result of an operation (such as `tf.py_func()`) that returns a
+ # list of not-necessarily-stackable tensors, we treat the
+ # returned value is a `tuple` instead. A user wishing to pack
+ # the return value into a single tensor can use an explicit
+ # `tf.stack()` before returning.
+ if isinstance(ret, list):
+ ret = tuple(ret)
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ self._output_classes = sparse.get_classes(ret)
+ self._output_shapes = nest.pack_sequence_as(
+ ret, [t.get_shape() for t in nest.flatten(ret)])
+ self._output_types = nest.pack_sequence_as(
+ ret, [t.dtype for t in nest.flatten(ret)])
+
+ _warn_if_collections(transformation_name)
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ self._function = tf_data_structured_function_wrapper
+ if add_to_graph:
+ self._function.add_to_graph(ops.get_default_graph())
+ else:
+ # Use the private method that will execute
+ # `tf_data_structured_function_wrapper` but delay adding it to the graph
+ # in case (e.g.) we need to rerun the function.
+ self._function._create_definition_if_needed() # pylint: disable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def function(self):
+ return self._function
+
+
def flat_structure(dataset):
"""Helper for setting `output_shapes` and `output_types` attrs of Dataset ops.
@@ -1182,6 +1287,109 @@ def flat_structure(dataset):
}
+# TODO(mrry): Investigate adding a `Defun` wrapper that combines
+# `defun_args()`, `restructure_args()`, and a future helper that consumes the
+# outputs of the wrapped function.
+def defun_args(dataset=None, input_types=None, input_classes=None):
+ """Returns a flat list of @{tf.DType} for a given element structure.
+
+ The expected usage for an example function is as follows:
+
+ ```python
+ input_dataset = ... # A `tf.data.Dataset`.
+
+ @function.Defun(*defun_args(input_dataset))
+ def tf_example_func(*args):
+ nested_args = restructure_args(args, input_dataset)
+ # [Destructure and handle the return values from `example_func()`.
+ ```
+
+ Either `dataset`, or both of `input_types` and `input_classes` must be
+ specified. If `dataset` is not specified, the structures of `input_types` and
+ `input_classes` must be compatible.
+
+ Args:
+ dataset: (Optional.) A @{tf.data.Dataset} whose element structure should
+ be flattened.
+ input_types: (Optional.) A nested structure of @{tf.DType} with the desired
+ structure and types for each argument.
+ input_classes: (Optional.) A nested structure of `type` with the desired
+ structure and classes for each argument.
+
+ Returns:
+ A flat list of @{tf.DType} for the given element structure.
+ """
+ if input_types is None:
+ assert dataset is not None
+ assert input_classes is None
+ input_types = dataset.output_types
+ input_classes = dataset.output_classes
+ else:
+ assert input_types is not None and input_classes is not None
+ return nest.flatten(
+ sparse.as_dense_types(input_types, input_classes))
+
+
+def restructure_args(args, dataset=None, input_shapes=None, input_types=None,
+ input_classes=None):
+ """Converts a flat tuple of arguments into a given structure.
+
+ The intended use is to bridge between the flat tuple of unshaped @{tf.Tensor}
+ arguments that a `Defun` receives and the potentially nested structures that
+ `tf.data` functions expect.
+
+ The expected usage for an example function is as follows:
+
+ ```python
+ input_dataset = ... # A `tf.data.Dataset`.
+
+ @function.Defun(*defun_args(input_dataset))
+ def tf_example_func(*args):
+ nested_args = restructure_args(args, input_dataset)
+ ret = example_func(*nested_args)
+ # [Destructure and handle the return values from `example_func()`.
+ ```
+
+ Either `dataset`, or all of `input_shapes`, `input_types` and `input_classes`
+ must be specified. If `dataset` is not specified, the structures of
+ `input_shapes`, `input_types` and `input_classes` must be compatible.
+
+ Args:
+ args: A flat tuple of @{tf.Tensor} objects, representing the arguments
+ to a TensorFlow function.
+ dataset: (Optional.) A @{tf.data.Dataset} whose element structure matches
+ the desired structure of the arguments.
+ input_shapes: (Optional.) A nested structure of @{tf.TensorShape} with the
+ desired structure and static shapes for each argument.
+ input_types: (Optional.) A nested structure of @{tf.DType} with the desired
+ structure and types for each argument.
+ input_classes: (Optional.) A nested structure of `type` with the desired
+ structure and classes for each argument.
+
+ Returns:
+ A nested structure representing the arguments.
+ """
+ if input_shapes is None:
+ assert dataset is not None
+ assert input_types is None and input_classes is None
+ input_shapes = dataset.output_shapes
+ input_types = dataset.output_types
+ input_classes = dataset.output_classes
+ else:
+ assert input_types is not None and input_classes is not None
+
+ dense_shapes = sparse.as_dense_shapes(input_shapes, input_classes)
+ for arg, shape in zip(args, nest.flatten(dense_shapes)):
+ arg.set_shape(shape)
+
+ nested_args = nest.pack_sequence_as(input_classes, args)
+ nested_args = sparse.deserialize_sparse_tensors(
+ nested_args, input_types, input_shapes, input_classes)
+ if not _should_unpack_args(nested_args):
+ nested_args = (nested_args,)
+ return nested_args
+
+
class _GeneratorDataset(Dataset):
"""A `Dataset` that generates elements by invoking a function."""
@@ -1214,137 +1422,26 @@ class _GeneratorDataset(Dataset):
init_args_types = nest.pack_sequence_as(
init_args, [t.dtype for t in nest.flatten(init_args)])
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(init_args_types, init_args_classes)))
- def tf_init_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- dense_shapes = sparse.as_dense_shapes(init_args_shapes, init_args_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(init_args_classes, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, init_args_types, init_args_shapes, init_args_classes)
- if _should_unpack_args(nested_args):
- ret = init_func(*nested_args)
- else:
- ret = init_func(nested_args)
-
- # If `init_func` returns a list of tensors, `nest.flatten()` and
- # `ops.convert_to_tensor()` would conspire to attempt to stack
- # those tensors into a single tensor, because the customized
- # version of `nest.flatten()` does not recurse into lists. Since
- # it is more likely that the list arose from returning the
- # result of an operation (such as `tf.py_func()`) that returns a
- # list of not-necessarily-stackable tensors, we treat the
- # returned value is a `tuple` instead. A user wishing to pack
- # the return value into a single tensor can use an explicit
- # `tf.stack()` before returning.
- if isinstance(ret, list):
- ret = tuple(ret)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._state_classes = sparse.get_classes(ret)
- self._state_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._state_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._init_func = tf_init_func
- self._init_func.add_to_graph(ops.get_default_graph())
-
- # These members will be initialized by `tf_next_func`.
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(self._state_types, self._state_classes)))
- def tf_next_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(self._state_shapes,
- self._state_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(self._state_classes, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, self._state_types, self._state_shapes,
- self._state_classes)
- if _should_unpack_args(nested_args):
- ret = next_func(*nested_args)
- else:
- ret = next_func(nested_args)
-
- # If `next_func` returns a list of tensors, `nest.flatten()` and
- # `ops.convert_to_tensor()` would conspire to attempt to stack
- # those tensors into a single tensor, because the customized
- # version of `nest.flatten()` does not recurse into lists. Since
- # it is more likely that the list arose from returning the
- # result of an operation (such as `tf.py_func()`) that returns a
- # list of not-necessarily-stackable tensors, we treat the
- # returned value is a `tuple` instead. A user wishing to pack
- # the return value into a single tensor can use an explicit
- # `tf.stack()` before returning.
- if isinstance(ret, list):
- ret = tuple(ret)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._output_classes = sparse.get_classes(ret)
- self._output_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._output_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._next_func = tf_next_func
- self._next_func.add_to_graph(ops.get_default_graph())
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(self._state_types, self._state_classes)))
- def tf_finalize_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the state.
- dense_shapes = sparse.as_dense_shapes(self._state_shapes,
- self._state_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(self._state_classes, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, self._state_types, self._state_shapes,
- self._state_classes)
- if _should_unpack_args(nested_args):
- return finalize_func(*nested_args)
- else:
- return finalize_func(nested_args)
-
- self._finalize_func = tf_finalize_func
- self._finalize_func.add_to_graph(ops.get_default_graph())
+ wrapped_init_func = StructuredFunctionWrapper(
+ init_func, "GeneratorDataset", input_classes=init_args_classes,
+ input_shapes=init_args_shapes, input_types=init_args_types)
+ self._state_classes = wrapped_init_func.output_classes
+ self._state_shapes = wrapped_init_func.output_shapes
+ self._state_types = wrapped_init_func.output_types
+ self._init_func = wrapped_init_func.function
+
+ wrapped_next_func = StructuredFunctionWrapper(
+ next_func, "GeneratorDataset", input_classes=self._state_classes,
+ input_shapes=self._state_shapes, input_types=self._state_types)
+ self._output_classes = wrapped_next_func.output_classes
+ self._output_shapes = wrapped_next_func.output_shapes
+ self._output_types = wrapped_next_func.output_types
+ self._next_func = wrapped_next_func.function
+
+ wrapped_finalize_func = StructuredFunctionWrapper(
+ finalize_func, "GeneratorDataset", input_classes=self._state_classes,
+ input_shapes=self._state_shapes, input_types=self._state_types)
+ self._finalize_func = wrapped_finalize_func.function
def _as_variant_tensor(self):
return gen_dataset_ops.generator_dataset(
@@ -1498,6 +1595,7 @@ class RangeDataset(Dataset):
self._parse_args(*args)
def _parse_args(self, *args):
+ """Parse arguments according to the same rules as the `range()` builtin."""
if len(args) == 1:
self._start = self._build_tensor(0, "start")
self._stop = self._build_tensor(args[0], "stop")
@@ -1823,7 +1921,7 @@ def _padding_value_to_tensor(value, output_type):
def _default_padding(input_dataset):
-
+ """Returns default padding tensors in a structure matching `input_dataset`."""
def make_zero(t):
if t.base_dtype == dtypes.string:
return ""
@@ -1949,66 +2047,12 @@ class MapDataset(Dataset):
super(MapDataset, self).__init__()
self._input_dataset = input_dataset
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_map_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- if _should_unpack_args(nested_args):
- ret = map_func(*nested_args)
- else:
- ret = map_func(nested_args)
-
- # If `map_func` returns a list of tensors, `nest.flatten()` and
- # `ops.convert_to_tensor()` would conspire to attempt to stack
- # those tensors into a single tensor, because the customized
- # version of `nest.flatten()` does not recurse into lists. Since
- # it is more likely that the list arose from returning the
- # result of an operation (such as `tf.py_func()`) that returns a
- # list of not-necessarily-stackable tensors, we treat the
- # returned value is a `tuple` instead. A user wishing to pack
- # the return value into a single tensor can use an explicit
- # `tf.stack()` before returning.
- if isinstance(ret, list):
- ret = tuple(ret)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._output_classes = sparse.get_classes(ret)
- self._output_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._output_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- _warn_if_collections("Dataset.map()")
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._map_func = tf_map_func
- self._map_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = StructuredFunctionWrapper(
+ map_func, "Dataset.map()", input_dataset)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
@@ -2061,39 +2105,20 @@ class FlatMapDataset(Dataset):
super(FlatMapDataset, self).__init__()
self._input_dataset = input_dataset
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_map_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- if _should_unpack_args(nested_args):
- dataset = map_func(*nested_args)
- else:
- dataset = map_func(nested_args)
-
+ # TODO(b/110122868): When we handle nested datasets natively as the return
+ # value from `map_func`, we can avoid needing this wrapper.
+ def map_func_wrapper(*args):
+ dataset = map_func(*args)
if not isinstance(dataset, Dataset):
raise TypeError("`map_func` must return a `Dataset` object.")
-
- _warn_if_collections(self._transformation_name())
-
self._output_classes = dataset.output_classes
- self._output_types = dataset.output_types
self._output_shapes = dataset.output_shapes
-
+ self._output_types = dataset.output_types
return dataset._as_variant_tensor() # pylint: disable=protected-access
- self._map_func = tf_map_func
- self._map_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = StructuredFunctionWrapper(
+ map_func_wrapper, self._transformation_name(), input_dataset)
+ self._map_func = wrapped_func.function
def _as_variant_tensor(self):
return gen_dataset_ops.flat_map_dataset(
@@ -2150,38 +2175,13 @@ class FilterDataset(Dataset):
"""See `Dataset.filter()` for details."""
super(FilterDataset, self).__init__()
self._input_dataset = input_dataset
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_predicate(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- if _should_unpack_args(nested_args):
- ret = predicate(*nested_args)
- else:
- ret = predicate(nested_args)
-
- ret = ops.convert_to_tensor(ret, dtype=dtypes.bool)
- if not (ret.dtype == dtypes.bool and
- ret.shape.is_compatible_with(tensor_shape.scalar())):
- raise ValueError("`predicate` must return a scalar boolean tensor.")
-
- _warn_if_collections("Dataset.filter()")
-
- return ret
-
- self._predicate = tf_predicate
- self._predicate.add_to_graph(ops.get_default_graph())
+ wrapped_func = StructuredFunctionWrapper(
+ predicate, "Dataset.filter()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.bool and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError("`predicate` must return a scalar boolean tensor.")
+ self._predicate = wrapped_func.function
def _as_variant_tensor(self):
return gen_dataset_ops.filter_dataset(
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index dee86966f1..e8a7904a88 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -391,3 +391,20 @@ py_library(
srcs = ["imperative_grad.py"],
srcs_version = "PY2AND3",
)
+
+cuda_py_test(
+ name = "memory_test",
+ size = "medium",
+ srcs = ["memory_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/keras",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "optonly", # The test is too slow in non-opt mode
+ ],
+)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 03393bcd46..dd3166735c 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -222,6 +222,11 @@ def _inference_name(n):
return "__inference_%s_%s" % (n, ops.uid())
+def _register(fn):
+ """Registers the function `fn`."""
+ context.context().add_function(fn)
+
+
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
@@ -591,7 +596,7 @@ def _get_defun_inputs(args):
return nest.pack_sequence_as(args, ret)
-def _defun_internal(name, func, compiled, args, kwds):
+def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
@@ -699,42 +704,57 @@ def _cache_key(x):
return x
-def _register(fn):
- """Registers the function `fn`."""
- context.context().add_function(fn)
+class _PolymorphicFunction(object):
+ """Wrapper class for the graph functions defined for a Python function.
+ See the documentation for `defun` for more information on the semantics of
+ defined functions.
+ """
-# TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name, compiled=False):
- """Defines a function with a given name.
+ def __init__(self, python_function, name, compiled=False):
+ """Initializes a polymorphic function.
- See the documentation for `defun` for more information on the semantics of
- this function.
+ Args:
+ python_function: the function to be wrapped.
+ name: the name given to it.
+ compiled: if True, the framework will attempt to compile func with XLA.
+ """
- Args:
- func: the function to be wrapped.
- name: the name given to it.
- compiled: if true, the framework will attempt to compile func with XLA.
+ self._python_function = python_function
+ self._name = name
+ self._compiled = compiled
+ self._arguments_to_functions = {}
+ self._variables = []
- Returns:
- the wrapped function.
- """
- arguments_to_functions = {}
+ def _maybe_define_function(self, *args, **kwds):
+ """Gets a function for these inputs, defining it if necessary."""
- def decorated(*args, **kwds):
- """Decorated version of func."""
- # Macroexpand on non-Tensor arguments
- cache_key = tuple(_cache_key(x) for x in args)
+ # TODO(akshayka): Remove this restriction.
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
+
+ # TODO(apassos): Better error messages for non-hashable arguments.
+ cache_key = tuple(_cache_key(x) for x in args)
cache_key = (cache_key, tuple(kwds.items()))
- if cache_key not in arguments_to_functions:
- arguments_to_functions[cache_key] = _defun_internal(
- name, func, compiled, args, kwds)
- return arguments_to_functions[cache_key](*args)
+ if cache_key not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[cache_key] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function
+ else:
+ return self._arguments_to_functions[cache_key]
- return decorated
+ def __call__(self, *args, **kwds):
+ """Calls a graph function specialized for this input signature."""
+ return self._maybe_define_function(*args, **kwds)(*args)
+
+ @property
+ def variables(self):
+ """Returns a list of variables used in any of the defined functions."""
+ return self._variables
# TODO(akshayka): Remove the `compiled` flag and create a separate
@@ -991,7 +1011,7 @@ def defun(func=None, compiled=False):
except AttributeError:
name = "function"
return tf_decorator.make_decorator(
- function, named_defun(function, name, compiled=compiled))
+ function, _PolymorphicFunction(function, name, compiled=compiled))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1056,7 +1076,7 @@ def make_defun_op(func, *args, **kwds):
name = func.__name__
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
- return _defun_internal(name, func, False, args, kwds)
+ return _trace_and_define_function(name, func, False, args, kwds)
class AutomaticControlDependencies(object):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index cfdbe5f079..6ce2ceffda 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -633,6 +633,23 @@ class FunctionTest(test.TestCase):
y = model(x)
self.assertAllEqual([[[[4.0]]]], y.numpy())
+ def testVariablesAreTracked(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def foo(x):
+ return v * x
+
+ defined = function.defun(foo)
+
+ x = constant_op.constant([1.0])
+ self.assertAllEqual(defined.variables, [])
+ _ = defined(x)
+ self.assertAllEqual(defined.variables, [v])
+
+ x = constant_op.constant([1.0, 2.0])
+ _ = defined(x) # ensure the variables list remains the same
+ self.assertAllEqual(defined.variables, [v])
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
diff --git a/tensorflow/python/eager/memory_test.py b/tensorflow/python/eager/memory_test.py
new file mode 100644
index 0000000000..74c6cbdd31
--- /dev/null
+++ b/tensorflow/python/eager/memory_test.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for memory leaks in eager execution.
+
+It is possible that this test suite will eventually become flaky due to taking
+too long to run (since the tests iterate many times), but for now they are
+helpful for finding memory leaks since not all PyObject leaks are found by
+introspection (test_util decorators). Please be careful adding new tests here.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+
+# memory_profiler might not be available in the OSS version of TensorFlow.
+try:
+ import memory_profiler # pylint:disable=g-import-not-at-top
+except ImportError:
+ memory_profiler = None
+
+
+class SingleLayerNet(keras.Model):
+ """Simple keras model used to ensure that there are no leaks."""
+
+ def __init__(self):
+ super(SingleLayerNet, self).__init__()
+ self.fc1 = keras.layers.Dense(5)
+
+ def call(self, x):
+ return self.fc1(x)
+
+
+class MemoryTest(test.TestCase):
+
+ def assertNotIncreasingMemory(self,
+ f,
+ num_iters=100000,
+ increase_threshold_absolute_mb=10):
+ """Assert memory usage doesn't increase beyond given threshold for f."""
+
+ with context.eager_mode():
+ # Warm up.
+ f()
+
+ initial = memory_profiler.memory_usage(-1)[0]
+
+ for _ in xrange(num_iters):
+ f()
+
+ increase = memory_profiler.memory_usage(-1)[0] - initial
+
+ assert increase < increase_threshold_absolute_mb, (
+ "Increase is too high. Initial memory usage: %f MB. Increase: %f MB. "
+ "Maximum allowed increase: %f") % (initial, increase,
+ increase_threshold_absolute_mb)
+
+ def testMemoryLeakInSimpleModelForwardOnly(self):
+ if memory_profiler is None:
+ self.skipTest("memory_profiler required to run this test")
+
+ inputs = array_ops.zeros([32, 100], dtypes.float32)
+ net = SingleLayerNet()
+
+ def f():
+ with backprop.GradientTape():
+ net(inputs)
+
+ self.assertNotIncreasingMemory(f)
+
+ def testMemoryLeakInSimpleModelForwardAndBackward(self):
+ if memory_profiler is None:
+ self.skipTest("memory_profiler required to run this test")
+
+ inputs = array_ops.zeros([32, 100], dtypes.float32)
+ net = SingleLayerNet()
+
+ def f():
+ with backprop.GradientTape() as tape:
+ result = net(inputs)
+
+ tape.gradient(result, net.variables)
+
+ del tape
+
+ self.assertNotIncreasingMemory(f)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
index 15677ea3c1..20c7a69b7c 100644
--- a/tensorflow/python/estimator/canned/baseline.py
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -215,6 +215,13 @@ class BaselineClassifier(estimator.Estimator):
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
"""
def __init__(self,
@@ -313,6 +320,13 @@ class BaselineRegressor(estimator.Estimator):
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
"""
def __init__(self,
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 6b54f51ca6..86dbf272ef 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -714,7 +714,15 @@ def _create_regression_head(label_dimension, weight_column=None):
@estimator_export('estimator.BoostedTreesClassifier')
class BoostedTreesClassifier(estimator.Estimator):
- """A Classifier for Tensorflow Boosted Trees models."""
+ """A Classifier for Tensorflow Boosted Trees models.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
+ """
def __init__(self,
feature_columns,
@@ -832,7 +840,15 @@ class BoostedTreesClassifier(estimator.Estimator):
@estimator_export('estimator.BoostedTreesRegressor')
class BoostedTreesRegressor(estimator.Estimator):
- """A Regressor for Tensorflow Boosted Trees models."""
+ """A Regressor for Tensorflow Boosted Trees models.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
+ """
def __init__(self,
feature_columns,
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index b924ad5df4..90889e3e5d 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -266,7 +266,10 @@ class DNNClassifier(estimator.Estimator):
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -418,7 +421,10 @@ class DNNRegressor(estimator.Estimator):
Loss is calculated by using mean squared error.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 64d81c46ce..3d1ad1365b 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -292,7 +292,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -473,7 +476,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
Loss is calculated by using mean squared error.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 705fc3ce06..ac59e786c4 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -227,7 +227,10 @@ class LinearClassifier(estimator.Estimator):
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -370,7 +373,10 @@ class LinearRegressor(estimator.Estimator):
Loss is calculated by using mean squared error.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 41c25f1c73..2b87f7403f 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -103,6 +103,15 @@ class Estimator(object):
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
+
+ @compatbility(eager)
+ Calling methods of `Estimator` will work while eager execution is enabled.
+ However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
+ will switch to graph model before calling all user-provided functions (incl.
+ hooks), so their code has to be compatible with graph mode execution. Note
+ that `input_fn` code using `tf.data` generally works in both graph and eager
+ modes.
+ @end_compatibility
"""
def __init__(self, model_fn, model_dir=None, config=None, params=None,
@@ -1150,13 +1159,10 @@ class Estimator(object):
input_fn, model_fn_lib.ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)
global_step_tensor = self._create_and_assert_global_step(g)
- # The default destination for the global_step_tensor fetch call is the
- # CPU.
- global_step_read_tensor = self._distribution.fetch(global_step_tensor)
# we want to add to the global collection in the main thread not the
# tower threads.
ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
- global_step_read_tensor)
+ self._distribution.read_var(global_step_tensor))
grouped_estimator_spec = self._distribution.call_for_each_tower(
self._call_model_fn,
features,
@@ -1254,7 +1260,7 @@ class Estimator(object):
training_chief_hooks=training_chief_hooks,
scaffold=scaffold)
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
- hooks, global_step_read_tensor,
+ hooks, global_step_tensor,
saving_listeners)
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index af2ead9b84..a58c5aabbe 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -172,7 +172,7 @@ def _internal_input_layer(features,
scope=None):
"""See input_layer. `scope` is a name or variable scope to use."""
- feature_columns = _clean_feature_columns(feature_columns)
+ feature_columns = _normalize_feature_columns(feature_columns)
for column in feature_columns:
if not isinstance(column, _DenseColumn):
raise ValueError(
@@ -350,10 +350,23 @@ def linear_model(features,
prediction itself for linear regression problems.
Note on supported columns: `linear_model` treats categorical columns as
- `indicator_column`s while `input_layer` explicitly requires wrapping each
- of them with an `embedding_column` or an `indicator_column`.
+ `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
+ like:
- Example:
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+ `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
+ just like `indicator_column`, while `input_layer` explicitly requires wrapping
+ each of categorical columns with an `embedding_column` or an
+ `indicator_column`.
+
+ Example of usage:
```python
price = numeric_column('price')
@@ -374,13 +387,44 @@ def linear_model(features,
to your model. All items should be instances of classes derived from
`_FeatureColumn`s.
units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a sparse column is
- multivalent. Currently "mean", "sqrtn" and "sum" are supported, with "sum"
- the default. "sqrtn" often achieves good accuracy, in particular with
- bag-of-words columns. It combines each sparse columns independently.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum" are
+ supported, with "sum" the default for linear model. "sqrtn" often achieves
+ good accuracy, in particular with bag-of-words columns.
* "sum": do not normalize features in the column
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are:
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
weight_collections: A list of collection names to which the Variable will be
added. Note that, variables will also be added to collections
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
@@ -536,7 +580,8 @@ class _LinearModel(training.Model):
name=None,
**kwargs):
super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = _clean_feature_columns(feature_columns)
+ self._feature_columns = _normalize_feature_columns(
+ feature_columns)
self._weight_collections = list(weight_collections or [])
if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
@@ -643,7 +688,7 @@ def _transform_features(features, feature_columns):
Returns:
A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values.
"""
- feature_columns = _clean_feature_columns(feature_columns)
+ feature_columns = _normalize_feature_columns(feature_columns)
outputs = {}
with ops.name_scope(
None, default_name='transform_features', values=features.values()):
@@ -911,7 +956,8 @@ def shared_embedding_columns(
tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
which to restore the column weights. Required if `ckpt_to_load_from` is
not `None`.
- max_norm: If not `None`, embedding values are l2-normalized to this value.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
@@ -1182,12 +1228,13 @@ def categorical_column_with_hash_bucket(key,
Use this when your sparse features are in string or integer format, and you
want to distribute your inputs into a finite number of buckets by hashing.
- output_id = Hash(input_feature_string) % bucket_size
+ output_id = Hash(input_feature_string) % bucket_size for string type input.
+ For int type input, the value is converted to its string representation first
+ and then hashed by the same formula.
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example:
@@ -1249,8 +1296,7 @@ def categorical_column_with_vocabulary_file(key,
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example with `num_oov_buckets`:
File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
@@ -1366,8 +1412,7 @@ def categorical_column_with_vocabulary_list(
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example with `num_oov_buckets`:
In the following example, each input in `vocabulary_list` is assigned an ID
@@ -1480,8 +1525,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
In the following examples, each input in the range `[0, 1000000)` is assigned
the same value. All other inputs are assigned `default_value` 0. Note that a
@@ -1538,8 +1582,14 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
def indicator_column(categorical_column):
"""Represents multi-hot representation of given categorical column.
- Used to wrap any `categorical_column_*` (e.g., to feed to DNN). Use
- `embedding_column` if the inputs are sparse.
+ - For DNN model, `indicator_column` can be used to wrap any
+ `categorical_column_*` (e.g., to feed to DNN). Consider to Use
+ `embedding_column` if the number of buckets/unique(values) are large.
+
+ - For Wide (aka linear) model, `indicator_column` is the internal
+ representation for categorical column when passing categorical column
+ directly (as any element in feature_columns) to `linear_model`. See
+ `linear_model` for details.
```python
name = indicator_column(categorical_column_with_vocabulary_list(
@@ -1956,7 +2006,7 @@ def _create_weighted_sum(column,
weight_collections,
trainable,
weight_var=None):
- """Creates a weighted sum for a dense or sparse column for linear_model."""
+ """Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
@@ -2055,7 +2105,34 @@ def _create_categorical_column_weighted_sum(column,
weight_collections,
trainable,
weight_var=None):
- """Create a weighted sum of a categorical column for linear_model."""
+ # pylint: disable=g-doc-return-or-yield,g-doc-args
+ """Create a weighted sum of a categorical column for linear_model.
+
+ Note to maintainer: As implementation details, the weighted sum is
+ implemented via embedding_lookup_sparse toward efficiency. Mathematically,
+ they are the same.
+
+ To be specific, conceptually, categorical column can be treated as multi-hot
+ vector. Say:
+
+ ```python
+ x = [0 0 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `c` in this case, which is same as `w[2]`.
+
+ Another example is
+
+ ```python
+ x = [0 1 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
+
+ For both cases, we can implement weighted sum via embedding_lookup with
+ sparse_combiner = "sum".
+ """
+
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
weight_collections=weight_collections,
@@ -2249,7 +2326,7 @@ def _shape_offsets(shape):
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
-def _to_sparse_input(input_tensor, ignore_value=None):
+def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
If `input_tensor` is already a `SparseTensor`, just return it.
@@ -2293,8 +2370,22 @@ def _to_sparse_input(input_tensor, ignore_value=None):
input_tensor, out_type=dtypes.int64, name='dense_shape'))
-def _clean_feature_columns(feature_columns):
- """Verifies and normalizes `feature_columns` input."""
+def _normalize_feature_columns(feature_columns):
+ """Normalizes the `feature_columns` input.
+
+ This method converts the `feature_columns` to list type as best as it can. In
+ addition, verifies the type and other parts of feature_columns, required by
+ downstream library.
+
+ Args:
+ feature_columns: The raw feature columns, usually passed by users.
+
+ Returns:
+ The normalized feature column list.
+
+ Raises:
+ ValueError: for any invalid inputs, such as empty, duplicated names, etc.
+ """
if isinstance(feature_columns, _FeatureColumn):
feature_columns = [feature_columns]
@@ -2420,6 +2511,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
input_tensor = inputs.get(self)
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
@@ -2553,7 +2645,7 @@ def _get_graph_for_variable(var):
class _SharedEmbeddingColumn(
- _DenseColumn,
+ _DenseColumn, _SequenceDenseColumn,
collections.namedtuple(
'_SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2600,7 +2692,11 @@ class _SharedEmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor_internal(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Private method that follows the signature of _get_dense_tensor."""
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
@@ -2641,6 +2737,44 @@ class _SharedEmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value."""
@@ -2762,7 +2896,7 @@ class _HashedCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
@@ -2813,7 +2947,7 @@ class _VocabularyFileCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -2865,7 +2999,7 @@ class _VocabularyListCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -2917,7 +3051,7 @@ class _IdentityCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if not input_tensor.dtype.is_integer:
raise ValueError(
@@ -2999,7 +3133,8 @@ class _WeightedCategoricalColumn(
self.dtype, weight_tensor.dtype))
if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
# The weight tensor can be a regular Tensor. In this case, sparsify it.
- weight_tensor = _to_sparse_input(weight_tensor, ignore_value=0.0)
+ weight_tensor = _to_sparse_input_and_drop_ignore_values(
+ weight_tensor, ignore_value=0.0)
if not weight_tensor.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
return (inputs.get(self.categorical_column), weight_tensor)
@@ -3444,3 +3579,8 @@ class _SequenceCategoricalColumn(
weight_tensor,
shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
+
+# TODO(xiejw): Remove the following alias once call sites are updated.
+_clean_feature_columns = _normalize_feature_columns
+_to_sparse_input = _to_sparse_input_and_drop_ignore_values
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 82ecba310b..002a3d3be5 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -650,6 +651,41 @@ class _FuncGraph(ops.Graph):
# TODO(skyewm): is this needed?
self.extra_vars = []
+ # pylint: disable=g-doc-return-or-yield
+
+ @tf_contextlib.contextmanager
+ def container(self, container_name):
+ """Returns a context manager that specifies the resource container to use.
+
+ Overridden from @{tf.Graph} to update both the init_scope container
+ and the present inner container. This is necessary to make sure setting
+ containers applies correctly both to created variables and to stateful
+ ops.
+
+ Args:
+ container_name: container name string.
+
+ Returns:
+ A context manager for defining resource containers for stateful ops,
+ yields the container name.
+ """
+ original_container = self._container
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ original_init_container = ops.get_default_graph()._container
+ try:
+ self._container = container_name
+ with ops.init_scope():
+ ops.get_default_graph()._container = container_name
+ yield self._container
+ finally:
+ self._container = original_container
+ with ops.init_scope():
+ ops.get_default_graph()._container = original_init_container
+ # pylint: enable=protected-access
+
+ # pylint: enable=g-doc-return-or-yield
+
def getvar(
self,
getter,
@@ -773,7 +809,9 @@ class _FuncGraph(ops.Graph):
def func_graph_from_py_func(func, arg_names, arg_types, name=None,
- capture_by_value=False, device=None):
+ capture_by_value=False, device=None,
+ colocation_stack=None, container=None,
+ collections_ref=None):
"""Returns a _FuncGraph generated from `func`.
Args:
@@ -786,6 +824,10 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
capture_by_value: boolean. If True, captured values will be copied into the
function body.
device: device name or function.
+ colocation_stack: A colocation stack (list) the _FuncGraph should use.
+ container: A container name the _FuncGraph should start with.
+ collections_ref: A reference to a collections dict the _FuncGraph should
+ use internally.
Returns:
A _FuncGraph.
@@ -796,7 +838,17 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
if not name:
name = _get_func_name(func)
func_graph = _FuncGraph(name, capture_by_value)
+
with func_graph.as_default(), ops.device(device):
+ # pylint: disable=protected-access
+ if collections_ref is not None:
+ func_graph._collections = collections_ref
+ if container is not None:
+ func_graph._container = container
+ if colocation_stack is not None:
+ func_graph._colocation_stack = colocation_stack
+ # pylint: enable=protected-access
+
# Create placeholders for the function arguments.
for (argname, argtype) in zip(arg_names, arg_types):
argholder = array_ops.placeholder(argtype, name=argname)
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 7996110829..8a4018a0df 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -215,7 +215,7 @@ def Input( # pylint: disable=invalid-name
if dtype is None:
dtype = K.floatx()
- if not shape and tensor is None:
+ if shape is None and tensor is None:
raise ValueError('Please provide to Input either a `shape`'
' or a `tensor` argument. Note that '
'`shape` does not include the batch '
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index ff51eadee9..28cedec338 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -364,11 +364,12 @@ class BatchNormalization(Layer):
def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg',
[variable, value, momentum]) as scope:
- decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ with ops.colocate_with(variable):
+ decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
+ if decay.dtype != variable.dtype.base_dtype:
+ decay = math_ops.cast(decay, variable.dtype.base_dtype)
+ update_delta = (variable - value) * decay
+ return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 450428707d..80ba7dafc9 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -587,7 +587,7 @@ class Conv2DTest(test.TestCase):
values.append(_GetVal(data_format, use_gpu))
for i in range(1, len(values)):
- self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
+ self.assertAllClose(values[0], values[i], rtol=1e-2, atol=1e-2)
@test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth1ValidBackpropInput(self):
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 5e223b1828..7134e02c34 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -356,7 +356,7 @@ class DepthwiseConv2DTest(test.TestCase):
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-0,
- dtypes.float32: 5e-4,
+ dtypes.float32: 8e-4,
dtypes.float64: 1e-12,
}[data_type]
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index cf2e8832fd..985922245e 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -93,6 +93,7 @@ cuda_py_test(
size = "small",
srcs = ["categorical_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index ca2358fe99..68b4ffdb58 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import constant_op
@@ -40,7 +41,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
return categorical.Categorical(logits, dtype=dtype)
-class CategoricalTest(test.TestCase):
+class CategoricalTest(test.TestCase, parameterized.TestCase):
def testP(self):
p = [0.2, 0.8]
@@ -131,7 +132,7 @@ class CategoricalTest(test.TestCase):
with self.test_session():
self.assertAllClose(dist.prob(0).eval(), 0.2)
- def testCDFWithDynamicEventShape(self):
+ def testCDFWithDynamicEventShapeKnownNdims(self):
"""Test that dynamically-sized events with unknown shape work."""
batch_size = 2
histograms = array_ops.placeholder(dtype=dtypes.float32,
@@ -167,6 +168,21 @@ class CategoricalTest(test.TestCase):
self.assertAllClose(actual_cdf_one, expected_cdf_one)
self.assertAllClose(actual_cdf_two, expected_cdf_two)
+ @parameterized.named_parameters(
+ ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]),
+ ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
+ [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88]))
+ def testCDFWithDynamicEventShapeUnknownNdims(
+ self, events, histograms, expected_cdf):
+ """Test that dynamically-sized events with unknown shape work."""
+ event_ph = array_ops.placeholder_with_default(events, shape=None)
+ histograms_ph = array_ops.placeholder_with_default(histograms, shape=None)
+ dist = categorical.Categorical(probs=histograms_ph)
+ cdf_op = dist.cdf(event_ph)
+
+ actual_cdf = self.evaluate(cdf_op)
+ self.assertAllClose(actual_cdf, expected_cdf)
+
def testCDFWithBatch(self):
histograms = [[0.1, 0.2, 0.3, 0.25, 0.15],
[0.0, 0.75, 0.2, 0.05, 0.0]]
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 267d78dbcb..36cef3855e 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -217,7 +217,6 @@ def conv1d(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -421,7 +420,6 @@ def conv2d(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -627,7 +625,6 @@ def conv3d(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -1266,7 +1263,6 @@ def conv2d_transpose(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -1438,7 +1434,6 @@ def conv3d_transpose(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index abbacac442..aadff231da 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -184,7 +184,6 @@ def dense(
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index d082e312e9..ece6667981 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -308,7 +308,6 @@ def batch_normalization(inputs,
virtual_batch_size=virtual_batch_size,
adjustment=adjustment,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs, training=training)
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index b88a0518b6..dd25fce2ec 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
-def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
+def _broadcast_cat_event_and_params(event, params, base_dtype):
"""Broadcasts the event or distribution parameters."""
- if event.shape.ndims is None:
- raise NotImplementedError(
- "Cannot broadcast with an event tensor of unknown rank.")
-
if event.dtype.is_integer:
pass
elif event.dtype.is_floating:
@@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
else:
raise TypeError("`value` should have integer `dtype` or "
"`self.dtype` ({})".format(base_dtype))
-
- if params.get_shape()[:-1] == event.get_shape():
- params = params
- else:
- params *= array_ops.ones_like(
- array_ops.expand_dims(event, -1), dtype=params.dtype)
+ shape_known_statically = (
+ params.shape.ndims is not None and
+ params.shape[:-1].is_fully_defined() and
+ event.shape.is_fully_defined())
+ if not shape_known_statically or params.shape[:-1] != event.shape:
+ params *= array_ops.ones_like(event[..., array_ops.newaxis],
+ dtype=params.dtype)
params_shape = array_ops.shape(params)[:-1]
event *= array_ops.ones(params_shape, dtype=event.dtype)
- event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1]))
+ if params.shape.ndims is not None:
+ event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
+
return event, params
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index bcc717b043..c7919e4d4c 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -43,8 +43,8 @@ def _clip(params, ids, max_norm):
Args:
params: A `Tensor` of embeddings retrieved by `gather`.
ids: The `ids` argument that was passed to `gather`.
- max_norm: If provided, the embeddings are l2-normalized to the value of
- max_norm.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value.
Returns:
A `Tensor` with the same type as `params`.
@@ -290,8 +290,8 @@ def embedding_lookup(
in `indices` are always validated to be within range. If assigned to GPU,
out-of-bound indices result in safe but unspecified behavior, which may
include raising an error.
- max_norm: If provided, embedding values are l2-normalized to the value of
- max_norm.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value.
Returns:
A `Tensor` with the same type as the tensors in `params`.
@@ -346,8 +346,8 @@ def embedding_lookup_sparse(params,
"mean" is the weighted sum divided by the total weight.
"sqrtn" is the weighted sum divided by the square root of the sum of the
squares of the weights.
- max_norm: If provided, each embedding is normalized to have l2 norm equal
- to max_norm before combining.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
Returns:
A dense tensor representing the combined embeddings for the
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 355b0d961e..161d9687d6 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.deprecation import deprecated
@@ -295,66 +296,6 @@ class Template(checkpointable.CheckpointableBase):
# which is not the same as whether the scope has been created.
self._variables_created = False
- def _checkpointable_custom_creator(self, next_creator, name, initial_value,
- checkpointable_parent=None, **kwargs):
- """A variable creation hook which adds Checkpointable dependencies.
-
- Set during the `Template`'s first wrapped function execution. Ensures that
- (a) `Template` objects depend on `Template`s created inside them which
- create variables, and (b) that any variables not in a more deeply nested
- `Template` are added as dependencies directly.
-
- The `checkpointable_parent` argument is passed between `Template` custom
- creators but ignored when the variable object itself is created. This
- argument indicates (if not `None`) that a more deeply nested `Template` has
- already added the variable as a dependency, and that parent `Template`s
- should add a dependency on that `Template` rather than on the variable
- directly.
-
- Args:
- next_creator: See `variable_scope.variable_creator_scope`; the next
- creator in the chain.
- name: The (full, scope-influenced) name of the variable. The scope name
- for the Template itself is stripped for the purposes of object-based
- dependency tracking, but scopes within Templates are respected.
- initial_value: See `variable_scope.variable_creator_scope`. Taken
- explicitly so the argument can be re-named and used with
- `Checkpointable._add_variable_with_custom_getter`.
- checkpointable_parent: If not None, a more deeply nested Template object
- to add a dependency on (rather than depending on the variable directly).
- **kwargs: Passed through to the next creator.
- Returns:
- The output of `next_creator`: the fetched/created variable object.
- """
- def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
- inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
- # we don't want to propagate.
- return next_creator(
- initial_value=initializer,
- name=name,
- **inner_kwargs)
- if name.startswith(self._variable_scope.name):
- scope_stripped_name = name[len(self._variable_scope.name) + 1:]
- if not checkpointable_parent:
- return self._add_variable_with_custom_getter(
- initializer=initial_value,
- name=scope_stripped_name,
- getter=_call_next_creator_renaming_initializer,
- # Disable error checking for Checkpointable. Exceptions are instead
- # raised if necessary when the object-based saver tries to
- # save/restore the object.
- overwrite=True,
- checkpointable_parent=self,
- **kwargs)
- else:
- self._track_checkpointable(
- checkpointable_parent,
- name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access
- len(self._variable_scope.name) + 1:],
- overwrite=True)
- return next_creator(name=name, initial_value=initial_value,
- checkpointable_parent=self, **kwargs)
-
def _call_func(self, args, kwargs):
try:
vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
@@ -365,8 +306,7 @@ class Template(checkpointable.CheckpointableBase):
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
- with variable_scope.variable_creator_scope(
- self._checkpointable_custom_creator):
+ with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
@@ -634,8 +574,7 @@ class EagerTemplate(Template):
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
- with variable_scope.variable_creator_scope(
- self._checkpointable_custom_creator):
+ with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 9a711edaa4..47414c28af 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1925,7 +1925,8 @@ class variable_scope(object):
for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
variables if they do not exist, and return them otherwise; if None, we
inherit the parent scope's reuse flag. When eager execution is enabled,
- this argument is always forced to be tf.AUTO_REUSE.
+ new variables are always created unless an EagerVariableStore or
+ template is currently active.
dtype: type of variables created in this scope (defaults to the type
in the passed scope, or inherited from parent scope).
use_resource: If False, all variables will be regular Variables. If True,
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index e7f88de1d2..c2f0e9d3e6 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -219,8 +219,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
else:
var_name = ",".join([v.name for v in var])
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
- logging.info("Initialize variable %s from checkpoint %s with %s",
- var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
+ logging.debug("Initialize variable %s from checkpoint %s with %s",
+ var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
else:
scopes = ""
# TODO(vihanjain): Support list of 'current_var_or_name' here.
@@ -261,8 +261,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
if var is None:
var = _collect_partitioned_variable(var_name, store_vars)
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
- logging.info("Initialize variable %s from checkpoint %s with %s",
- var_name, ckpt_dir_or_file, full_tensor_name)
+ logging.debug("Initialize variable %s from checkpoint %s with %s",
+ var_name, ckpt_dir_or_file, full_tensor_name)
def _get_checkpoint_filename(ckpt_dir_or_file):
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 96e6d10791..0608076e6d 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -41,6 +41,7 @@ from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.util import deprecation
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@@ -564,6 +565,93 @@ def gather_initializers(root_checkpointable):
if hasattr(c, "initializer") and c.initializer is not None]
+@tf_contextlib.contextmanager
+def capture_dependencies(template):
+ """Capture variables created within this scope as `Template` dependencies.
+
+ Requires that `template.variable_scope` is active.
+
+ This scope is intended as a compatibility measure, allowing a checkpointable
+ object to add dependencies on variables created in a block of code which is
+ not aware of object-based saving (and instead uses variable names
+ heavily). This is how `Template` objects add dependencies on variables and
+ sub-`Template`s. Where possible, use `tf.make_template` directly.
+
+ Args:
+ template: The `Template` object to register dependencies with.
+
+ Yields:
+ None (when used as a context manager).
+ """
+ name_prefix = template.variable_scope.name
+
+ def _checkpointable_custom_creator(next_creator, name, initial_value,
+ checkpointable_parent=None, **kwargs):
+ """A variable creation hook which adds Checkpointable dependencies.
+
+ Set for example during a `Template`'s first wrapped function
+ execution. Ensures that (a) `template` depends on any checkpointable
+ objects using their own `capture_dependencies` scope inside this scope which
+ create variables, and (b) that any variables not in a more deeply nested
+ scope are added as dependencies directly.
+
+ The `checkpointable_parent` argument is passed between custom creators but
+ ignored when the variable object itself is created. This argument indicates
+ (if not `None`) that a more deeply nested scope has already added the
+ variable as a dependency, and that parent scopes should add a dependency on
+ that object rather than on the variable directly.
+
+ Args:
+ next_creator: See `variable_scope.variable_creator_scope`; the next
+ creator in the chain.
+ name: The (full, scope-influenced) name of the variable. The `name_prefix`
+ itself is stripped for the purposes of object-based dependency tracking,
+ but scopes opened within this scope are respected.
+ initial_value: See `variable_scope.variable_creator_scope`. Taken
+ explicitly so the argument can be re-named and used with
+ `Checkpointable._add_variable_with_custom_getter`.
+ checkpointable_parent: If not None, a more deeply nested checkpointable
+ object and its name prefix which were passed to `capture_dependencies`
+ to add a dependency on (rather than depending on the variable directly).
+ **kwargs: Passed through to the next creator.
+
+ Returns:
+ The output of `next_creator`: the fetched/created variable object.
+ """
+ def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
+ inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
+ # we don't want to propagate.
+ return next_creator(
+ initial_value=initializer,
+ name=name,
+ **inner_kwargs)
+ if name.startswith(name_prefix):
+ scope_stripped_name = name[len(name_prefix) + 1:]
+ if not checkpointable_parent:
+ return template._add_variable_with_custom_getter( # pylint: disable=protected-access
+ initializer=initial_value,
+ name=scope_stripped_name,
+ getter=_call_next_creator_renaming_initializer,
+ # Disable error checking for Checkpointable. Exceptions are instead
+ # raised if necessary when the object-based saver tries to
+ # save/restore the object.
+ overwrite=True,
+ checkpointable_parent=(template, name_prefix),
+ **kwargs)
+ else:
+ parent_object, parent_name_prefix = checkpointable_parent
+ template._track_checkpointable( # pylint: disable=protected-access
+ parent_object,
+ name=parent_name_prefix[len(name_prefix) + 1:],
+ overwrite=True)
+ return next_creator(
+ name=name, initial_value=initial_value,
+ checkpointable_parent=(template, name_prefix), **kwargs)
+
+ with variable_scope.variable_creator_scope(_checkpointable_custom_creator):
+ yield
+
+
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
def __init__(self, tensor, name):
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 8cdf5d7855..e2115417c4 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -1243,6 +1243,18 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
+class _ManualScope(checkpointable.Checkpointable):
+
+ def __call__(self):
+ with variable_scope.variable_scope("ManualScope") as vs:
+ self.variable_scope = vs
+ with checkpointable_utils.capture_dependencies(template=self):
+ return self._build()
+
+ def _build(self):
+ return variable_scope.get_variable(name="in_manual_scope", shape=[])
+
+
class TemplateTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -1255,14 +1267,23 @@ class TemplateTests(test.TestCase):
v2 = variable_scope.get_variable(
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
use_resource=True)
- return v, v + 1., v2
+ manual = _ManualScope()
+ return v, v + 1., v2, manual, manual()
save_template = template.make_template("s1", _templated)
- v1_save, _, v2_save = save_template()
+ v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
+ six.assertCountEqual(
+ self,
+ [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
+ checkpointable_utils.list_objects(save_template))
+ manual_dep, = manual_scope._checkpoint_dependencies
+ self.assertEqual("in_manual_scope", manual_dep.name)
+ self.assertIs(manual_scope_v, manual_dep.ref)
optimizer = adam.AdamOptimizer(0.0)
save_root = checkpointable_utils.Checkpoint(
my_template=save_template, optimizer=optimizer)
optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in save_template.variables])
self.evaluate([v.initializer for v in optimizer.variables()])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
@@ -1275,11 +1296,13 @@ class TemplateTests(test.TestCase):
load_root = checkpointable_utils.Checkpoint(
my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
- var, var_plus_one, var2 = load_template()
+ var, var_plus_one, var2, _, _ = load_template()
load_optimizer.minimize(var.read_value)
- self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual(3, len(load_template._checkpoint_dependencies))
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ self.assertEqual("ManualScope",
+ load_template._checkpoint_dependencies[2].name)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([12.], self.evaluate(var))
self.assertAllEqual([13.], self.evaluate(var_plus_one))
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index fece3370f3..7b06bffa4b 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -298,7 +298,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
- save_checkpoint_steps=USE_DEFAULT):
+ save_checkpoint_steps=USE_DEFAULT,
+ summary_dir=None):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
@@ -348,6 +349,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
`save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
the default checkpoint saver isn't used. If both are provided, then only
`save_checkpoint_secs` is used. Default not enabled.
+ summary_dir: A string. Optional path to a directory where to
+ save summaries. If None, checkpoint_dir is used instead.
Returns:
A `MonitoredSession` object.
@@ -388,11 +391,12 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
master=master,
config=config)
- if checkpoint_dir:
+ summary_dir = summary_dir or checkpoint_dir
+ if summary_dir:
if log_step_count_steps and log_step_count_steps > 0:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
- output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))
+ output_dir=summary_dir, every_n_steps=log_step_count_steps))
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
@@ -400,7 +404,9 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
- output_dir=checkpoint_dir))
+ output_dir=summary_dir))
+
+ if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index b8f58a288c..53ed89e4ab 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -206,21 +206,19 @@ class BaseSaverBuilder(object):
filename_tensor: String Tensor.
saveables: List of BaseSaverBuilder.SaveableObject objects.
preferred_shard: Int. Shard to open first when loading a sharded file.
- restore_sequentially: Bool. If true, each restore is sequential.
+ restore_sequentially: Unused. Bool. If true, each restore is sequential.
Returns:
A list of Tensors resulting from reading 'saveable' from
'filename'.
"""
+ del restore_sequentially
all_tensors = []
- assign_ops = []
for saveable in saveables:
- restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
- with ops.control_dependencies(restore_control_inputs):
- all_tensors.extend(
- self.restore_op(filename_tensor, saveable, preferred_shard))
+ all_tensors.extend(
+ self.restore_op(filename_tensor, saveable, preferred_shard))
return all_tensors
# pylint: disable=unused-argument
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 92c1a5fc07..31e407f199 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2183,10 +2183,12 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
// Return false if we might be hitting a cuBLAS bug that produces the wrong
// result. See nvbugs/2156201, b/79126339.
+#if (CUDA_VERSION >= 9000)
if (CUDA_VERSION < 9020 && algorithm != CUBLAS_GEMM_ALGO12 &&
std::max({m, n, k}) >= 2097153 && cc_major < 7) {
return false;
}
+#endif
cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 4a98cfe164..0cd0790a72 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -192,6 +192,7 @@ string ToVlogString(dnn::DataType data_type) {
case dnn::DataType::kInt8:
return "dnn::DataType::kInt8";
}
+ return "unknown DataType";
}
// Used together with PARAM to VLOG calls made to the stream. Intended
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index 3a28153e52..6065c12cad 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -5,12 +5,16 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
py_library(
name = "doc_srcs",
srcs = ["doc_srcs.py"],
srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:util",
+ ],
)
py_binary(
@@ -39,6 +43,7 @@ py_test(
srcs = ["doc_srcs_test.py"],
args = [
"--package=tensorflow.python",
+ "--api_name=tensorflow",
] + TENSORFLOW_API_INIT_FILES,
main = "doc_srcs_test.py",
srcs_version = "PY2AND3",
@@ -48,3 +53,19 @@ py_test(
"//tensorflow/python:no_contrib",
],
)
+
+py_test(
+ name = "estimator_doc_srcs_test",
+ srcs = ["doc_srcs_test.py"],
+ args = [
+ "--package=tensorflow.python.estimator",
+ "--api_name=estimator",
+ ] + ESTIMATOR_API_INIT_FILES,
+ main = "doc_srcs_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_srcs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
+ ],
+)
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index e375cd48d8..46b81e17c6 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -252,7 +252,7 @@ def get_module(dir_path, relative_to_dir):
return dir_path.replace('/', '.').strip('.')
-def get_module_docstring(module_name, package):
+def get_module_docstring(module_name, package, api_name):
"""Get docstring for the given module.
This method looks for docstring in the following order:
@@ -268,6 +268,7 @@ def get_module_docstring(module_name, package):
(excluding 'tensorflow.' prefix) to get a docstring for.
package: Base python package containing python with target tf_export
decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Returns:
One-line docstring to describe the module.
@@ -275,8 +276,10 @@ def get_module_docstring(module_name, package):
# Module under base package to get a docstring from.
docstring_module_name = module_name
- if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
- docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name]
+ doc_sources = doc_srcs.get_doc_sources(api_name)
+
+ if module_name in doc_sources:
+ docsrc = doc_sources[module_name]
if docsrc.docstring:
return docsrc.docstring
if docsrc.docstring_module_name:
@@ -335,7 +338,7 @@ def create_api_files(
if module or not root_init_template:
contents = (
_GENERATED_FILE_HEADER %
- get_module_docstring(module, package) + text + _GENERATED_FILE_FOOTER)
+ get_module_docstring(module, package, api_name) + text)
else:
# Read base init file
with open(root_init_template, 'r') as root_init_template_file:
diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/tools/api/generator/doc_srcs.py
index 74f6db98fd..ccd5bea481 100644
--- a/tensorflow/tools/api/generator/doc_srcs.py
+++ b/tensorflow/tools/api/generator/doc_srcs.py
@@ -19,6 +19,8 @@ from __future__ import print_function
import collections
+from tensorflow.python.util import tf_export
+
# Specifies docstring source for a module.
# Only one of docstring or docstring_module_name should be set.
@@ -31,7 +33,7 @@ DocSource = collections.namedtuple(
# Each attribute of DocSource is optional.
DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
-TENSORFLOW_DOC_SOURCES = {
+_TENSORFLOW_DOC_SOURCES = {
'app': DocSource(docstring_module_name='platform.app'),
'compat': DocSource(docstring_module_name='util.compat'),
'distributions': DocSource(
@@ -63,3 +65,28 @@ TENSORFLOW_DOC_SOURCES = {
'train.queue_runner': DocSource(
docstring_module_name='training.queue_runner'),
}
+
+_ESTIMATOR_DOC_SOURCES = {
+ 'estimator': DocSource(
+ docstring_module_name='estimator_lib'),
+ 'estimator.export': DocSource(
+ docstring_module_name='export.export_lib'),
+ 'estimator.inputs': DocSource(
+ docstring_module_name='inputs.inputs'),
+}
+
+
+def get_doc_sources(api_name):
+ """Get a map from module to a DocSource object.
+
+ Args:
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Returns:
+ Map from module name to DocSource object.
+ """
+ if api_name == tf_export.TENSORFLOW_API_NAME:
+ return _TENSORFLOW_DOC_SOURCES
+ if api_name == tf_export.ESTIMATOR_API_NAME:
+ return _ESTIMATOR_DOC_SOURCES
+ return {}
diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py
index 9ba95a3439..7b8f27c1b1 100644
--- a/tensorflow/tools/api/generator/doc_srcs_test.py
+++ b/tensorflow/tools/api/generator/doc_srcs_test.py
@@ -32,7 +32,7 @@ FLAGS = None
class DocSrcsTest(test.TestCase):
def testModulesAreValidAPIModules(self):
- for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
+ for module_name in doc_srcs.get_doc_sources(FLAGS.api_name):
# Convert module_name to corresponding __init__.py file path.
file_path = module_name.replace('.', '/')
if file_path:
@@ -43,7 +43,7 @@ class DocSrcsTest(test.TestCase):
self.assertFalse('%s is not a valid API module' % module_name)
def testHaveDocstringOrDocstringModule(self):
- for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+ for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
if docsrc.docstring and docsrc.docstring_module_name:
self.assertFalse(
'%s contains DocSource has both a docstring and a '
@@ -52,12 +52,12 @@ class DocSrcsTest(test.TestCase):
% (module_name))
def testDocstringModulesAreValidModules(self):
- for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+ for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
if docsrc.docstring_module_name:
doc_module_name = '.'.join([
FLAGS.package, docsrc.docstring_module_name])
if doc_module_name not in sys.modules:
- sys.assertFalse(
+ self.assertFalse(
'docsources_module %s is not a valid module under %s.' %
(docsrc.docstring_module_name, FLAGS.package))
@@ -71,6 +71,9 @@ if __name__ == '__main__':
'--package', type=str,
help='Base package that imports modules containing the target tf_export '
'decorators.')
+ parser.add_argument(
+ '--api_name', type=str,
+ help='API name: tensorflow or estimator')
FLAGS, unparsed = parser.parse_known_args()
importlib.import_module(FLAGS.package)
diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
index 5f45b3b1ad..b0fb04d7d4 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
@@ -242,7 +242,7 @@ tf_module {
}
member_method {
name: "MonitoredTrainingSession"
- argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\'], "
+ argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\'], "
}
member_method {
name: "NewCheckpointReader"
diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake
index d5dea4f3e4..e8c3199828 100644
--- a/tensorflow/tools/ci_build/Dockerfile.cmake
+++ b/tensorflow/tools/ci_build/Dockerfile.cmake
@@ -28,6 +28,8 @@ RUN pip install --upgrade astor
RUN pip install --upgrade gast
RUN pip install --upgrade numpy
RUN pip install --upgrade termcolor
+RUN pip install keras_applications==1.0.2
+RUN pip install keras_preprocessing==1.0.1
# Install golang
RUN apt-get install -t xenial-backports -y golang-1.9
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index b3d3f23ec8..386e66cc21 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -113,3 +113,9 @@ pip3 install --upgrade termcolor
# Install last working version of setuptools.
pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
+
+# Keras
+pip2 install keras_applications==1.0.2
+pip3 install keras_applications==1.0.2
+pip2 install keras_preprocessing==1.0.1
+pip3 install keras_preprocessing==1.0.1
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 61d34c7304..4e28fa74b9 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -84,4 +84,8 @@ pip3.5 install --upgrade termcolor
# Install last working version of setuptools.
pip3.5 install --upgrade setuptools==39.1.0
+# Keras
+pip3.5 install keras_applications==1.0.2
+pip3.5 install keras_preprocessing==1.0.1
+
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh)
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index fe2d2cf11c..a0b43199a2 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -100,4 +100,8 @@ pip3 install --upgrade termcolor
# Install last working version of setuptools.
pip3 install --upgrade setuptools==39.1.0
+# Keras
+pip3.5 install keras_applications==1.0.2
+pip3.5 install keras_preprocessing==1.0.1
+
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh)
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 9d4148c07f..d8356cec47 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -60,6 +60,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/autograph/converters:test_lib",
"//tensorflow/contrib/autograph/impl:impl",
"//tensorflow/contrib/autograph/operators:operators",
+ "//tensorflow/contrib/autograph/lang:lang",
"//tensorflow/contrib/autograph/pyct:pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index ef5cb60cee..423eff3bb2 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -451,11 +451,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/582e5dd5553e3089fef97f9ab5a3f063e0160fa9.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/582e5dd5553e3089fef97f9ab5a3f063e0160fa9.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/45a02a4f8474b4b8c5cc106b5cecb06cf6e1b3c6.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/45a02a4f8474b4b8c5cc106b5cecb06cf6e1b3c6.tar.gz",
],
- sha256 = "9a0e63469ae5a546e0c84b778955f0febabfc8497d312324546ec7d0db68430e",
- strip_prefix = "llvm-582e5dd5553e3089fef97f9ab5a3f063e0160fa9",
+ sha256 = "056f7316a354d1f95e013176bd9b8be74e8f4d47fb0d908e0e742613187dbd59",
+ strip_prefix = "llvm-45a02a4f8474b4b8c5cc106b5cecb06cf6e1b3c6",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
diff --git a/third_party/gpus/crosstool/CROSSTOOL.tpl b/third_party/gpus/crosstool/CROSSTOOL.tpl
index 60b19daf1d..1424ff6511 100644
--- a/third_party/gpus/crosstool/CROSSTOOL.tpl
+++ b/third_party/gpus/crosstool/CROSSTOOL.tpl
@@ -295,3 +295,245 @@ toolchain {
%{host_compiler_includes}
}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "macosx"
+ target_cpu: "darwin"
+ target_system_name: "local"
+ toolchain_identifier: "local_darwin"
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+ %{host_compiler_warnings}
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag:"-no-canonical-prefixes"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin/"
+ }
+ }
+ }
+
+ feature {
+ name: "undefined-dynamic"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-undefined"
+ flag: "dynamic_lookup"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ implies: "undefined-dynamic"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "%{host_compiler_path}" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/libtool" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+%{host_compiler_includes}
+}