aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md1
-rw-r--r--tensorflow/c/c_api.cc13
-rw-r--r--tensorflow/c/c_api_test.cc65
-rw-r--r--tensorflow/c/c_test_util.cc7
-rw-r--r--tensorflow/c/c_test_util.h3
-rw-r--r--tensorflow/cc/framework/scope.cc30
-rw-r--r--tensorflow/cc/framework/scope_internal.h3
-rw-r--r--tensorflow/cc/framework/scope_test.cc10
-rw-r--r--tensorflow/compiler/tests/fft_test.py6
-rw-r--r--tensorflow/compiler/tests/segment_reduction_ops_test.py98
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc35
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc56
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc109
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc8
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.cc10
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.h6
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc9
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc38
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h12
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry_test.cc33
-rw-r--r--tensorflow/compiler/xla/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc89
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h10
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc262
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h74
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc97
-rw-r--r--tensorflow/compiler/xla/layout_util.cc6
-rw-r--r--tensorflow/compiler/xla/overflow_util.h50
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc13
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i1
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py6
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py15
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc149
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc219
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD26
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.cc50
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.h65
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc82
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc103
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc38
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc123
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc94
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc69
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc40
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc44
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc105
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h48
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h33
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc17
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h36
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h10
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc39
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc4
-rw-r--r--tensorflow/compiler/xla/shape_util.cc73
-rw-r--r--tensorflow/compiler/xla/shape_util.h7
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc76
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc23
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc21
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc50
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/select_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc50
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/transpose_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc65
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc2
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb311
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake13
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake12
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py30
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py39
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py48
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_input.py35
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/config.py12
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py169
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py39
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn.py17
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py19
-rw-r--r--tensorflow/contrib/estimator/python/estimator/linear.py17
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle3
-rw-r--r--tensorflow/contrib/lite/python/lite.py45
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py276
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py7
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md46
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc2
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py13
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/framework/kernel_def_util.cc83
-rw-r--r--tensorflow/core/framework/kernel_def_util.h31
-rw-r--r--tensorflow/core/framework/kernel_def_util_test.cc133
-rw-r--r--tensorflow/core/framework/op_kernel.cc59
-rw-r--r--tensorflow/core/graph/tensor_id.cc3
-rw-r--r--tensorflow/core/graph/tensor_id.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc30
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc2
-rw-r--r--tensorflow/core/kernels/function_ops.cc18
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc1
-rw-r--r--tensorflow/core/kernels/pad_op.cc4
-rw-r--r--tensorflow/core/kernels/pad_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h12
-rw-r--r--tensorflow/core/util/device_name_utils.cc57
-rw-r--r--tensorflow/core/util/device_name_utils.h12
-rw-r--r--tensorflow/core/util/device_name_utils_test.cc47
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h1
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md4
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/client/session_test.py69
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py53
-rw-r--r--tensorflow/python/eager/function.py7
-rw-r--r--tensorflow/python/eager/function_test.py15
-rw-r--r--tensorflow/python/estimator/canned/dnn.py32
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py38
-rw-r--r--tensorflow/python/estimator/canned/linear.py42
-rw-r--r--tensorflow/python/estimator/canned/optimizers.py2
-rw-r--r--tensorflow/python/estimator/canned/optimizers_test.py30
-rw-r--r--tensorflow/python/framework/importer.py6
-rw-r--r--tensorflow/python/framework/ops.py46
-rw-r--r--tensorflow/python/framework/test_util.py20
-rw-r--r--tensorflow/python/keras/engine/network.py5
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py19
-rw-r--r--tensorflow/python/ops/control_flow_ops.py7
-rw-r--r--tensorflow/python/ops/gradients_impl.py8
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py12
-rw-r--r--tensorflow/python/training/checkpointable/BUILD7
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py6
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py85
-rw-r--r--tensorflow/python/util/lock_util_test.py3
-rw-r--r--tensorflow/tools/api/generator/doc_srcs_test.py24
-rw-r--r--tensorflow/tools/docs/BUILD2
-rw-r--r--tensorflow/tools/docs/generate_lib.py80
-rw-r--r--tensorflow/tools/docs/generate_lib_test.py110
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/toolchains/clang6/CROSSTOOL.tpl3
210 files changed, 4284 insertions, 1939 deletions
diff --git a/RELEASE.md b/RELEASE.md
index f0a7afe684..2ee2b67435 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -32,7 +32,6 @@
* Using `tf.keras.layers` with custom variable scopes.
* Using `tf.layers` in a subclassed `tf.keras.Model` class. See
[here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details
-
* `tf.data`:
* The `DatasetBase::DebugString()` method is now `const`.
* Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets.
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 9d5f98d4d6..a8ad8e4b94 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -2414,7 +2414,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
- g->name_map[n->name()] = n;
+ // We have a convoluted scheme here: Using the C++ graph construction API
+ // to add potentially many nodes to the graph without running the checks
+ // (such as uniqueness of the names of nodes) we run with other functions
+ // that add a node to the graph (like TF_FinishOperation).
+ if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
+ status->status = tensorflow::errors::Internal(
+ "BUG: The API allowed construction of a graph with duplicate node "
+ "names (",
+ n->name(),
+ "). This is a bug. Please file an issue at "
+ "https://github.com/tensorflow/tensorflow/issues.");
+ }
}
}
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 577f10c5e6..bc04b53fbb 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1160,7 +1160,7 @@ TEST(CAPI, GetOpDef) {
}
void StringVectorToArrays(const std::vector<string>& v,
- std::unique_ptr<const void* []>* ptrs,
+ std::unique_ptr<const void*[]>* ptrs,
std::unique_ptr<size_t[]>* lens) {
ptrs->reset(new const void*[v.size()]);
lens->reset(new size_t[v.size()]);
@@ -1196,7 +1196,7 @@ class CApiColocationTest : public ::testing::Test {
void SetViaStringList(TF_OperationDescription* desc,
const std::vector<string>& list) {
- std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<const void*[]> list_ptrs;
std::unique_ptr<size_t[]> list_lens;
StringVectorToArrays(list, &list_ptrs, &list_lens);
TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
@@ -1700,6 +1700,61 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
+void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
+ ASSERT_TRUE(t != nullptr);
+ ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
+ ASSERT_EQ(0, TF_NumDims(t));
+ ASSERT_EQ(4, TF_TensorByteSize(t));
+ float* p = static_cast<float*>(TF_TensorData(t));
+ *f = *p;
+}
+
+TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) {
+ const float X = 3.0f, Y = 7.0f;
+ TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT);
+ TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT);
+ TF_Operation* xy = Mul(x, y, graph_, s_, "xy");
+ TF_Output dxy_dx, dxy_dy;
+
+ TF_Output outputs[1] = {{xy, 0}};
+ TF_Output inputs[1] = {{x, 0}};
+ TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ inputs[0] = {y, 0};
+ TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TF_Session* sess = TF_NewSession(graph_, opts, s_);
+ TF_DeleteSessionOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Output feeds[] = {{x, 0}, {y, 0}};
+ TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)};
+ TF_Output fetches[] = {dxy_dx, dxy_dy};
+ TF_Tensor* fetchValues[] = {nullptr, nullptr};
+
+ TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches,
+ fetchValues, 2, nullptr /* target_opers */, 0,
+ nullptr /* run_metadata */, s_);
+ TF_DeleteTensor(feedValues[0]);
+ TF_DeleteTensor(feedValues[1]);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_DeleteSession(sess, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f;
+ ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue);
+ EXPECT_EQ(Y, dxy_dxValue);
+
+ ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue);
+ EXPECT_EQ(X, dxy_dyValue);
+
+ TF_DeleteTensor(fetchValues[0]);
+ TF_DeleteTensor(fetchValues[1]);
+}
+
// REGISTER_OP for CApiAttributesTest test cases.
// Registers two ops, each with a single attribute called 'v'.
// The attribute in one op will have a type 'type', the other
@@ -1784,7 +1839,7 @@ TEST_F(CApiAttributesTest, String) {
TEST_F(CApiAttributesTest, StringList) {
std::vector<string> list = {"bugs", "bunny", "duck"};
- std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<const void*[]> list_ptrs;
std::unique_ptr<size_t[]> list_lens;
StringVectorToArrays(list, &list_ptrs, &list_lens);
int list_total_size = 0;
@@ -1800,7 +1855,7 @@ TEST_F(CApiAttributesTest, StringList) {
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size);
- std::unique_ptr<void* []> values(new void*[list.size()]);
+ std::unique_ptr<void*[]> values(new void*[list.size()]);
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
std::unique_ptr<char[]> storage(new char[list_total_size]);
TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(),
@@ -2025,7 +2080,7 @@ TEST_F(CApiAttributesTest, TensorShapeProtoList) {
tensorflow::PartialTensorShape(pts2).AsProto(&proto);
proto.SerializeToString(&bytes2);
- std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<const void*[]> list_ptrs;
std::unique_ptr<size_t[]> list_lens;
const std::vector<string> list = {bytes1, bytes2};
StringVectorToArrays(list, &list_ptrs, &list_lens);
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index f3b28c1708..24eb6c069b 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -216,6 +216,13 @@ TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
return MinWithDevice(l, r, graph, /*op_device=*/"", s, name);
}
+TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name) {
+ TF_Operation* op;
+ BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true);
+ return op;
+}
+
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index c16aba666e..38313d647c 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -80,6 +80,9 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "min");
+TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name = "mul");
+
// If `op_device` is non-empty, set the created op on that device.
TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
const string& op_device, TF_Status* s,
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 62a889181e..8c886f3171 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -37,6 +37,11 @@ Scope& Scope::operator=(const Scope& other) {
return *this;
}
+namespace {
+const char kScopeSeparator[] = "/";
+const char kSuffixSeparator[] = "_";
+} // namespace
+
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
ShapeRefiner* refiner, bool disable_shape_inference)
: graph_(graph),
@@ -308,19 +313,23 @@ string Scope::Impl::GetUniqueName(const string& prefix,
return prefix;
}
auto entry = name_map_->find(prefix);
- string unique_name = prefix;
if (entry == name_map_->end()) {
name_map_->insert({prefix, 0});
- } else {
- unique_name = strings::StrCat(unique_name, "_", ++entry->second);
+ return prefix;
}
+ string unique_name;
+ do {
+ unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second);
+ } while (name_map_->find(unique_name) != name_map_->end());
+ name_map_->insert({unique_name, 0});
return unique_name;
}
string Scope::Impl::GetNameForOp(const string& default_name) const {
const string unique_name =
GetUniqueName(default_name, true /* check_single_use */);
- const string sep = name_.empty() || unique_name.empty() ? "" : "/";
+ const string sep =
+ name_.empty() || unique_name.empty() ? "" : kScopeSeparator;
return strings::StrCat(name_, sep, unique_name);
}
@@ -345,7 +354,8 @@ Scope Scope::NewSubScope(const string& child_scope_name) const {
}
const string unique_name =
impl()->GetUniqueName(child_scope_name, false /* check_single_use */);
- const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/";
+ const string sep =
+ impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator;
return Scope(new Impl(*this, Impl::Tags::ScopeName(),
strings::StrCat(impl()->name_, sep, unique_name),
false /* copy_names */));
@@ -412,7 +422,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
if (!impl()->single_use_scope()) {
Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name
: impl()->op_name_);
- const string child_op_sep = impl()->name_.empty() ? "" : "_";
+ const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator;
const string child_name =
strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_);
return {child,
@@ -435,7 +445,13 @@ class InternalScope {
static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
for (const Node* node : graph->nodes()) {
- (*name_map)[node->name()] = 0;
+ const string& name = node->name();
+ (*name_map)[name] = 0;
+ // Add all name prefixes ('/' separated).
+ size_t idx = -1;
+ while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) {
+ (*name_map)[name.substr(0, idx)] = 0;
+ }
}
// We provide null destructors for these shared ptrs (except for name_map)
// since the caller owns them and doesn't want the scope to destroy them.
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 8efcfed20d..58adaef2e9 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -34,8 +34,7 @@ class Scope::Impl {
// name that has not been used so far in a scope will get no suffix. Later
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
// can share the same NameMap. For instance, a new scope created using
- // WithControlDependencies() should would share the same NameMap with the
- // parent.
+ // WithControlDependencies() would share the same NameMap with the parent.
typedef std::unordered_map<string, int> NameMap;
Impl(const std::shared_ptr<Graph>& graph,
diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc
index 9eca9d3fac..b40b345eb8 100644
--- a/tensorflow/cc/framework/scope_test.cc
+++ b/tensorflow/cc/framework/scope_test.cc
@@ -26,6 +26,16 @@ TEST(ScopeTest, BasicNames) {
EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul");
}
+TEST(ScopeTest, OpAndScopeNameCollision) {
+ Scope root = Scope::NewRootScope();
+ EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_1");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo_1"), "foo_1_1");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_3");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2_1");
+}
+
TEST(ScopeTest, HierarchicalNames) {
Scope root = Scope::NewRootScope();
Scope child = root.NewSubScope("child");
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index afb5fa4bb4..b2360dd009 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -27,6 +27,7 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import spectral_ops
from tensorflow.python.platform import googletest
@@ -97,8 +98,11 @@ class FFTTest(XLATestCase):
ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape)
out = signal.stft(ph, ws, hs)
+ grad = gradients_impl.gradients(out, ph,
+ grad_ys=array_ops.ones_like(out))
- value = sess.run(out, {ph: data})
+ # For gradients, we simply verify that they compile & execute.
+ value, _ = sess.run([out, grad], {ph: data})
self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
def testFFT(self):
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
index 4a9c0e7471..772c20fd42 100644
--- a/tensorflow/compiler/tests/segment_reduction_ops_test.py
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -21,26 +21,40 @@ from __future__ import print_function
import functools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class SegmentReductionOpsTest(XLATestCase):
+class SegmentReductionOpsTest(xla_test.XLATestCase):
"""Test cases for segment reduction ops."""
- def UnsortedSegmentSum(self, data, indices, num_segments):
+ def _segmentReduction(self, op, data, indices, num_segments):
with self.test_session() as sess, self.test_scope():
d = array_ops.placeholder(data.dtype, shape=data.shape)
if isinstance(indices, int):
i = array_ops.placeholder(np.int32, shape=[])
else:
i = array_ops.placeholder(indices.dtype, shape=indices.shape)
- return sess.run(
- math_ops.unsorted_segment_sum(d, i, num_segments),
- {d: data,
- i: indices})
+ return sess.run(op(d, i, num_segments), {d: data, i: indices})
+
+ def _unsortedSegmentSum(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices,
+ num_segments)
+
+ def _unsortedSegmentProd(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices,
+ num_segments)
+
+ def _unsortedSegmentMin(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_min, data, indices,
+ num_segments)
+
+ def _unsortedSegmentMax(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_max, data, indices,
+ num_segments)
def testUnsortedSegmentSum0DIndices1DData(self):
for dtype in self.numeric_types:
@@ -49,14 +63,14 @@ class SegmentReductionOpsTest(XLATestCase):
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
[0, 0, 0, 0, 0, 0]],
dtype=dtype),
- self.UnsortedSegmentSum(
+ self._unsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
def testUnsortedSegmentSum1DIndices1DData(self):
for dtype in self.numeric_types:
self.assertAllClose(
np.array([1, 3, 2, 9], dtype=dtype),
- self.UnsortedSegmentSum(
+ self._unsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
@@ -64,7 +78,7 @@ class SegmentReductionOpsTest(XLATestCase):
for dtype in self.numeric_types:
self.assertAllClose(
np.array([6, 3, 0, 6], dtype=dtype),
- self.UnsortedSegmentSum(
+ self._unsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
@@ -76,7 +90,7 @@ class SegmentReductionOpsTest(XLATestCase):
dtype=dtype)
indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
num_segments = 10
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
[[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
@@ -92,7 +106,7 @@ class SegmentReductionOpsTest(XLATestCase):
dtype=dtype)
indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
num_segments = 4
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
[[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
@@ -102,30 +116,30 @@ class SegmentReductionOpsTest(XLATestCase):
def testUnsortedSegmentSum2DIndices3DData(self):
for dtype in self.numeric_types:
data = np.array(
- [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
- [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
- [310, 311, 312]]],
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
+ 200, 201, 202
+ ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
dtype=dtype)
indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
num_segments = 8
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
- [[210, 211, 212], [110, 111, 112], [310, 311, 312],
- [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301,
- 302], [0, 0, 0]],
+ [[210, 211, 212], [110, 111, 112], [310, 311, 312], [
+ 100, 102, 104
+ ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]],
dtype=dtype), y)
def testUnsortedSegmentSum1DIndices3DData(self):
for dtype in self.numeric_types:
data = np.array(
- [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
- [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
- [310, 311, 312]]],
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
+ 200, 201, 202
+ ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
dtype=dtype)
indices = np.array([3, 0, 2, 5], dtype=np.int32)
num_segments = 6
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
[[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
@@ -138,10 +152,40 @@ class SegmentReductionOpsTest(XLATestCase):
data = np.ones((4, 8, 7), dtype=dtype)
indices = np.ones((3, 2), dtype=np.int32)
num_segments = 4
- self.assertRaises(ValueError,
- functools.partial(self.UnsortedSegmentSum, data,
- indices, num_segments))
+ self.assertRaises(
+ ValueError,
+ functools.partial(self._segmentReduction,
+ math_ops.unsorted_segment_sum, data, indices,
+ num_segments))
+
+ def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self):
+ """Tests for min, max, and prod ops.
+
+ These share most of their implementation with sum, so we only test basic
+ functionality.
+ """
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array([8, 3, 1, 0], dtype=dtype),
+ self._unsortedSegmentProd(
+ np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
+
+ for dtype in self.int_types | self.float_types:
+ minval = dtypes.as_dtype(dtype).min
+ maxval = dtypes.as_dtype(dtype).max
+
+ self.assertAllClose(
+ np.array([2, 3, maxval, 0], dtype=dtype),
+ self._unsortedSegmentMin(
+ np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
+ self.assertAllClose(
+ np.array([4, 3, minval, 6], dtype=dtype),
+ self._unsortedSegmentMax(
+ np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
-if __name__ == '__main__':
+if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 140dad61d9..6cc95149a1 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -166,6 +166,27 @@ StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
return inserted_node;
}
+// Check that the graph has no cycle containing the given node.
+Status CheckNoCycleContains(const Node* node, const int num_nodes) {
+ std::vector<const Node*> ready;
+ ready.push_back(node);
+ std::vector<bool> visited(num_nodes);
+ while (!ready.empty()) {
+ const Node* current_node = ready.back();
+ ready.pop_back();
+ visited[current_node->id()] = true;
+ for (const Edge* out : current_node->out_edges()) {
+ if (out->dst() == node) {
+ return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(",
+ node->def().op(), ") feeds into itself.");
+ } else if (!visited[out->dst()->id()]) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+ return Status::OK();
+}
+
StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
NodeDef arg_def;
NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
@@ -1407,6 +1428,10 @@ StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
TF_RETURN_IF_ERROR(
AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node));
TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
+ // Check that the if_node doesn't feed into itself.
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNoCycleContains(if_node, graph_->num_node_ids()),
+ "ConvertToXlaIf failed.");
return if_node;
}
@@ -1506,6 +1531,16 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
worklist.push_back(frame->parent);
}
}
+ // There should be no cycle at this point, since while loops have been removed
+ // from graph.
+ // Check that the newly added XlaWhile nodes don't feed into themselves.
+ for (const Node* node : graph->op_nodes()) {
+ if (node->def().op() == "XlaWhile") {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNoCycleContains(node, graph->num_node_ids()),
+ "FunctionalizeLoop failed.");
+ }
+ }
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index 14977a908a..aae2f8ee5a 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -1012,5 +1013,60 @@ TEST(FunctionalizeControlFlow, Complex) {
}
}
+TEST(FunctionalizeControlFlow, Cycle) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ // -----------------------------------------------------
+ // | |
+ // | v
+ // less -> switch_1 --> add -> merge_1 -> identity -> switch_2
+ // | ^ |
+ // | | v
+ // --------> one -------------------------> add_2 ---> merge_2
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+
+ auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
+ auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
+ auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
+ auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
+ auto two =
+ ops::Const<int32>(scope.WithOpName("cond/two")
+ .WithControlDependencies(switch_1.output_true),
+ 2);
+ auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"),
+ switch_1.output_true, two);
+ auto one =
+ ops::Const<int32>(scope.WithOpName("cond/one")
+ .WithControlDependencies(switch_1.output_false),
+ 1);
+ auto add = ops::Add(scope.WithOpName("cond/false/add"),
+ switch_1.output_false, one);
+
+ auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"),
+ std::initializer_list<Input>{add, mul});
+ auto identity =
+ ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output);
+ auto switch_2 =
+ ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less);
+ auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"),
+ switch_2.output_false, one);
+ auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"),
+ switch_2.output_true, two);
+ auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"),
+ std::initializer_list<Input>{add_2, mul_2});
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+ // No cycle before functionalize control flow.
+ TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph));
+ FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ // switch_1 and switch_2 have the same switch depth. They are replaced by a
+ // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle:
+ // less -> XlaIf <--> identity.
+ Status status = FunctionalizeControlFlow(graph.get(), &library);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 212f6f3966..4a6622ed73 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@@ -87,6 +88,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
}
} // namespace
Status GraphCompiler::Compile() {
+ // Check that the graph has no illegal cycles.
+ TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
// Maintain a mapping from node id to node outputs.
using NodeOutputs = std::vector<TensorValue>;
std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index be83834e86..3bab4ae917 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -210,9 +210,7 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::XlaOp min_positive =
XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min());
auto uniform = b->RngUniform(min_positive, one, xla_shape);
- auto truncated_normal_or_status = TruncatedNormal(dtype, uniform, b);
- OP_REQUIRES_OK(ctx, truncated_normal_or_status.status());
- ctx->SetOutput(0, truncated_normal_or_status.ValueOrDie());
+ ctx->SetOutput(0, TruncatedNormal(dtype, uniform));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index 664078ca16..ff14483347 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -22,12 +22,19 @@ limitations under the License.
namespace tensorflow {
namespace {
-class UnsortedSegmentSum : public XlaOpKernel {
+class UnsortedSegmentReduce : public XlaOpKernel {
public:
- explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
+ // The initial value to initialize elements of the output to.
+ virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
+
+ // A function to combine two scalars with the same index (e.g., sum).
+ virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) = 0;
+
void Compile(XlaOpKernelContext* ctx) override {
// output = unsorted_segment_sum(data, indices, num_segments)
// Compute a tensor such that:
@@ -50,27 +57,29 @@ class UnsortedSegmentSum : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
- errors::InvalidArgument(
- "UnsortedSegmentSum requires that indices' rank be"
- " less than or equal to data's rank."));
+ errors::InvalidArgument(type_string(),
+ " requires that indices' rank be"
+ " less than or equal to data's rank."));
// Validate that indices.shape is a prefix of data.shape.
for (int d = 0; d < indices_shape.dims(); ++d) {
- OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
- errors::InvalidArgument(
- "UnsortedSegmentSum requires indices shape to be prefix"
- " of data_shape, but dimension ",
- d, " differs ", data_shape.dim_size(d), " vs. ",
- indices_shape.dim_size(d)));
+ OP_REQUIRES(
+ ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
+ errors::InvalidArgument(type_string(),
+ " requires indices shape to be prefix"
+ " of data_shape, but dimension ",
+ d, " differs ", data_shape.dim_size(d),
+ " vs. ", indices_shape.dim_size(d)));
}
xla::XlaBuilder* builder = ctx->builder();
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
- auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_),
- buffer_shape.dim_sizes());
+ auto buffer =
+ builder->Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
- auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
- return builder->Add(a, b);
+ auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) {
+ return Combine(a, b, builder);
};
auto result = XlaScatter(buffer, /*updates=*/data, indices,
@@ -79,13 +88,81 @@ class UnsortedSegmentSum : public XlaOpKernel {
ctx->SetOutput(0, result.ValueOrDie());
}
- private:
+ protected:
DataType dtype_;
};
+class UnsortedSegmentSum : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentSum(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return XlaHelpers::Zero(builder, dtype_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) override {
+ return builder->Add(a, b);
+ };
+};
+
REGISTER_XLA_OP(
Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"),
UnsortedSegmentSum);
+class UnsortedSegmentProd : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentProd(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return XlaHelpers::One(builder, dtype_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) override {
+ return builder->Mul(a, b);
+ };
+};
+
+REGISTER_XLA_OP(
+ Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"),
+ UnsortedSegmentProd);
+
+class UnsortedSegmentMin : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentMin(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return XlaHelpers::MaxFiniteValue(builder, dtype_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) override {
+ return builder->Min(a, b);
+ };
+};
+
+REGISTER_XLA_OP(
+ Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"),
+ UnsortedSegmentMin);
+
+class UnsortedSegmentMax : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentMax(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return XlaHelpers::MinFiniteValue(builder, dtype_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) override {
+ return builder->Max(a, b);
+ };
+};
+
+REGISTER_XLA_OP(
+ Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"),
+ UnsortedSegmentMax);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 0367501433..43ab4642e9 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -207,10 +207,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
- auto erfinv_or_status = ErfInv(uniform);
- OP_REQUIRES_OK(ctx, erfinv_or_status.status());
auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)),
- erfinv_or_status.ValueOrDie());
+ ErfInv(uniform));
ctx->SetOutput(0, normal);
}
@@ -245,9 +243,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
auto uniform =
RandomUniform(b, seed, shape, std::numeric_limits<float>::min(), 1.0);
- auto truncated_normal_or_status = TruncatedNormal(dtype, uniform, b);
- OP_REQUIRES_OK(ctx, truncated_normal_or_status.status());
- ctx->SetOutput(0, truncated_normal_or_status.ValueOrDie());
+ ctx->SetOutput(0, TruncatedNormal(dtype, uniform));
}
private:
diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc
index 4a2516244a..e4f195901e 100644
--- a/tensorflow/compiler/tf2xla/lib/random.cc
+++ b/tensorflow/compiler/tf2xla/lib/random.cc
@@ -23,9 +23,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
-xla::StatusOr<xla::XlaOp> TruncatedNormal(const DataType dtype,
- const xla::XlaOp& uniform,
- xla::XlaBuilder* builder) {
+
+xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) {
+ xla::XlaBuilder* builder = uniform.builder();
auto normal_cdf = [](double x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
@@ -51,7 +51,7 @@ xla::StatusOr<xla::XlaOp> TruncatedNormal(const DataType dtype,
// probit(p) = sqrt(2) * erfinv(2*p-1)
auto p = builder->Add(alpha_normal_cdf, builder->Mul(z, uniform));
auto erfinv_input = builder->Sub(builder->Mul(p, two), one);
- TF_ASSIGN_OR_RETURN(auto erfinv_or_status, ErfInv(erfinv_input));
- return builder->Mul(sqrt_2, erfinv_or_status);
+ return builder->Mul(sqrt_2, ErfInv(erfinv_input));
}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h
index 18c873dba5..39cbcf9c5e 100644
--- a/tensorflow/compiler/tf2xla/lib/random.h
+++ b/tensorflow/compiler/tf2xla/lib/random.h
@@ -21,15 +21,15 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
+
// Builds an array filled with values sampled from a truncated normal
// distribution such that no values are greater than two or less than negative
// two.
//
// The "uniform" parameter must be an array of random numbers distributed in
// (0,1).
-xla::StatusOr<xla::XlaOp> TruncatedNormal(DataType dtype,
- const xla::XlaOp& uniform,
- xla::XlaBuilder* builder);
+xla::XlaOp TruncatedNormal(DataType dtype, xla::XlaOp uniform);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index 265b39402c..5f408f2ed0 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -86,10 +86,9 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
- TF_ASSERT_OK_AND_ASSIGN(
- auto l_index,
- DynamicSliceInMinorDims(&builder, a,
- {index, builder.ConstantR0<int32>(0)}, {1, 4}));
+ TF_ASSERT_OK(DynamicSliceInMinorDims(
+ &builder, a, {index, builder.ConstantR0<int32>(0)}, {1, 4})
+ .status());
ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
{a_data.get(), index_data.get()});
@@ -132,9 +131,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) {
auto l_index,
DynamicSliceInMinorDims(&builder, a,
{index, builder.ConstantR0<int32>(0)}, {1, n}));
- TF_ASSERT_OK_AND_ASSIGN(
- auto dot, BatchDot(&builder, l_index, row,
- /*transpose_x=*/false, /*transpose_y=*/true));
+ TF_ASSERT_OK(BatchDot(&builder, l_index, row,
+ /*transpose_x=*/false, /*transpose_y=*/true)
+ .status());
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
{a_data.get(), row_data.get(), index_data.get()});
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 9c8e56a17e..e646ffe39f 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -384,13 +384,14 @@ Status BuildComputation(
const XlaCompiler::Argument& arg = args[resource->arg_num()];
const int core = arg_cores[resource->arg_num()];
DCHECK_LT(resource->arg_num(), arg_cores.size());
- bool modified = resource->value() != resource->initial_value();
+ bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
- modified = modified ||
- grad.second->value() != grad.second->initial_value() ||
- arg.tensor_array_gradients.count(grad.first) == 0;
+ modified =
+ modified ||
+ !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
+ arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 93cd340485..31115eea60 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -97,12 +97,48 @@ xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) {
return b->ConstantLiteral(xla::Literal::MinValue(type));
}
+xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ switch (type) {
+ case xla::F16:
+ return b->ConstantR0<Eigen::half>(
+ Eigen::NumTraits<Eigen::half>::lowest());
+ case xla::BF16:
+ return b->ConstantR0<bfloat16>(bfloat16::lowest());
+ case xla::F32:
+ return b->ConstantR0<float>(-std::numeric_limits<float>::max());
+ case xla::F64:
+ return b->ConstantR0<double>(-std::numeric_limits<double>::max());
+ default:
+ return b->ConstantLiteral(xla::Literal::MinValue(type));
+ }
+}
+
xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::MaxValue(type));
}
+xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ switch (type) {
+ case xla::F16:
+ return b->ConstantR0<Eigen::half>(
+ Eigen::NumTraits<Eigen::half>::highest());
+ case xla::BF16:
+ return b->ConstantR0<bfloat16>(bfloat16::highest());
+ case xla::F32:
+ return b->ConstantR0<float>(std::numeric_limits<float>::max());
+ case xla::F64:
+ return b->ConstantR0<double>(std::numeric_limits<double>::max());
+ default:
+ return b->ConstantLiteral(xla::Literal::MaxValue(type));
+ }
+}
+
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
@@ -267,6 +303,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
}
DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
+ // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
+ // repeated floating point additions.
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return DT_FLOAT;
}
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index c3fdc5252e..c320016998 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -29,13 +29,21 @@ namespace tensorflow {
class XlaHelpers {
public:
// Returns a handle representing the minimum value of a scalar
- // element of data_type.
+ // element of data_type. -inf for floating-point types.
static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type);
- // Returns a handle representing the maximum value of a scalar
+ // Returns a handle representing the minimum finite value of a scalar
// element of data_type.
+ static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type);
+
+ // Returns a handle representing the maximum value of a scalar
+ // element of data_type. inf for floating point types.
static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type);
+ // Returns a handle representing the maximum finite value of a scalar
+ // element of data_type.
+ static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type);
+
// Returns a handle representing the zero value of a scalar
// element of data_type.
static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index ee6da6a67a..46785bc1f0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -240,6 +240,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
// a) the types supported by the backend, and
// b) the types allowed by the OpDef, and
// c) the type constraints.
+ bool unsatisfiable_type_constraint = false;
for (const string& type_attr : type_attrs) {
KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
attr_constraint->set_name(type_attr);
@@ -276,7 +277,14 @@ void XlaOpRegistry::RegisterCompilationKernels() {
if (op_registration->allow_resource_types) {
allowed_values->add_type(DT_RESOURCE);
}
+ // Don't build KernelDefs that have unsatisfiable type constraints.
+ if (allowed_values->type().empty()) {
+ unsatisfiable_type_constraint = true;
+ break;
+ }
}
+ if (unsatisfiable_type_constraint) continue;
+
if (backend.second.op_filter != nullptr &&
!backend.second.op_filter(kdef.get())) {
continue;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc
index 266cbc4395..7b3b15b1af 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc
@@ -82,5 +82,38 @@ TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) {
}
}
+// A dummy generic OpKernel for all backends.
+class DummyInfeasibleTypeConstraintOp : public XlaOpKernel {
+ public:
+ explicit DummyInfeasibleTypeConstraintOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ LOG(FATAL) << "unreachable";
+ }
+};
+
+REGISTER_OP("DummyInfeasibleTypeConstraintOp")
+ .Attr("T: {float, string}")
+ .Input("input: T")
+ .Output("output: T")
+ .Doc(R"doc(
+A dummy Op.
+
+input: dummy input.
+output: dummy output.
+)doc");
+REGISTER_XLA_OP(
+ Name("DummyInfeasibleTypeConstraintOp").TypeConstraint("T", DT_STRING),
+ DummyInfeasibleTypeConstraintOp);
+
+TEST(XlaOpRegistryTest, OpWithInfeasibleTypeConstraintIsNotRegistered) {
+ XlaOpRegistry::RegisterCompilationKernels();
+ auto registered_kernels = GetAllRegisteredKernels().kernel();
+ for (const auto& kernels : registered_kernels) {
+ // The operator should not be registered.
+ EXPECT_NE(kernels.op(), "DummyInfeasibleTypeConstraintOp");
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 4525197146..95bd725850 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -175,6 +175,7 @@ cc_library(
hdrs = [
"iterator_util.h",
"map_util.h",
+ "overflow_util.h",
"ptr_util.h",
"util.h",
],
@@ -250,7 +251,7 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework_internal",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 8e875bf352..0d7758eef9 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -111,14 +111,17 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
});
}
-StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
- auto f = builder->ConstantR0<bool>(false);
- XlaComputation logical_or = CreateScalarOrComputation(builder);
- TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
- builder->GetShape(predicates));
- std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
- std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
- return builder->Reduce(predicates, f, logical_or, all_dimensions);
+XlaOp Any(XlaOp predicates) {
+ XlaBuilder* builder = predicates.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ auto f = builder->ConstantR0<bool>(false);
+ XlaComputation logical_or = CreateScalarOrComputation(builder);
+ TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
+ builder->GetShape(predicates));
+ std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
+ std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
+ return builder->Reduce(predicates, f, logical_or, all_dimensions);
+ });
}
namespace {
@@ -164,7 +167,7 @@ std::array<float, 6> kErfUCoefficient = {
// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(const XlaOp& x,
+XlaOp EvaluatePolynomial(XlaOp x,
tensorflow::gtl::ArraySlice<float> coefficients,
PrimitiveType data_type) {
XlaBuilder* b = x.builder();
@@ -176,7 +179,7 @@ XlaOp EvaluatePolynomial(const XlaOp& x,
}
// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(const XlaOp& x, PrimitiveType data_type) {
+XlaOp Erfc(XlaOp x, PrimitiveType data_type) {
XlaBuilder* b = x.builder();
XlaOp zero = FloatLiteral(b, data_type, 0.0);
XlaOp two = FloatLiteral(b, data_type, 2.0);
@@ -197,7 +200,7 @@ XlaOp Erfc(const XlaOp& x, PrimitiveType data_type) {
}
// Compute a polynomial approximation of the error function.
-XlaOp Erf(const XlaOp& x, PrimitiveType data_type) {
+XlaOp Erf(XlaOp x, PrimitiveType data_type) {
XlaBuilder* b = x.builder();
XlaOp z = b->Mul(x, x);
XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type);
@@ -217,38 +220,40 @@ XlaOp Erf(const XlaOp& x, PrimitiveType data_type) {
// p = sum_{i=1}^n gq[i]*w^i
// }
// return p*x
-StatusOr<XlaOp> ErfInv(const XlaOp& x) {
+XlaOp ErfInv(XlaOp x) {
XlaBuilder* b = x.builder();
- TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> w_less_than_5_constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> w_greater_than_5_constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- auto one = b->ConstantR0<float>(1.0);
- auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x))));
-
- auto lt = b->Lt(w, b->ConstantR0<float>(5.0));
- auto coefficient = [&](int i) {
- return b->Select(
- lt,
- b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())),
- b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]),
- AsInt64Slice(shape.dimensions())));
- };
- w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)),
- b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f)));
- auto p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = b->Add(coefficient(i), b->Mul(p, w));
- }
- return b->Mul(p, x);
+ return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
+ constexpr int kDegree = 9;
+ constexpr std::array<float, 9> w_less_than_5_constants = {
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ constexpr std::array<float, 9> w_greater_than_5_constants = {
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+
+ auto one = b->ConstantR0<float>(1.0);
+ auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x))));
+
+ auto lt = b->Lt(w, b->ConstantR0<float>(5.0));
+ auto coefficient = [&](int i) {
+ return b->Select(
+ lt,
+ b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())),
+ b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())));
+ };
+ w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)),
+ b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f)));
+ auto p = coefficient(0);
+ for (int i = 1; i < kDegree; ++i) {
+ p = b->Add(coefficient(i), b->Mul(p, w));
+ }
+ return b->Mul(p, x);
+ });
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 33a8254274..d0e04bbb5e 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -53,22 +53,22 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
// Returns whether any predicate in "predicates" is set.
//
// Note: if predicates is zero-sized, Any() vacuously returns false.
-StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder);
+XlaOp Any(XlaOp predicates);
// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(const XlaOp& x,
+XlaOp EvaluatePolynomial(XlaOp x,
tensorflow::gtl::ArraySlice<float> coefficients,
PrimitiveType data_type);
// Compute an approximation of the error function complement (1 - erf(x)).
-XlaOp Erfc(const XlaOp& x, PrimitiveType data_type);
+XlaOp Erfc(XlaOp x, PrimitiveType data_type);
// Compute an approximation of the error function.
-XlaOp Erf(const XlaOp& x, PrimitiveType data_type);
+XlaOp Erf(XlaOp x, PrimitiveType data_type);
// Compute an approximation of the inverse of the error function.
-StatusOr<XlaOp> ErfInv(const XlaOp& x);
+XlaOp ErfInv(XlaOp x);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index 507a2dc5f0..b0f41ac1d3 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -52,6 +52,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shape_inference",
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 256667cbe0..8515d120da 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -59,6 +60,54 @@ bool CanBeRoot(HloOpcode opcode) {
} // namespace
+XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); }
+XlaOp operator+(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Add(x, y);
+}
+XlaOp operator-(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Sub(x, y);
+}
+XlaOp operator*(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Mul(x, y);
+}
+XlaOp operator/(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Div(x, y);
+}
+XlaOp operator%(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Rem(x, y);
+}
+
+XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); }
+XlaOp operator&(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->And(x, y);
+}
+XlaOp operator|(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Or(x, y);
+}
+XlaOp operator^(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->Xor(x, y);
+}
+XlaOp operator<<(const XlaOp& x, const XlaOp& y) {
+ return x.builder()->ShiftLeft(x, y);
+}
+
+XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
+ XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ if (!ShapeUtil::ElementIsIntegral(shape)) {
+ return InvalidArgument(
+ "Argument to >> operator does not have an integral type (%s).",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+ if (ShapeUtil::ElementIsSigned(shape)) {
+ return builder->ShiftRightArithmetic(x, y);
+ } else {
+ return builder->ShiftRightLogical(x, y);
+ }
+ });
+}
+
StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
@@ -81,7 +130,7 @@ XlaBuilder::XlaBuilder(const string& computation_name)
XlaBuilder::~XlaBuilder() {}
-void XlaBuilder::NoteError(const Status& error) {
+XlaOp XlaBuilder::ReportError(const Status& error) {
CHECK(!error.ok());
if (die_immediately_on_error_) {
LOG(FATAL) << "error building computation: " << error;
@@ -91,19 +140,22 @@ void XlaBuilder::NoteError(const Status& error) {
first_error_ = error;
first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
}
+ return XlaOp(this);
}
-XlaOp XlaBuilder::NoteErrorOrReturn(
- const std::function<StatusOr<XlaOp>()>& op_creator) {
+XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
if (!first_error_.ok()) {
return XlaOp(this);
}
- auto op = op_creator();
if (!op.ok()) {
- NoteError(op.status());
- return XlaOp(this);
+ return ReportError(op.status());
}
- return op.ConsumeValueOrDie();
+ return op.ValueOrDie();
+}
+
+XlaOp XlaBuilder::ReportErrorOrReturn(
+ const std::function<StatusOr<XlaOp>()>& op_creator) {
+ return ReportErrorOrReturn(op_creator());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
@@ -207,7 +259,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() {
DCHECK(parent_builder_ != nullptr);
auto build_status = Build();
if (!build_status.ok()) {
- parent_builder_->NoteError(
+ parent_builder_->ReportError(
AddStatus(build_status.status(),
tensorflow::strings::StrCat("error from: ", name_)));
return {};
@@ -315,7 +367,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
}
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
@@ -327,7 +379,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
XlaOp XlaBuilder::BinaryOp(
HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -383,7 +435,7 @@ XlaOp XlaBuilder::BinaryOp(
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
const XlaOp& ehs) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -430,7 +482,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape();
*instr.mutable_literal() = literal.ToProto();
@@ -440,7 +492,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
@@ -461,7 +513,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
const string& name) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!parameter_numbers_.insert(parameter_number).second) {
return InvalidArgument("parameter %lld already registered",
@@ -476,7 +528,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
XlaOp XlaBuilder::Broadcast(
const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
const Shape& shape,
@@ -510,7 +562,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -530,7 +582,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
std::vector<int64> limits(shape.dimensions().begin(),
@@ -545,7 +597,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -566,7 +618,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -584,7 +636,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
int64 dimension) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
@@ -603,7 +655,7 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -624,7 +676,7 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
ShapeInference::InferReshapeShape(
@@ -638,7 +690,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
@@ -648,7 +700,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
// Not collapsing anything, trivially we can return the operand versus
// enqueueing a trivial reshape.
@@ -690,7 +742,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
- NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
*instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
@@ -704,7 +756,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
}
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
@@ -718,7 +770,7 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
}
XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
if (!ShapeUtil::IsTuple(tuple_shape)) {
@@ -767,7 +819,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
DotDimensionNumbers dimension_numbers;
@@ -780,7 +832,7 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -859,7 +911,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -905,7 +957,7 @@ XlaOp XlaBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -992,7 +1044,7 @@ StatusOr<Window> XlaBuilder::MakeWindow(
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
const tensorflow::gtl::ArraySlice<int64> fft_length) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1009,23 +1061,69 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
}
XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout");
}
- *instr.mutable_shape() = shape;
+ const Shape infeed_instruction_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ *instr.mutable_shape() = infeed_instruction_shape;
instr.set_infeed_config(config);
- return AddInstruction(std::move(instr), HloOpcode::kInfeed);
+
+ if (ShapeUtil::IsArray(shape) && sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
+ // TODO(b/110793772): Support tiled array-shaped infeeds.
+ return InvalidArgument(
+ "Tiled sharding is not yet supported for array-shaped infeeds");
+ }
+
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ return InvalidArgument(
+ "Replicated sharding is not yet supported for infeeds");
+ }
+
+ // The sharding is set by the client according to the data tuple shape.
+ // However, the shape of the infeed instruction is a tuple containing the
+ // data and a token. For tuple sharding type, the sharding must be changed
+ // to accommodate the token.
+ XlaOp infeed;
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) {
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ OpSharding infeed_instruction_sharding = *sharding();
+ // Arbitrarily assign the token to device 0.
+ *infeed_instruction_sharding.add_tuple_shardings() =
+ sharding_builder::AssignDevice(0);
+ XlaScopedShardingAssignment scoped_sharding(this,
+ infeed_instruction_sharding);
+ TF_ASSIGN_OR_RETURN(infeed,
+ AddInstruction(std::move(instr), HloOpcode::kInfeed));
+ } else {
+ TF_ASSIGN_OR_RETURN(infeed,
+ AddInstruction(std::move(instr), HloOpcode::kInfeed));
+ }
+
+ // The infeed instruction produces a tuple of the infed data and a token
+ // type. Return XLA op containing the data.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto infeed_data;
+ *infeed_data.mutable_shape() = shape;
+ infeed_data.set_tuple_index(0);
+ return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
+ {infeed});
});
}
void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
- NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- *instr.mutable_shape() = ShapeUtil::MakeNil();
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
@@ -1042,14 +1140,33 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
instr.set_outfeed_config(outfeed_config);
- return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand});
+ TF_RETURN_IF_ERROR(
+ AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand})
+ .status());
+
+ // The outfeed instruction produces a token. However, existing users expect
+ // a nil shape (empty tuple). This should only be relevant if the outfeed is
+ // the root of a computation.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto tuple_instr;
+ *tuple_instr.mutable_shape() = ShapeUtil::MakeNil();
+
+ // The dummy tuple should have no sharding.
+ {
+ XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
+ TF_ASSIGN_OR_RETURN(
+ XlaOp empty_tuple,
+ AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
+ return empty_tuple;
+ }
});
}
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (tensorflow::str_util::StartsWith(call_target_name, "$")) {
return InvalidArgument(
@@ -1066,7 +1183,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
const string& channel_name,
int64 cost_estimate_ns, const Shape& shape) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape;
instr.set_channel_name(channel_name);
@@ -1221,7 +1338,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) {
XlaOp XlaBuilder::Transpose(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> permutation) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1236,7 +1353,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
XlaOp XlaBuilder::Rev(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1265,7 +1382,7 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1277,7 +1394,7 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1311,13 +1428,12 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!static_operands.empty()) {
return Unimplemented("static_operands is not supported in Map");
}
HloInstructionProto instr;
-
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
@@ -1329,16 +1445,32 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape,
dimensions));
+ const Shape& output_shape = instr.shape();
+ const int64 output_rank = ShapeUtil::Rank(output_shape);
AddCalledComputation(computation, &instr);
+ std::vector<XlaOp> new_operands(operands.begin(), operands.end());
+ for (XlaOp& new_operand : new_operands) {
+ TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand));
+ const int64 rank = ShapeUtil::Rank(shape);
+ if (rank != output_rank) {
+ TF_ASSIGN_OR_RETURN(new_operand,
+ InDimBroadcast(output_shape, new_operand, {}));
+ TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand));
+ }
+ if (!ShapeUtil::SameDimensions(output_shape, shape)) {
+ TF_ASSIGN_OR_RETURN(new_operand,
+ AddBroadcastSequence(output_shape, new_operand));
+ }
+ }
- return AddInstruction(std::move(instr), HloOpcode::kMap, operands);
+ return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
});
}
XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
tensorflow::gtl::ArraySlice<XlaOp> parameters,
const Shape& shape) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Check the number of parameters per RNG distribution.
@@ -1376,7 +1508,7 @@ XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, const XlaOp& init) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Infer shape.
@@ -1398,7 +1530,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
@@ -1423,7 +1555,7 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate));
@@ -1455,7 +1587,7 @@ XlaOp XlaBuilder::Reduce(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1480,7 +1612,7 @@ XlaOp XlaBuilder::Reduce(
XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
std::vector<int64> all_dimnos(ShapeUtil::Rank(operand_shape));
std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
@@ -1493,7 +1625,7 @@ XlaOp XlaBuilder::ReduceWindow(
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1516,7 +1648,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1540,7 +1672,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1563,7 +1695,7 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, const XlaOp& mean,
const XlaOp& variance, float epsilon,
int64 feature_index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1588,7 +1720,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
const XlaOp& batch_mean, const XlaOp& batch_var,
const XlaOp& grad_output, float epsilon,
int64 feature_index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1612,7 +1744,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
XlaOp XlaBuilder::CrossReplicaSum(
const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
auto b = CreateSubBuilder("sum");
@@ -1628,7 +1760,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (channel_id.has_value()) {
return Unimplemented("channel_id is not supported in AllReduce");
}
@@ -1655,7 +1787,7 @@ XlaOp XlaBuilder::SelectAndScatter(
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const XlaOp& source, const XlaOp& init_value,
const XlaComputation& scatter) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
return SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides,
@@ -1672,7 +1804,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const XlaOp& source, const XlaOp& init_value,
const XlaComputation& scatter) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1700,7 +1832,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
@@ -1714,7 +1846,7 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
}
void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
- NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Send instruction produces a tuple of {aliased operand, U32 context}.
@@ -1735,7 +1867,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
}
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Recv instruction produces a tuple of {receive buffer, U32 context}.
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index f18306fff0..d7e50772c4 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include <string>
+#include <type_traits>
#include <utility>
#include "tensorflow/compiler/xla/client/padding.h"
@@ -46,22 +47,25 @@ class XlaBuilder;
// instruction as an operand.
class XlaOp {
public:
- XlaOp() : handle_(-1), builder_(nullptr) {}
- ~XlaOp() {}
-
- XlaBuilder* builder() const { return builder_; }
-
- bool operator==(const XlaOp& rhs) const {
- return handle_ == rhs.handle_ && builder_ == rhs.builder_;
+ XlaOp() : handle_(-1), builder_(nullptr) {
+ static_assert(std::is_trivially_destructible<XlaOp>::value,
+ "XlaOp should be trivially destructible");
}
+ ~XlaOp() = default;
- bool operator!=(const XlaOp& rhs) const {
- return handle_ != rhs.handle_ || builder_ != rhs.builder_;
- }
+ XlaBuilder* builder() const { return builder_; }
// Returns true if the XlaOp represents valid, non-erroneous value.
bool valid() const { return handle_ >= 0; }
+ // Returns true if the XlaOp was created by the XlaOp() constructor and
+ // not returned by a builder.
+ bool IsUninitialized() const { return builder_ == nullptr; }
+
+ bool IsIdenticalTo(const XlaOp& rhs) const {
+ return handle_ == rhs.handle_ && builder_ == rhs.builder_;
+ }
+
friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
out << op.handle();
return out;
@@ -84,6 +88,30 @@ class XlaOp {
XlaBuilder* builder_;
};
+// Arithmetic operator overloads for the XlaOp type.
+XlaOp operator-(const XlaOp& x);
+XlaOp operator+(const XlaOp& x, const XlaOp& y);
+XlaOp operator-(const XlaOp& x, const XlaOp& y);
+XlaOp operator*(const XlaOp& x, const XlaOp& y);
+XlaOp operator/(const XlaOp& x, const XlaOp& y);
+XlaOp operator%(const XlaOp& x, const XlaOp& y);
+
+// Bitwise operator overloads for the XlaOp type.
+XlaOp operator~(const XlaOp& x);
+XlaOp operator&(const XlaOp& x, const XlaOp& y);
+XlaOp operator|(const XlaOp& x, const XlaOp& y);
+XlaOp operator^(const XlaOp& x, const XlaOp& y);
+XlaOp operator<<(const XlaOp& x, const XlaOp& y);
+// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
+// a right logical shift.
+XlaOp operator>>(const XlaOp& x, const XlaOp& y);
+
+// We don't overload the relational operators (==, !=, <, <=, >, >=) because the
+// semantics might be surprising since their result types are usually 'bool'.
+// Further programmers may expect == to be a structural equality.
+// We also choose not to overload any of the mutating operators (e.g., +=, -=)
+// because the semantics might be misleading — XLA computations are immutable.
+
// A convenient interface for building up computations.
//
// Thread-compatible.
@@ -822,6 +850,24 @@ class XlaBuilder {
// Returns the (inferred) result for the current computation's shape.
StatusOr<ProgramShape> GetProgramShape() const;
+ // Reports an error to the builder, by
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to
+ // Build()/GetShape()/...)
+ // * dying if die_immediately_on_error_ is true.
+ // Returns an XlaOp with an invalid handle but a valid builder. This value can
+ // be returned in place of a value in APIs that return an XlaOp.
+ XlaOp ReportError(const Status& error);
+
+ // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
+ // If the Status was an error, reports the error to builder and returns an
+ // invalid XlaOp handle.
+ XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
+
+ // A helper function that runs a function that returns a StatusOr<XlaOp> and
+ // returns an XlaOp.
+ XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
+
private:
StatusOr<XlaOp> AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
@@ -830,14 +876,6 @@ class XlaBuilder {
void AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr);
- // Notes that the error occurred by:
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to Build())
- // * dying if die_immediately_on_error_ is true
- void NoteError(const Status& error);
-
- XlaOp NoteErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
-
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
// Internal helper method that does the building for an arbitrary unary op.
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
index 0680b38f3a..8a5bf96714 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
@@ -59,6 +59,76 @@ TEST_F(XlaBuilderTest, OnePlusTwo) {
EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
}
+TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) {
+ auto test_unary_operator =
+ [&](std::function<XlaOp(XlaOp)> op,
+ ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
+ XlaBuilder b(TestName());
+ op(b.ConstantR0<int32>(1));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, matches_pattern);
+ };
+ test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant()));
+ test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) {
+ auto test_binary_operator =
+ [&](std::function<XlaOp(XlaOp, XlaOp)> op,
+ ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
+ XlaBuilder b(TestName());
+ op(b.ConstantR0<int32>(1), b.ConstantR0<int32>(2));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, matches_pattern);
+ };
+
+ test_binary_operator([](XlaOp x, XlaOp y) { return x + y; },
+ op::Add(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x - y; },
+ op::Subtract(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x * y; },
+ op::Multiply(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x / y; },
+ op::Divide(op::Constant(), op::Constant()));
+
+ test_binary_operator([](XlaOp x, XlaOp y) { return x & y; },
+ op::And(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x | y; },
+ op::Or(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; },
+ op::Xor(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x << y; },
+ op::ShiftLeft(op::Constant(), op::Constant()));
+ test_binary_operator(
+ [](XlaOp x, XlaOp y) { return x >> y; },
+ op::ShiftRightArithmetic(op::Constant(), op::Constant()));
+
+ auto test_unsigned_binary_operator =
+ [&](std::function<XlaOp(XlaOp, XlaOp)> op,
+ ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
+ XlaBuilder b(TestName());
+ op(b.ConstantR0<uint32>(1), b.ConstantR0<uint32>(2));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, matches_pattern);
+ };
+ test_unsigned_binary_operator(
+ [](XlaOp x, XlaOp y) { return x >> y; },
+ op::ShiftRightLogical(op::Constant(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) {
+ XlaBuilder b(TestName());
+ b.ConstantR0<float>(1) >> b.ConstantR0<float>(2);
+ auto statusor = b.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Argument to >> operator does not have an integral type"));
+}
+
TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
XlaBuilder b(TestName());
auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
@@ -221,5 +291,32 @@ TEST_F(XlaBuilderTest, Transpose) {
EXPECT_THAT(root, op::Transpose(op::Parameter()));
}
+TEST_F(XlaBuilderTest, ReportError) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
+ b.Add(b.ReportError(InvalidArgument("a test error")), x);
+ auto statusor = b.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
+}
+
+TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) {
+ XlaBuilder b(TestName());
+ StatusOr<XlaOp> op(b.ConstantR0<float>(1.0));
+ b.Add(b.ReportErrorOrReturn(op), b.ConstantR0<float>(2.0));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
+ XlaBuilder b(TestName());
+ StatusOr<XlaOp> op(InvalidArgument("a test error"));
+ b.Add(b.ReportErrorOrReturn(op), b.ConstantR0<float>(2.0));
+ auto statusor = b.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 3f059cac30..15eeb2ea13 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -248,6 +248,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
}
+ if (layout.format() == SPARSE) {
+ if (!layout.padded_dimensions().empty()) {
+ return InvalidArgument("Sparse layout has padded dimensions");
+ }
+ }
+
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/overflow_util.h b/tensorflow/compiler/xla/overflow_util.h
new file mode 100644
index 0000000000..8657d3a4bf
--- /dev/null
+++ b/tensorflow/compiler/xla/overflow_util.h
@@ -0,0 +1,50 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Multiply two nonnegative int64's, returning negative for overflow
+inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) {
+ // Multiply in uint64 rather than int64 since signed overflow is undefined.
+ // Negative values will wrap around to large unsigned values in the casts
+ // (see section 4.7 [conv.integral] of the C++14 standard).
+ const uint64 ux = x;
+ const uint64 uy = y;
+ const uint64 uxy = ux * uy;
+
+ // Check if we overflow uint64, using a cheap check if both inputs are small
+ if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) {
+ // Ensure nonnegativity. Note that negative numbers will appear "large"
+ // to the unsigned comparisons above.
+ CHECK(x >= 0 && y >= 0);
+
+ // Otherwise, detect overflow using a division
+ if (ux != 0 && uxy / ux != uy) return -1;
+ }
+
+ // Cast back to signed. Any negative value will signal an error.
+ return static_cast<int64>(uxy);
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 29062348b0..734d9334fd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -511,22 +511,14 @@ LocalOp LocalComputationBuilder::Rev(
LocalOp LocalComputationBuilder::Map(
tensorflow::gtl::ArraySlice<LocalOp> operands,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<LocalOp> static_operands) {
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
xla_ops.push_back(op.op());
}
- std::vector<XlaOp> static_xla_ops;
- static_xla_ops.reserve(static_operands.size());
- for (const auto& op : static_operands) {
- static_xla_ops.push_back(op.op());
- }
-
- return builder_.Map(xla_ops, local_computation.computation(), dimensions,
- static_xla_ops);
+ return builder_.Map(xla_ops, local_computation.computation(), dimensions);
}
LocalOp LocalComputationBuilder::Reduce(
@@ -621,6 +613,7 @@ _FORWARD_BINOP(Max)
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
+_FORWARD_BINOP(Xor)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 95f0a0610b..e920f8aecd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -270,8 +270,7 @@ class LocalComputationBuilder {
LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<LocalOp> static_operands);
+ tensorflow::gtl::ArraySlice<int64> dimensions);
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
@@ -333,6 +332,7 @@ class LocalComputationBuilder {
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
+ _FORWARD_BINOP(Xor)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 477df6fde2..76e9e637cd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -988,6 +988,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Min;
%unignore xla::swig::LocalComputationBuilder::And;
%unignore xla::swig::LocalComputationBuilder::Or;
+%unignore xla::swig::LocalComputationBuilder::Xor;
%unignore xla::swig::LocalComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a1fc25303c..abb97d0c6f 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -123,6 +123,7 @@ _BINARY_OPS = [
'Min',
'And',
'Or',
+ 'Xor',
'Pow',
]
@@ -908,20 +909,19 @@ class ComputationBuilder(object):
"""
return self._client.Call(computation_to_apply.c_local_computation, operands)
- def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
+ def Map(self, operands, computation_to_apply, dimensions):
"""Enqueues a map operation onto the computation.
Args:
operands: an iterable of LocalOp.
computation_to_apply: a Computation object.
dimensions: dimensions over which to apply map the function.
- static_operands: auxiliary arguments passed to the applied computation.
Returns:
A LocalOp representing the added Map op.
"""
return self._client.Map(operands, computation_to_apply.c_local_computation,
- dimensions, static_operands)
+ dimensions)
def Reduce(self, operand, init_value, computation_to_apply, dimensions):
"""Enqueues a reduction operation onto the computation.
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 71e1d60a4e..0564ddcb85 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -157,6 +157,13 @@ class ComputationsWithConstantsTest(LocalComputationTest):
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
+ def testBooleanXor(self):
+ c = self._NewComputation()
+ c.Xor(
+ c.Constant(NumpyArrayBool([True, False, True, False])),
+ c.Constant(NumpyArrayBool([True, True, False, False])))
+ self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
+
def testSum2DF32(self):
c = self._NewComputation()
c.Add(
@@ -1168,14 +1175,6 @@ class EmbeddedComputationsTest(LocalComputationTest):
self._CreateBinaryDivF64Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
- def DISABLED_testMapWithStaticOperands(self):
- c = self._NewComputation()
- factor = c.ConstantF32Scalar(3.0)
- c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
- self._CreateMulF32ByParamComputation(), [0],
- static_operands=[factor])
- self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0])
-
def testSelectAndScatterF32(self):
c = self._NewComputation()
c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index d7dd9786a2..4031320001 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -91,7 +91,7 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
auto y = builder.ConstantR1<float>(
{5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
auto ax = builder.Mul(alpha, x);
- auto axpy = builder.Add(ax, y);
+ builder.Add(ax, y);
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index c08960a57b..0833289b73 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2094,6 +2094,7 @@ cc_library(
hdrs = ["hlo_verifier.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index d8a9aba834..4858fe61e0 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -50,20 +50,15 @@ namespace {
namespace m = match;
-// Returns whether operand is a literal with the given value.
-bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
- return operand->opcode() == HloOpcode::kConstant &&
- operand->literal().IsAll(value);
-}
-
bool IsAll(const HloInstruction* op, int8 value) {
- if (IsLiteralWithValue(op, value)) {
- return true;
- }
- if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) {
- return true;
+ switch (op->opcode()) {
+ case HloOpcode::kBroadcast:
+ return IsAll(op->operand(0), value);
+ case HloOpcode::kConstant:
+ return op->literal().IsAll(value);
+ default:
+ return false;
}
- return false;
}
// Returns whether the given transpose produces a result which is bit-wise
@@ -160,9 +155,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleMap(HloInstruction* map) override;
- Status HandleMaximum(HloInstruction* maximum) override;
- Status HandleMinimum(HloInstruction* minimum) override;
-
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
@@ -201,8 +193,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Helper method to perform and add reduction in a single dimension.
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
- HloInstruction* zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloInstruction* zero =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ Literal::Zero(hlo->shape().element_type()).CloneToUnique()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -572,6 +565,14 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
return Status::OK();
}
+namespace {
+template <typename T>
+Status InvertConstant(const HloInstruction& constant, Literal* result) {
+ return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return T{1.0} / constant.literal().Get<T>(indices);
+ });
+}
+} // namespace
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
Shape* shape;
@@ -633,14 +634,31 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
- HloInstruction* one =
- computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(a->shape().element_type()).CloneToUnique()));
- HloInstruction* inverse = computation_->AddInstruction(
- HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kMultiply, a, inverse));
+ Literal new_literal(b->shape());
+ switch (b->shape().element_type()) {
+ case F16:
+ TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
+ break;
+ case F32:
+ TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
+ break;
+ case BF16:
+ TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
+ break;
+ case F64:
+ TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
+ break;
+ case C64:
+ TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
+ break;
+ default:
+ return Status::OK();
+ }
+ auto inverse = computation_->AddInstruction(
+ HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
+ return ReplaceInstruction(divide, new_divide);
}
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
@@ -660,18 +678,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
TF_ASSIGN_OR_RETURN(auto b_times_c,
MakeBinaryHlo(HloOpcode::kMultiply, b, c));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kDivide, a, b_times_c));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
+ return ReplaceInstruction(divide, new_divide);
}
// A / (B / C) => (A*C) / B
if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
TF_ASSIGN_OR_RETURN(auto a_times_c,
MakeBinaryHlo(HloOpcode::kMultiply, a, c));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kDivide, a_times_c, b));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
+ return ReplaceInstruction(divide, new_divide);
}
return Status::OK();
@@ -2074,10 +2092,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ Literal::Zero(convolution->shape().element_type())
+ .CloneToUnique())),
{}));
}
const auto& window = convolution->window();
@@ -2249,68 +2266,6 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
return ReplaceWithNewInstruction(map, std::move(clone));
}
-Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
- // Match the following tree:
- // min_operand operand
- // \ /
- // max_operand min
- // \ /
- // max
- // where max_operand and min_operand are scalar constants.
- {
- HloInstruction* min;
- HloInstruction* max_operand;
- HloInstruction* min_operand;
- HloInstruction* operand;
-
- if (hlo_query::MatchBinaryInstructionOperandOpcode(
- HloOpcode::kMinimum, maximum,
- /*matching_operand=*/&min,
- /*other_operand=*/&max_operand) &&
- hlo_query::MatchBinaryInstructionOperand(
- hlo_query::IsScalarConstant, min,
- /*matching_operand=*/&min_operand,
- /*other_operand=*/&operand) &&
- TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum,
- max_operand)) {
- return Status::OK();
- }
- }
-
- return Status::OK();
-}
-
-Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
- // Match the following tree:
- // max_operand operand
- // \ /
- // min_operand max
- // \ /
- // min
- // where max_operand and min_operand are scalar constants.
- {
- HloInstruction* max;
- HloInstruction* max_operand;
- HloInstruction* min_operand;
- HloInstruction* operand;
-
- if (hlo_query::MatchBinaryInstructionOperandOpcode(
- HloOpcode::kMaximum, minimum,
- /*matching_operand=*/&max,
- /*other_operand=*/&min_operand) &&
- hlo_query::MatchBinaryInstructionOperand(
- hlo_query::IsScalarConstant, max,
- /*matching_operand=*/&max_operand,
- /*other_operand=*/&operand) &&
- TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max,
- max_operand)) {
- return Status::OK();
- }
- }
-
- return Status::OK();
-}
-
StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 49cc0b808b..b733f6f59e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -201,8 +201,11 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- builder.AddInstruction(
- HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation));
+ builder.AddInstruction(HloInstruction::CreateMap(
+ r2f32,
+ {param0, builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, zero, {}))},
+ add_computation));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -211,7 +214,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
- EXPECT_THAT(root, op::Add(param0, zero));
+ EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero)));
}
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
@@ -367,17 +370,16 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
// Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r2f32, "param1"));
HloInstruction* param2 = builder.AddInstruction(
HloInstruction::CreateParameter(2, r2f32, "param2"));
HloInstruction* param3 = builder.AddInstruction(
- HloInstruction::CreateParameter(3, r0f32, "param3"));
+ HloInstruction::CreateParameter(3, r2f32, "param3"));
HloInstruction* div0 = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
HloInstruction* div1 = builder.AddInstruction(
@@ -398,8 +400,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
EXPECT_THAT(
computation->root_instruction(),
op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2)));
- EXPECT_TRUE(
- ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32));
}
// Test that A/exp(B) is simplified to A*exp(-B).
@@ -459,7 +459,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) {
// Test that broadcasting is done on the right step when simplifying A/pow(B,C)
// to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
@@ -467,7 +466,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* param2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction::CreateParameter(2, r1f32, "param2"));
HloInstruction* power = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
builder.AddInstruction(
@@ -484,14 +483,9 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
ASSERT_THAT(computation->root_instruction(),
op::Multiply(param0, op::Power(param1, op::Negate(param2))));
-
- const HloInstruction* negate =
- computation->root_instruction()->operand(1)->operand(1);
- const Shape& negate_shape = negate->shape();
- EXPECT_EQ(0, negate_shape.dimensions_size());
}
-// A / Const => A * (1 / Const)
+// A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
HloComputation::Builder builder(TestName());
@@ -510,20 +504,19 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
- op::Multiply(param0, op::Divide(op::Constant(), constant)));
+ op::Multiply(param0, op::Constant()));
}
// pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
HloComputation::Builder builder(TestName());
HloInstruction* base = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* exp1 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* exp2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction::CreateParameter(2, r1f32, "param2"));
HloInstruction* inner_power = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
@@ -540,15 +533,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
// Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
// numbers.
TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
- Shape r0c64 = ShapeUtil::MakeShape(C64, {});
Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
HloComputation::Builder builder(TestName());
HloInstruction* base = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1c64, "param0"));
HloInstruction* exp1 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, r0c64, "param1"));
+ HloInstruction::CreateParameter(1, r1c64, "param1"));
HloInstruction* exp2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0c64, "param2"));
+ HloInstruction::CreateParameter(2, r1c64, "param2"));
HloInstruction* inner_power = builder.AddInstruction(
HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
@@ -1416,33 +1408,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape));
}
-// Regression test for a bug in the reshape sinking transformation, where
-// moving a reshape to a scalar led to a crash.
-TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
- HloComputation::Builder builder(TestName());
- HloInstruction* param =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
- HloInstruction* reshape = builder.AddInstruction(
- HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
- HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
- builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero));
- auto computation = module().AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Reshape(param), zero));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- bitcasting_callback());
-
- simplifier.Run(&module()).ValueOrDie();
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Reshape(param), zero));
-}
-
// Regression test for a bug where if we failed to sink a reshape, we'd set the
// 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
@@ -2103,160 +2068,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
EXPECT_EQ("NO_CHANGE", build_and_simplify());
}
-// Test that max(min(A, x), y) is transformed to clamp(y, A, x)
-TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMinimum, param0, min_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Minimum(param0, min_value), max_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Clamp(max_value, param0, min_value));
-}
-
-// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar
-// values.
-TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMaximum, param0, max_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Clamp(max_value, param0, min_value));
-}
-
-// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for
-// broadcasted scalar values.
-TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r1f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r1f32, HloOpcode::kMaximum, param0, max_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Clamp(max_value, param0, min_value));
-}
-
-// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to
-// clamp(non-constant1, A, non-constant2)
-TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateParameter(1, r0f32, "param1"));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0f32, "param2"));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMaximum, param0, max_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-}
-
-// Test that min(f(max(A, constant1)), constant2) is not transformed to
-// clamp(constant1, A, constant2)
-TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMaximum, param0, max_value));
- HloInstruction* fmax = builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value));
- builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMinimum, fmax, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
- min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
- min_value));
-}
-
// Test that slice(broadcast(/*scalar value*/)) simplifies to a single
// broadcast.
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index efa4696130..28b5a5784f 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1874,11 +1874,15 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
- auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto infeed =
+ builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
+ auto infeed_data = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
auto cond0 = module->AddEmbeddedComputation(build_cond());
auto body0 = module->AddEmbeddedComputation(build_body());
auto while0 = builder.AddInstruction(
- HloInstruction::CreateWhile(r0s32, cond0, body0, infeed));
+ HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
auto cond1 = module->AddEmbeddedComputation(build_cond());
auto body1 = module->AddEmbeddedComputation(build_body());
@@ -1909,8 +1913,8 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// computation, since the issue this test stresses depends on the order the
// nodes are traversed during BufferAssignment.
SequentialHloOrdering::HloModuleSequence sequence;
- sequence[module->entry_computation()] = {infeed, while0, while1, zero,
- add, while2, tuple};
+ sequence[module->entry_computation()] = {
+ token, infeed, infeed_data, while0, while1, zero, add, while2, tuple};
TF_ASSERT_OK_AND_ASSIGN(
auto assignment,
BufferAssigner::Run(
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 738d00881d..924348c870 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -148,14 +148,16 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
HloComputation::Builder outfeeder(TestName() + ".outfeeder");
auto value = outfeeder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({}));
outfeeder.AddInstruction(
- HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/""));
+ HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/""));
auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build());
HloComputation::Builder outer(TestName() + ".outer");
outer.AddInstruction(HloInstruction::CreateCall(
- ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation));
+ outfeed_computation->root_instruction()->shape(), /*operands=*/{},
+ outfeed_computation));
module->AddEntryComputation(outer.Build());
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index 868348547d..c38719d50e 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -144,8 +144,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
auto* conditional = computation->root_instruction();
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* false_computation = conditional->false_computation();
- false_computation->AddInstruction(
- HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
+ auto token =
+ false_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
+ false_computation->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index ed1a50f516..e7539759ce 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -1605,8 +1605,8 @@ HloModule TokensShouldNotBeCopied
%constant.1 = s32[] constant(1)
%add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
%get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
- %generate-token = token[] generate-token(token[] %get-tuple-element.2)
- ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token)
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
}
%Cond (param: (s32[], token[])) -> pred[] {
@@ -1619,7 +1619,7 @@ HloModule TokensShouldNotBeCopied
ENTRY %TokensShouldNotBeCopied () -> s32[] {
%one = s32[] constant(1)
%negative_one = s32[] negate(%one)
- %init_token = token[] generate-token()
+ %init_token = token[] after-all()
%init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
%while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index b703be0f39..2c3eb1ae36 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -54,29 +54,6 @@ cc_library(
)
cc_library(
- name = "external_constant_pool",
- srcs = ["external_constant_pool.cc"],
- hdrs = ["external_constant_pool.h"],
- deps = [
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "external_constant_pool_test",
- srcs = ["external_constant_pool_test.cc"],
- deps = [
- ":external_constant_pool",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:test",
- ],
-)
-
-cc_library(
name = "cpu_compiler",
srcs = ["cpu_compiler.cc"],
hdrs = ["cpu_compiler.h"],
@@ -175,7 +152,6 @@ cc_library(
":cpu_runtime",
":custom_call_target_registry",
":disassembler",
- ":external_constant_pool",
":orc_jit_memory_mapper",
":runtime_fp16",
":runtime_conv2d",
@@ -256,7 +232,6 @@ cc_library(
":cpu_options",
":cpu_runtime",
":dot_op_emitter",
- ":external_constant_pool",
":ir_emission_utils",
":ir_function",
":parallel_loop_emitter",
@@ -273,6 +248,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 52da9d6eac..55962ba70d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -269,6 +269,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; },
/*enable_dot_strength_reduction=*/false);
+ pass.AddPass<HloDCE>();
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.
@@ -306,11 +307,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
module->mutable_entry_computation_layout(), &target_machine_features);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
- pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
- /*is_layout_sensitive=*/true,
- [](const Shape&, const Shape&) { return true; },
- /*enable_dot_strength_reduction=*/false);
- pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ {
+ auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
+ "after layout assignement");
+ pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ /*is_layout_sensitive=*/true,
+ [](const Shape&, const Shape&) { return true; },
+ /*enable_dot_strength_reduction=*/false);
+ pass.AddPass<HloDCE>();
+ pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ }
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
@@ -578,7 +584,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
- &target_machine_features, jit->external_constant_pool());
+ &target_machine_features);
for (auto embedded_computation :
entry_computation->MakeEmbeddedComputationsList()) {
@@ -765,8 +771,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
- &target_machine_features,
- /*external_constant_pool=*/nullptr);
+ &target_machine_features);
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 97e10a89a2..750310c633 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -501,8 +501,8 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
HloInstruction* exp = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
- builder.AddInstruction(HloInstruction::CreateMap(
- shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get())));
module->AddEntryComputation(builder.Build());
@@ -525,8 +525,8 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
HloInstruction* exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
- builder.AddInstruction(HloInstruction::CreateMap(
- shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get())));
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
deleted file mode 100644
index c562865591..0000000000
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
-
-#include <algorithm>
-#include <cstdlib>
-#include <cstring>
-
-#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
-
-namespace xla {
-namespace cpu {
-void ExternalConstantPool::Insert(string name, const LiteralSlice& literal,
- int64 alignment) {
- CHECK(!ShapeUtil::IsTuple(literal.shape()));
- CHECK(alignment > 0 && IsPowerOfTwo(static_cast<uint64>(alignment)));
- CHECK(entries_.find(name) == entries_.end());
-
- const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
- void* raw_pointer = tensorflow::port::AlignedMalloc(
- literal_size, std::max<size_t>(alignment, sizeof(void*)));
- CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
- << " bytes with alignment of " << alignment;
-
- std::memcpy(raw_pointer, literal.untyped_data(), literal_size);
- entries_.emplace(std::move(name), static_cast<uint8*>(raw_pointer));
-}
-
-const uint8* ExternalConstantPool::Find(const string& name) {
- auto it = entries_.find(name);
- return it == entries_.end() ? nullptr : it->second.get();
-}
-} // namespace cpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
deleted file mode 100644
index 0677f5f0b5..0000000000
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
-
-#include <memory>
-
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/platform/mem.h"
-
-namespace xla {
-namespace cpu {
-// An ExternalConstantPool maintains a set of constants kept external to
-// generated LLVM IR. These constants are accessed from the IR via globals with
-// extern linkage. This current incarnation of ExternalConstantPool only
-// supports the JIT CPU backend; the AOT backend is not supported.
-//
-// Implementation-wise, this is a simple wrapper around a map of strings to byte
-// buffers. This simply implementation works in a JIT scenario. This class
-// will have to become smarter if we decide to support external constant pools
-// on AOT compiles in the future.
-class ExternalConstantPool {
- public:
- // Inserts a buffer with the contents of `literal` into the constant pool with
- // the name `name`. It is an error to try to insert two constants with the
- // same `name` into the same constant pool. The buffer for literal is aligned
- // to `aligment` bytes, and `alignment` must be a power of 2.
- //
- // The constant pool copies out the contents of `literal` into a buffer it
- // owns -- it does not keep pointers to `literal`, or to memory owned by
- // `literal`.
- void Insert(string name, const LiteralSlice& literal, int64 alignment);
-
- // Find the constant with name `name` in this constant pool. If there isn't
- // such constant, return nullptr.
- const uint8* Find(const string& name);
-
- private:
- // We need to `AlignedFree` pointers allocated into `entries_` since we
- // allocate them with `AlignedMalloc`.
- struct FreeDeleter {
- void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); }
- };
-
- tensorflow::gtl::FlatMap<string, std::unique_ptr<uint8, FreeDeleter>>
- entries_;
-};
-} // namespace cpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc
deleted file mode 100644
index 9290a4e5df..0000000000
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
-#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace xla {
-namespace cpu {
-namespace {
-class ExternalConstantPoolTest : public ::testing::Test {};
-
-template <typename T>
-T GetFromBuffer(const uint8* buffer, int64 index) {
- T result;
- std::memcpy(&result, buffer + index * sizeof(T), sizeof(T));
- return result;
-}
-
-TEST(ExternalConstantPoolTest, Basic) {
- ExternalConstantPool constant_pool;
- EXPECT_EQ(constant_pool.Find("name-0"), nullptr);
- const auto literal = Literal::CreateR2({{1, 2}, {3, 4}});
- constant_pool.Insert("name-0", *literal, 4);
- const uint8* constant = constant_pool.Find("name-0");
- ASSERT_NE(constant, nullptr);
-
- EXPECT_EQ(GetFromBuffer<int32>(constant, 0), 1);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 1), 2);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 2), 3);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 3), 4);
-
- EXPECT_EQ(constant_pool.Find("name-1"), nullptr);
-}
-
-TEST(ExternalConstantPoolTest, RowMinorLayout) {
- ExternalConstantPool constant_pool;
- EXPECT_EQ(constant_pool.Find("name-0"), nullptr);
- const auto literal = Literal::CreateR2WithLayout(
- {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
- constant_pool.Insert("name-0", *literal, 4);
- const uint8* constant = constant_pool.Find("name-0");
- ASSERT_NE(constant, nullptr);
-
- EXPECT_EQ(GetFromBuffer<int32>(constant, 0), 1);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 1), 3);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 2), 2);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 3), 4);
-}
-
-TEST(ExternalConstantPoolTest, Alignment) {
- ExternalConstantPool constant_pool;
- EXPECT_EQ(constant_pool.Find("name-0"), nullptr);
-
- for (int i = 0; i < 8; i++) {
- int64 alignment = 1 << i;
- string name = tensorflow::strings::StrCat("name-", i);
-
- const auto literal = Literal::CreateR2({{1, 2}, {3, 4}});
- constant_pool.Insert(name, *literal, alignment);
-
- const uint8* constant = constant_pool.Find(name);
- ASSERT_NE(constant, nullptr);
- EXPECT_EQ(reinterpret_cast<intptr_t>(constant) % alignment, 0);
- }
-}
-
-} // namespace
-} // namespace cpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 75e8e9a835..6b66a4b0b7 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -48,6 +48,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
@@ -83,8 +85,7 @@ IrEmitter::IrEmitter(
llvm::Module* llvm_module,
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
- const TargetMachineFeatures* target_machine_features,
- ExternalConstantPool* external_constant_pool)
+ const TargetMachineFeatures* target_machine_features)
: assignment_(assignment),
module_(llvm_module),
arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
@@ -94,8 +95,7 @@ IrEmitter::IrEmitter(
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
hlo_module_config_(hlo_module.config()),
is_top_level_computation_(false),
- target_machine_features_(*target_machine_features),
- external_constant_pool_(external_constant_pool) {
+ target_machine_features_(*target_machine_features) {
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_enable_fast_math()));
@@ -161,45 +161,18 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
}
llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
- llvm::Constant* result;
-
- // We avoid creating large constants in the LLVM IR since LLVM is not
- // efficient for large constant arrays. We still emit "small enough" constant
- // arrays into the Ir, in the off chance the LLVM optimizer can do something
- // interesting with it.
- //
- // TODO(b/29904935): Remove the large constant pool.
- const int kMaxInternalConstantSizeInBytes = 128;
- if (external_constant_pool_ &&
- ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) {
- string global_name = tensorflow::strings::StrCat(
- "constant_global_", external_global_constant_counter_++);
- llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
- /*Module=*/*module_,
- /*Type=*/IrShapeType(literal.shape()),
- /*isConstant=*/true,
- /*Linkage=*/llvm::GlobalValue::ExternalLinkage,
- /*Initializer=*/nullptr,
- /*Name=*/AsStringRef(global_name));
- result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
- external_constant_pool_->Insert(global_name, literal,
- MinimumAlignmentForShape(literal.shape()));
- result = result_global;
- } else {
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
- llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
- /*Module=*/*module_,
- /*Type=*/initializer->getType(),
- /*isConstant=*/true,
- /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
- /*Initializer=*/initializer,
- /*Name=*/"");
- result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
- result = llvm::ConstantExpr::getBitCast(
- result_global, IrShapeType(literal.shape())->getPointerTo());
- }
- return result;
+ llvm::Constant* initializer =
+ llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+ llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
+ /*Module=*/*module_,
+ /*Type=*/initializer->getType(),
+ /*isConstant=*/true,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/initializer,
+ /*Name=*/"");
+ result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ return llvm::ConstantExpr::getBitCast(
+ result_global, IrShapeType(literal.shape())->getPointerTo());
}
Status IrEmitter::HandleConstant(HloInstruction* constant) {
@@ -321,30 +294,42 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
return DefaultAction(select);
}
-Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
+Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
+ HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
VLOG(2) << "HandleInfeed: " << infeed->ToString();
- const Shape& shape = infeed->shape();
-
- // The infeed operation produces data (dequeued from the infeed queue) at this
- // address, which has been provided by buffer assignment.
+ // The infeed operation produces a two-element tuple containing data and a
+ // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
+ const Shape& data_shape = infeed->infeed_shape();
+ DCHECK(ShapeUtil::Equal(data_shape,
+ ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
- llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed);
- if (ShapeUtil::IsTuple(shape)) {
- TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape));
+ // Write the tuple index table.
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
+ assignment_.GetUniqueSlice(infeed, {0}));
+ llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape);
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
+ assignment_.GetUniqueSlice(infeed, {1}));
+ llvm::Value* token_address = EmitTempBufferPointer(
+ token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
+ llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address},
+ &ir_builder_, module_);
+
+ if (ShapeUtil::IsTuple(data_shape)) {
+ TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
// For a tuple, we first copy each of the internal elements to
// their corresponding target locations. We then construct the
// tuple outer buffer containing pointers to the internal
// elements.
std::vector<llvm::Value*> tuple_element_addresses;
- for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
+ for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
- assignment_.GetUniqueSlice(infeed, {i}));
+ assignment_.GetUniqueSlice(infeed, {0, i}));
const Shape& tuple_element_shape =
- ShapeUtil::GetTupleElementShape(shape, i);
+ ShapeUtil::GetTupleElementShape(data_shape, i);
// Only the outer tuple buffer's target address is obtained from
// GetEmittedValueFor, to handle the case when Infeed is the root
@@ -359,11 +344,11 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
tuple_element_addresses.push_back(tuple_element_address);
}
- llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_,
- module_);
+ llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
+ tuple_element_addresses, &ir_builder_, module_);
} else {
- TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
- GetEmittedValueFor(infeed)));
+ TF_RETURN_IF_ERROR(
+ EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
}
return Status::OK();
@@ -2539,7 +2524,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
return Status::OK();
}
-Status IrEmitter::HandleGenerateToken(HloInstruction* gen_token) {
+Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0);
// No code to generate, but we need to emit an address for book-keeping.
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index e1815c1db7..3c110a320f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -67,17 +66,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// index in the profiling array.
// computation_to_profile_idx: the mapping from HLO computations to their
// index in the profiling array.
- // external_constant_pool: if non-null, points to an ExternalConstantPool
- // instance into which the Ir emitter can spill
- // constants.
IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
llvm::Module* llvm_module,
std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx,
std::unordered_map<const HloComputation*, int64>
computation_to_profile_idx,
- const TargetMachineFeatures* target_machine,
- ExternalConstantPool* external_constant_pool);
+ const TargetMachineFeatures* target_machine);
~IrEmitter() override;
// Emit and return the given HLO computation as an LLVM IR
@@ -150,7 +145,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
- Status HandleGenerateToken(HloInstruction* gen_token) override;
+ Status HandleAfterAll(HloInstruction* gen_token) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
@@ -537,9 +532,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
const TargetMachineFeatures& target_machine_features_;
- int64 external_global_constant_counter_ = 0;
- ExternalConstantPool* external_constant_pool_;
-
struct LiteralPtrHashFunctor {
size_t operator()(const Literal* literal) const { return literal->Hash(); }
};
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index fc2efbaf9a..36c9f74385 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -110,8 +110,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
const string hlo_string = R"(
HloModule TestTaskParallel_infeed_outfeed
ENTRY InfeedOutfeed {
- infeed0 = u32[12345678,2]{1,0} infeed()
- ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0)
+ infeed0 = (u32[12345678,2]{1,0}, token[]) infeed()
+ infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0
+ ROOT outfeed0 = token[] outfeed(infeed0.data)
}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 167aa4adda..e3965b4e05 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -51,7 +51,7 @@ int main(int argc, char** argv) {
xla::XlaBuilder builder("");
auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto add = builder.Add(p1, p0, {0});
+ builder.Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
xla::XlaComputation computation = computation_status.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index c4c90515ac..be772cfb7e 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -127,13 +127,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
}
llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
- if (const uint8* from_constant_pool =
- external_constant_pool_.Find(string(name))) {
- return llvm::JITEvaluatedSymbol(
- reinterpret_cast<uint64_t>(from_constant_pool),
- llvm::JITSymbolFlags::None);
- }
-
void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
if (func_addr == nullptr) {
return nullptr;
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
index 1851a3ee0b..d74b63fcf4 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -91,10 +90,6 @@ class SimpleOrcJIT {
llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
- ExternalConstantPool* external_constant_pool() {
- return &external_constant_pool_;
- }
-
// Creates an llvm::TargetMachine suitable for JITting code that will run on
// the current machine.
static std::unique_ptr<llvm::TargetMachine> InferTargetMachineForJIT(
@@ -112,7 +107,6 @@ class SimpleOrcJIT {
std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
ObjLayerT object_layer_;
CompileLayerT compile_layer_;
- ExternalConstantPool external_constant_pool_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
index 3a7255c1d2..1d4bf483ae 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
@@ -56,7 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest {
TEST_F(CpuExternalConstantsTest, Basic) {
TestWithArray(/*rows=*/1024, /*cols=*/1024, R"(
-CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16
+CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16
+CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16
)");
}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 23e7a3de4d..783b2820e9 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -96,8 +96,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil));
auto floor = builder.AddInstruction(
HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp));
- auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ vshape,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))),
+ {}));
builder.AddInstruction(
HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
@@ -114,9 +117,9 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
EXPECT_EQ(HloOpcode::kMultiply,
fusion_instruction->fused_expression_root()->opcode());
- // There should be 7 fused instructions: 2 parameters and the fused
+ // There should be 8 fused instructions: 2 parameters and the fused
// operations.
- EXPECT_EQ(7, fusion_instruction->fused_instruction_count());
+ EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
auto result = ExecuteAndTransfer(std::move(module), {});
@@ -170,8 +173,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
auto floor = builder.AddInstruction(
HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp));
- auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ cshape,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))),
+ {}));
builder.AddInstruction(
HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
@@ -188,9 +194,9 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
EXPECT_EQ(HloOpcode::kMultiply,
fusion_instruction1->fused_expression_root()->opcode());
- // There should be 5 fused instructions in the root fusion instruction: 2
+ // There should be 6 fused instructions in the root fusion instruction: 2
// parameters, multiply, floor, and exp.
- EXPECT_EQ(5, fusion_instruction1->fused_instruction_count())
+ EXPECT_EQ(6, fusion_instruction1->fused_instruction_count())
<< fusion_instruction1->fused_instructions_computation()->ToString();
auto fusion_instruction2 = reduce->operand(0);
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
index 1739b6e8b7..90b99c828e 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
@@ -38,7 +38,8 @@ while_body {
while_cond {
arg_cond = f32[2,3,2] parameter(0)
- ROOT unknown = pred[] infeed()
+ infeed = (pred[], token[]) infeed()
+ ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
ENTRY main {
@@ -49,8 +50,8 @@ ENTRY main {
{{2, 1}, {2001, 3002}, {2001, 2002}}})
const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body
- out0 = () outfeed(f32[2,3,2] const_a)
- ROOT out1 = () outfeed(f32[2,3,2] const_b)
+ out0 = token[] outfeed(f32[2,3,2] const_a)
+ ROOT out1 = token[] outfeed(f32[2,3,2] const_b)
}
)";
@@ -84,7 +85,8 @@ while_body {
while_cond {
arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
- ROOT unknown = pred[] infeed()
+ infeed = (pred[], token[]) infeed()
+ ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
ENTRY main {
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index 40b4d0ed00..dac416e1c7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -32,7 +32,8 @@ ENTRY main {
{{{1, 2}, {1001, 1002}, {2001, 2002}},
{{2, 1}, {2001, 3002}, {2001, 2002}}})
- ROOT out = () outfeed(f32[2,3,2] const_a)
+ outfeed = token[] outfeed(f32[2,3,2] const_a)
+ ROOT root = () tuple()
}
)";
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 7d56d57b5f..cb3676c5ba 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -246,7 +246,7 @@ class DfsHloVisitorBase {
virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
- virtual Status HandleGenerateToken(HloInstructionPtr token) = 0;
+ virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 6934e00a4b..987c91e5ba 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -188,7 +188,7 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
- Status HandleGenerateToken(HloInstructionPtr token) override {
+ Status HandleAfterAll(HloInstructionPtr token) override {
return DefaultAction(token);
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index af6d298589..2508755e4c 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -442,6 +442,7 @@ cc_library(
srcs = ["multi_output_fusion.cc"],
hdrs = ["multi_output_fusion.h"],
deps = [
+ ":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:multi_output_fusion",
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
index db6924c742..c77e3c81c9 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
@@ -126,12 +126,17 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
HloInstruction* variance_plus_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev,
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-2)))));
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ inverse_stddev->shape(),
+ computation_->AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))),
+ {}))));
HloInstruction* variance =
computation_->AddInstruction(HloInstruction::CreateBinary(
variance_plus_epsilon->shape(), HloOpcode::kSubtract,
- variance_plus_epsilon, epsilon));
+ variance_plus_epsilon,
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ variance_plus_epsilon->shape(), epsilon, {}))));
// Repackage the results.
std::unique_ptr<HloInstruction> new_tuple = HloInstruction::CreateTuple({
@@ -175,12 +180,17 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
HloInstruction* var_plus_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
batch_norm->operand(3)->shape(), HloOpcode::kAdd,
- batch_norm->mutable_operand(3), epsilon));
+ batch_norm->mutable_operand(3),
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ batch_norm->operand(3)->shape(), epsilon, {}))));
HloInstruction* inverse_stddev =
computation_->AddInstruction(HloInstruction::CreateBinary(
var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon,
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-.5)))));
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ var_plus_epsilon->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(-.5))),
+ {}))));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index ea34d5b30c..2b63d8727c 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -22,29 +22,29 @@ namespace xla {
namespace gpu {
InfeedThunk::InfeedThunk(
- tensorflow::gtl::ArraySlice<BufferAllocation::Slice> tuple_element_buffers,
- const BufferAllocation::Slice& destination_buffer,
+ const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction)
- : Thunk(Kind::kInfeed, hlo_instruction),
- tuple_element_buffers_(tuple_element_buffers.begin(),
- tuple_element_buffers.end()),
- destination_buffer_(destination_buffer) {}
+ : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream) {
VLOG(2) << "Infeeding to GPU ";
- se::DeviceMemoryBase destination_address =
- buffer_allocations.GetDeviceAddress(destination_buffer_);
-
+ // First copy the infeed data which is element 0 of the infeed instruction's
+ // two-tuple output (the other element is a token).
+ se::DeviceMemoryBase data_address =
+ buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
InfeedManager* infeed_manager = GetOrCreateInfeedManager();
std::vector<InfeedBuffer*> infeed_buffers;
- if (ShapeUtil::IsTuple(hlo_instruction()->shape())) {
- CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape()));
+ const Shape& data_shape =
+ ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0);
+ if (ShapeUtil::IsTuple(data_shape)) {
+ CHECK(!ShapeUtil::IsNestedTuple(data_shape));
// Transfer the tuple elements first.
std::vector<void*> tuple_element_addresses;
- for (BufferAllocation::Slice tuple_element_buffer :
- tuple_element_buffers_) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) {
+ const BufferAllocation::Slice& tuple_element_buffer =
+ infeed_slices_.element({0, i});
se::DeviceMemoryBase tuple_element_address =
buffer_allocations.GetDeviceAddress(tuple_element_buffer);
@@ -56,15 +56,23 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
}
// Transfer the tuple outer buffer.
auto host_size = tuple_element_addresses.size() * sizeof(void*);
- stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(),
+ stream->ThenMemcpy(&data_address, tuple_element_addresses.data(),
host_size);
} else {
InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
infeed_buffers.push_back(buffer);
- stream->ThenMemcpy(&destination_address, *(buffer->device_memory()),
+ stream->ThenMemcpy(&data_address, *(buffer->device_memory()),
buffer->length());
}
+ // Construct top-level tuple of infeed containing the data and the token. Use
+ // a nullptr for the token, it should never be dereferenced.
+ std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr};
+ se::DeviceMemoryBase top_level_address =
+ buffer_allocations.GetDeviceAddress(infeed_slices_.element({}));
+ stream->ThenMemcpy(&top_level_address, infeed_addresses.data(),
+ 2 * sizeof(void*));
+
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
index 93713cb12d..cb9a6232f3 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
@@ -32,12 +32,8 @@ namespace gpu {
class InfeedThunk : public Thunk {
public:
// Constructs a InfeedThunk that copies data from the on-device
- // infeed queue to the device buffer
- // `destination_buffer`. `mem_size` is the size of the data in
- // bytes.
- InfeedThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Slice>
- tuple_element_buffers,
- const BufferAllocation::Slice& destination_buffer,
+ // infeed queue into the buffers in the given shape tree.
+ InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction);
InfeedThunk(const InfeedThunk&) = delete;
@@ -47,8 +43,7 @@ class InfeedThunk : public Thunk {
se::Stream* stream) override;
private:
- const std::vector<BufferAllocation::Slice> tuple_element_buffers_;
- const BufferAllocation::Slice destination_buffer_;
+ const ShapeTree<BufferAllocation::Slice> infeed_slices_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index efeb276470..d5e07c3afb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -191,6 +191,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
HloOpcode root_opcode = computation.root_instruction()->opcode();
PrimitiveType element_type =
computation.root_instruction()->shape().element_type();
+ bool is_atomic_integral = element_type == S32 || element_type == U32 ||
+ element_type == S64 || element_type == U64;
llvm::Value* source = ir_builder_.CreateLoad(source_address, "source");
if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types.
@@ -201,7 +203,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
{output_address->getType()}, &ir_builder_);
return true;
}
- if (primitive_util::IsIntegralType(element_type)) {
+ if (is_atomic_integral) {
// integral + integral
ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address,
source,
@@ -210,9 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
}
}
- // NVPTX supports atomicMax and atomicMin on only integer types.
- if (root_opcode == HloOpcode::kMaximum &&
- primitive_util::IsIntegralType(element_type)) {
+ // NVPTX supports atomicMax and atomicMin only on integer types.
+ if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
// max(integral, integral)
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
@@ -222,8 +223,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
return true;
}
- if (root_opcode == HloOpcode::kMinimum &&
- primitive_util::IsIntegralType(element_type)) {
+ if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
// min(integral, integral)
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f6f0a45124..fbd647f251 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -615,6 +615,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
output_shape_index = {i};
}
if (inst->opcode() == HloOpcode::kReduce) {
+ CHECK(IsReductionToVector(*inst))
+ << "Only reductions to vector are supported";
// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
@@ -1970,10 +1972,8 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
HloComputation* reducer = reduce->to_apply();
// HandleReduce specializes reduction from a multi-dimensional array to a 1D
// array. The specialized version requires an initializer thunk that
- // ingitializes the output array to the initial value of the reduce.
- if (IsReductionToVector(*reduce) &&
- // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits
- 32 <= primitive_util::BitWidth(reduce->shape().element_type())) {
+ // initializes the output array to the initial value of the reduce.
+ if (IsReductionToVector(*reduce)) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
BuildInitializerThunk(reduce));
std::vector<std::unique_ptr<Thunk>> thunks;
@@ -2311,7 +2311,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
return Status::OK();
}
-Status IrEmitterUnnested::HandleGenerateToken(HloInstruction* gen_token) {
+Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) {
return Status::OK();
}
@@ -2563,17 +2563,14 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
- std::vector<BufferAllocation::Slice> tuple_element_buffers;
- for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
- BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment()
- .GetUniqueSlice(inst, {i})
- .ConsumeValueOrDie();
- tuple_element_buffers.push_back(buffer);
- }
-
- return MakeUnique<InfeedThunk>(
- tuple_element_buffers,
- /*destination_buffer=*/GetAllocationSlice(*inst), inst);
+ ShapeTree<BufferAllocation::Slice> slices(inst->shape());
+ slices.ForEachMutableElement(
+ [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ *slice = ir_emitter_context_->buffer_assignment()
+ .GetUniqueSlice(inst, index)
+ .ConsumeValueOrDie();
+ });
+ return MakeUnique<InfeedThunk>(slices, inst);
}
namespace {
@@ -2718,7 +2715,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint8 b = literal_bytes.front();
pattern16 = uint16{b} | (uint16{b} << 8);
} else {
- pattern16 = literal_bytes.front();
+ memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {MakeUnique<Memset32BitValueThunk>(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 279a5c386a..819060061a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -76,7 +76,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleGenerateToken(HloInstruction* gen_token) override;
+ Status HandleAfterAll(HloInstruction* gen_token) override;
Status EmitTargetElementLoop(
const HloInstruction& hlo,
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index d541776f00..652b5c7687 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -23,9 +23,11 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -69,6 +71,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// In that case, the operand of the reduce needs to have the same shape
// as the other tuple operands, but also we need to compare the output
// shapes of the reduces.
+ // TODO(tjoerg): Allow differences in fp precision.
auto* element_instr_1 = get_element_instr(instr1);
auto* element_instr_2 = get_element_instr(instr2);
if (element_instr_1->opcode() == HloOpcode::kReduce &&
@@ -82,26 +85,33 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
}
namespace {
-bool IsReduction(HloInstruction* instr) {
+bool IsInputFusibleReduction(HloInstruction* instr) {
if (instr->IsMultiOutputFusion()) {
for (const HloInstruction* operand :
instr->fused_expression_root()->operands()) {
if (operand->opcode() == HloOpcode::kReduce) {
+ CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Reduce multi-output fusion " << instr->ToString()
+ << " must be an input fusion.";
return true;
}
}
return false;
} else if (instr->opcode() == HloOpcode::kFusion) {
- return instr->fused_expression_root()->opcode() == HloOpcode::kReduce;
+ // The loop emitter can handle to-vector reduce fusions. Such reduce
+ // fusions have the fusion kind kLoop rather than kInput. We do not fuse
+ // to-vector reduce fusions, because the resulting fusions may no longer be
+ // supported by loop emitter.
+ return IsReductionToVector(*instr->fused_expression_root());
} else {
- return instr->opcode() == HloOpcode::kReduce;
+ return IsReductionToVector(*instr);
}
}
} // namespace
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
// We can fuse reduces and loop fusions.
- return IsReduction(instr) ||
+ return IsInputFusibleReduction(instr) ||
(instr->opcode() == HloOpcode::kFusion &&
instr->fusion_kind() == HloInstruction::FusionKind::kLoop &&
// TODO(b/110202584): bitcasts make nested fusions, GPU has no support
@@ -147,5 +157,110 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop;
}
+bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
+ bool changed = false;
+ RecomputeReachability();
+
+ tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
+ // Keep a list of the instructions to fuse after making all the fusion
+ // decisions. We first aggressively add instructions to potential_fusion_list,
+ // then filter out instructions that will be no longer fusable because of
+ // reachability change. This avoids recalculating reachability on a large set
+ // of instructions.
+ std::vector<std::pair<HloInstruction*, HloInstruction*>>
+ potential_fusion_list;
+ std::vector<std::pair<HloInstruction*, HloInstruction*>> fusion_list;
+ std::vector<HloInstruction*> instrs_to_update_reachability;
+
+ // For each reduce or reduce multi-output fusion, try to fuse it with loop
+ // fusions operands.
+ for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) {
+ if (consumer->user_count() == 0) {
+ continue;
+ }
+ if (!IsInputFusibleReduction(consumer)) {
+ continue;
+ }
+
+ auto consumer_operands = consumer->operands();
+ for (size_t i = 0; i < consumer_operands.size(); ++i) {
+ HloInstruction* producer = consumer_operands[i];
+ if (!producer->IsFusable()) {
+ continue;
+ }
+ const bool is_loop_fusion =
+ producer->opcode() == HloOpcode::kFusion &&
+ producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
+ if (!is_loop_fusion) {
+ continue;
+ }
+ if (!ShapesCompatibleForFusion(producer, consumer)) {
+ continue;
+ }
+ // If we have already decided to fuse this producer, skip it.
+ if (ContainsKey(to_fuse, producer)) {
+ continue;
+ }
+ // Do not fuse a producer if the other operands of the fusion are
+ // reachable from the producer, this would create a cycle.
+ if (c_any_of(consumer_operands, [&](HloInstruction* operand) {
+ return producer != operand &&
+ reachability()->IsReachable(producer, operand);
+ })) {
+ break;
+ }
+ to_fuse.insert(producer);
+ potential_fusion_list.emplace_back(producer, consumer);
+ instrs_to_update_reachability.push_back(producer);
+ instrs_to_update_reachability.push_back(consumer);
+ break;
+ }
+ }
+
+ // Filter out pairs that will be no longer fusable because of reachability
+ // change.
+ for (auto& fusion_pair : potential_fusion_list) {
+ HloInstruction* producer = fusion_pair.first;
+ HloInstruction* consumer = fusion_pair.second;
+ if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) {
+ return producer != operand &&
+ reachability()->IsReachable(producer, operand);
+ })) {
+ UpdateReachability(producer, consumer, instrs_to_update_reachability);
+ fusion_list.push_back(fusion_pair);
+ }
+ }
+
+ for (auto fusions_to_create : fusion_list) {
+ HloInstruction* producer = fusions_to_create.first;
+ HloInstruction* consumer = fusions_to_create.second;
+ if (consumer->opcode() != HloOpcode::kFusion) {
+ // Fusing with a reduce (fusion) always results in an input fusion.
+ HloInstruction* input_fusion =
+ computation()->AddInstruction(HloInstruction::CreateFusion(
+ consumer->shape(), HloInstruction::FusionKind::kInput, consumer));
+ VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
+ << consumer->name() << " into " << input_fusion->name();
+ TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion));
+ if (producer->opcode() == HloOpcode::kFusion) {
+ input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ input_fusion->FuseInstructionIntoMultiOutput(producer);
+ }
+ } else {
+ VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
+ << consumer->name();
+
+ if (producer->opcode() == HloOpcode::kFusion) {
+ consumer->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ consumer->FuseInstructionIntoMultiOutput(producer);
+ }
+ }
+ changed = true;
+ }
+ return changed;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index 16db0e0f02..67ca5d49ee 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -45,6 +45,9 @@ class GpuMultiOutputFusion : public MultiOutputFusion {
// Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override;
+
+ // Fuse loop fusions into reduce fusions.
+ bool DoProducerConsumerMultiOutputFusion() override;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 5e7ceb7976..979ea79243 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -255,5 +255,99 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_add {
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
+ reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Add()));
+}
+
+TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
+ greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
+ }
+
+ fused_reduce {
+ p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+ c1 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation
+ mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
+ gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
+}
+
+TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_element_wise {
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+ }
+
+ fused_reduce {
+ p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+ mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2)
+ c1 = f32[] constant(0)
+ ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c057be8201..34b18b0e21 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter(
return instructions_.back().get();
}
+namespace {
+
+// Returns the new name for a fusion parameter when we change its number.
+//
+// Fusion parameters are named foo.param_1, bar.param_2, etc. We are
+// renumbering the parameters, so replace the final number in the name with
+// the updated value.
+string RenameFusionParameter(const string& original_name, int64 new_param_no) {
+ const string param_underscore = ".param_";
+ size_t index = original_name.rfind(param_underscore);
+ if (index == string::npos) {
+ return original_name;
+ }
+ string after_param = original_name.substr(index + param_underscore.size());
+ int64 numeric_suffix;
+ if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ return StrCat(original_name.substr(0, index + param_underscore.size()),
+ new_param_no);
+ }
+ return original_name;
+}
+
+} // namespace
+
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
@@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- string param_name = param_instruction->name();
- // Fusion parameters are named foo.param_1, bar.param_2, etc. We are
- // renumbering the parameters, so replace the final number in the name with
- // the updated value.
- const string param_underscore = ".param_";
- size_t index = param_name.rfind(param_underscore);
- if (index == string::npos) {
- string after_param = name().substr(index + param_underscore.size());
- int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
- param_name =
- StrCat(param_name.substr(0, index), param_underscore, param_no);
- }
- }
-
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
@@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) {
return Status::OK();
}
+Status HloComputation::RemoveUnusedParameters() {
+ CHECK(IsFusionComputation());
+ int64 removed = 0;
+ for (int64 i = 0; i < param_instructions_.size(); ++i) {
+ HloInstruction* param_instruction = param_instructions_[i];
+ if (param_instruction->user_count() == 0 &&
+ param_instruction != root_instruction()) {
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ ++removed;
+ continue;
+ }
+
+ if (removed > 0) {
+ const int64 param_no = i - removed;
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
+ HloInstruction* new_instr =
+ AddInstructionInternal(HloInstruction::CreateParameter(
+ param_no, param_instruction->shape(), param_name));
+ TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
+ param_instructions_[param_no] = new_instr;
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ }
+ }
+ param_instructions_.resize(param_instructions_.size() - removed);
+ return Status::OK();
+}
+
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 0f111a1a76..c1c3e79ebc 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -113,6 +113,11 @@ class HloComputation {
// instruction.
Status RemoveParameter(int64 param_no);
+ // Remove unused parameters from the computation.
+ // Note this is only applicatable to the computation for the fusion
+ // instruction.
+ Status RemoveUnusedParameters();
+
// Add new parameter instruction to the computation.
// This should be a new parameter. Instruction will be appended to parameters
// and inserted to the instruction list.
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index c504fc51d2..a8f3f0e9c2 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -375,20 +375,20 @@ TEST_F(HloComputationTest, DeepCopyToken) {
// Test that DeepCopyInstruction properly handles tokens which should not be
// copied.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
// No copy should be added.
- EXPECT_THAT(copy, op::GenerateToken());
+ EXPECT_THAT(copy, op::AfterAll());
}
TEST_F(HloComputationTest, DeepCopyTokenTuple) {
// Test that DeepCopyInstruction properly handles tokens which should not be
// copied.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
auto tuple =
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 762e1afc71..8955e26d5c 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -393,7 +393,7 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) {
+Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 0d66736fe1..44e5df587c 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -97,7 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleBroadcast(const HloInstruction* broadcast) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
- Status HandleGenerateToken(const HloInstruction* token) override;
+ Status HandleAfterAll(const HloInstruction* token) override;
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index d22bef5673..f77e880a77 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -139,7 +139,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
XlaBuilder builder("matrix_multiply");
auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -160,7 +160,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
TEST_F(HloCostAnalysisTest, Map) {
XlaBuilder builder("map");
auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
- auto result = builder.Map({input}, add_and_exp_, {0});
+ builder.Map({input}, add_and_exp_, {0});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -186,7 +186,7 @@ TEST_F(HloCostAnalysisTest, Convolution) {
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
/*x_dim=*/3}),
"kernel");
- auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);
+ builder.Conv(input, kernel, {1, 1}, Padding::kValid);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -207,8 +207,7 @@ TEST_F(HloCostAnalysisTest, Reduce) {
XlaBuilder builder("reduce");
auto input =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
- auto result =
- builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});
+ builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -225,8 +224,8 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) {
XlaBuilder builder("reduce_window");
auto input =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
- auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
- {4, 5}, {4, 5}, Padding::kValid);
+ builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_, {4, 5},
+ {4, 5}, Padding::kValid);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -244,9 +243,8 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) {
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
auto source =
builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
- auto result =
- builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
- source, builder.ConstantR0<float>(0), add_);
+ builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
+ source, builder.ConstantR0<float>(0), add_);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -278,8 +276,8 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
// sigmoid(input * weight + bias)
- auto result = builder.Map(
- {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
+ builder.Map({builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_,
+ {0, 1});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -421,7 +419,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
XlaBuilder builder("matmul");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y");
- auto tuple = builder.Tuple({x, y});
+ builder.Tuple({x, y});
auto hlo_module = BuildHloGraph(&builder);
ASSERT_IS_OK(
@@ -446,10 +444,10 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
/*x_dim=*/3}),
"kernel");
- auto result = builder.ConvGeneralDilated(
- input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}},
- /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
- XlaBuilder::CreateDefaultConvDimensionNumbers(2));
+ builder.ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1},
+ /*padding=*/{{1, 1}, {1, 1}},
+ /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2));
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -464,7 +462,7 @@ TEST_F(HloCostAnalysisTest, Slice) {
// Test the analysis on a slice.
XlaBuilder builder("slice");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x");
- auto slice = builder.Slice(x, {0}, {1}, {1});
+ builder.Slice(x, {0}, {1}, {1});
auto hlo_module = BuildHloGraph(&builder);
// Run HLO cost analysis.
@@ -479,7 +477,7 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) {
// Test the analysis on a slice.
XlaBuilder builder("dynamic-slice");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x");
- auto slice = builder.DynamicSlice(x, builder.ConstantR1<int32>({1}), {1});
+ builder.DynamicSlice(x, builder.ConstantR1<int32>({1}), {1});
auto hlo_module = BuildHloGraph(&builder);
// Run HLO cost analysis.
@@ -494,8 +492,8 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
// Test the analysis on a slice.
XlaBuilder builder("dynamic-update-slice");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x");
- auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1<float>({1.0}),
- builder.ConstantR1<int32>({1}));
+ builder.DynamicUpdateSlice(x, builder.ConstantR1<float>({1.0}),
+ builder.ConstantR1<int32>({1}));
auto hlo_module = BuildHloGraph(&builder);
// Run HLO cost analysis.
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 5a56607a66..2822ecd788 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -234,9 +234,10 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
{
auto param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
-
- auto infeed =
- body_builder.AddInstruction(HloInstruction::CreateInfeed(shape, ""));
+ auto token =
+ body_builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto infeed = body_builder.AddInstruction(
+ HloInstruction::CreateInfeed(shape, token, ""));
body_builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed));
}
@@ -278,8 +279,10 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
{
auto param = nested_callee_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
+ auto token = nested_callee_builder.AddInstruction(
+ HloInstruction::CreateAfterAll({}));
nested_callee_builder.AddInstruction(
- HloInstruction::CreateOutfeed(shape, param, ""));
+ HloInstruction::CreateOutfeed(shape, param, token, ""));
}
auto nested_called_computation =
module->AddEmbeddedComputation(nested_callee_builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 5d8081c1ef..ff356bdd6d 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -340,10 +340,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
HloModule Module
ENTRY entry {
- infeed = (f32[4], f32[4]) infeed(),
- sharding={{maximal device=1}, {maximal device=0}}
- gte0 = f32[4] get-tuple-element(infeed), index=0
- gte1 = f32[4] get-tuple-element(infeed), index=1
+ token = token[] after-all()
+ infeed = ((f32[4], f32[4]), token[]) infeed(token),
+ sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
+ infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0
+ gte0 = f32[4] get-tuple-element(infeed.data), index=0
+ gte1 = f32[4] get-tuple-element(infeed.data), index=1
copy0 = f32[4] copy(gte0)
copy1 = f32[4] copy(gte1)
ROOT add = f32[4] add(copy0, copy1)
@@ -357,8 +359,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
- EXPECT_TRUE(HasDomainEdge(module, "gte0", "infeed"));
- EXPECT_TRUE(HasDomainEdge(module, "gte1", "infeed"));
+ EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed"));
EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0"));
EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1"));
@@ -366,6 +367,8 @@ ENTRY entry {
// HLO passes adding unexpected instructions.
//
// infeed
+ // |
+ // infeed.data (tuple element 0 of infeed)
// / \
// GTE0 GTE1
// / \
@@ -374,26 +377,31 @@ ENTRY entry {
// \ /
// TUPLE
// |
- // DOMAIN
HloInstruction* infeed = FindInstruction(module, "infeed");
ASSERT_NE(infeed, nullptr);
- auto infeed_users = infeed->users();
- HloInstruction* new_gte0 =
+ HloInstruction* infeed_data =
infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0));
+
+ auto infeed_data_users = infeed_data->users();
+ HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data,
+ 0));
HloInstruction* new_copy0 =
- infeed->parent()->AddInstruction(HloInstruction::CreateUnary(
+ infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
new_gte0->shape(), HloOpcode::kCopy, new_gte0));
- HloInstruction* new_gte1 =
- infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1));
+ HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data,
+ 1));
HloInstruction* new_copy1 =
- infeed->parent()->AddInstruction(HloInstruction::CreateUnary(
+ infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
new_gte1->shape(), HloOpcode::kCopy, new_gte1));
- HloInstruction* new_tuple = infeed->parent()->AddInstruction(
+ HloInstruction* new_tuple = infeed_data->parent()->AddInstruction(
HloInstruction::CreateTuple({new_copy0, new_copy1}));
- for (HloInstruction* user : infeed_users) {
- TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple));
+ for (HloInstruction* user : infeed_data_users) {
+ TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple));
}
HloDomainRemover remover(ShardingMetadata::KindName(),
@@ -412,7 +420,7 @@ ENTRY entry {
};
for (auto& assignment : assignments) {
auto device = assignment.instruction->sharding_unique_device();
- EXPECT_TRUE(device.has_value());
+ ASSERT_TRUE(device.has_value());
EXPECT_EQ(*device, assignment.device);
}
EXPECT_TRUE(new_tuple->has_sharding());
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
index 5c5a059e0f..c170e36c73 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
@@ -57,8 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
const string& hlo_string = R"(
HloModule InfeedOutfeed
ENTRY RoundTrip16MiBR1.v2 {
- ROOT infeed = bf16[4]{0} infeed()
- outfeed = () outfeed(infeed)
+ token = token[] after-all()
+ infeed = (bf16[4]{0}, token[]) infeed(token)
+ ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
}
)";
auto module = CreateModuleFromHloString(hlo_string);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 33424019b9..deb7f28d84 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -902,7 +902,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
-Status HloEvaluator::HandleGenerateToken(HloInstruction* token) {
+Status HloEvaluator::HandleAfterAll(HloInstruction* token) {
evaluated_[token] = Literal::CreateToken();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index fc2fc9437b..2ad56080d8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -174,7 +174,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleBroadcast(HloInstruction* broadcast) override;
- Status HandleGenerateToken(HloInstruction* token) override;
+ Status HandleAfterAll(HloInstruction* token) override;
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index b349f7d46f..8856723f67 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -984,7 +984,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kBitcast:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
- case HloOpcode::kGenerateToken:
+ case HloOpcode::kAfterAll:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a07dbe6256..1c8c9a8d6d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -263,12 +263,30 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
CreateReducePrecision(proto.shape(), operands(0),
proto.exponent_bits(), proto.mantissa_bits());
break;
- case HloOpcode::kInfeed:
- instruction = CreateInfeed(proto.shape(), proto.infeed_config());
- break;
+ case HloOpcode::kInfeed: {
+ const Shape& data_shape =
+ ShapeUtil::GetTupleElementShape(proto.shape(), 0);
+ if (proto.operand_ids_size() == 0) {
+ // TODO(b/80000000): Remove this when all uses of infeed are
+ // converted to take tokens.
+ instruction = CreateInfeed(data_shape, proto.infeed_config());
+ } else {
+ CHECK_EQ(proto.operand_ids_size(), 2);
+ instruction =
+ CreateInfeed(data_shape, operands(0), proto.infeed_config());
+ }
+ } break;
case HloOpcode::kOutfeed:
- instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
- proto.outfeed_config());
+ if (proto.operand_ids_size() == 1) {
+ // TODO(b/80000000): Remove this when all uses of outfeed are
+ // converted to take tokens.
+ instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
+ proto.outfeed_config());
+ } else {
+ CHECK_EQ(proto.operand_ids_size(), 2);
+ instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
+ operands(1), proto.outfeed_config());
+ }
break;
case HloOpcode::kCrossReplicaSum: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
@@ -543,10 +561,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
- return MakeUnique<HloMapInstruction>(shape, operands, map_computation,
- static_operands);
+ HloComputation* map_computation) {
+ return MakeUnique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
@@ -610,14 +626,28 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
- const Shape& shape, const string& config) {
- return MakeUnique<HloInfeedInstruction>(shape, config);
+ const Shape& infeed_shape, HloInstruction* token_operand,
+ const string& config) {
+ return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
+ const Shape& infeed_shape, const string& config) {
+ return MakeUnique<HloInfeedInstruction>(infeed_shape, config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
- const Shape& shape, HloInstruction* operand,
+ const Shape& outfeed_shape, HloInstruction* operand,
+ HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) {
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
+ token_operand, outfeed_config);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
+ const Shape& outfeed_shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config) {
- return MakeUnique<HloOutfeedInstruction>(shape, operand, outfeed_config);
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
+ outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
@@ -652,11 +682,10 @@ HloInstruction::CreateCrossReplicaSum(
return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
}
-/* static */ std::unique_ptr<HloInstruction>
-HloInstruction::CreateGenerateToken(
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- auto instruction = WrapUnique(new HloInstruction(
- HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape()));
+ auto instruction = WrapUnique(
+ new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -1183,8 +1212,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
break;
- case HloOpcode::kGenerateToken:
- clone = CreateGenerateToken(new_operands);
+ case HloOpcode::kAfterAll:
+ clone = CreateAfterAll(new_operands);
break;
}
SetupDerivedInstruction(clone.get());
@@ -1369,6 +1398,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) {
operand->AddUser(this);
}
+void HloInstruction::RemoveOperandsAtAscendingIndices(
+ tensorflow::gtl::ArraySlice<int> ascending_indices) {
+ if (ascending_indices.empty()) {
+ return;
+ }
+ int next_index = 0;
+ int removed_count = 0;
+ for (int to_remove : ascending_indices) {
+ while (next_index < to_remove) {
+ operands_[next_index - removed_count] = operands_[next_index];
+ ++next_index;
+ }
+ CHECK_LT(to_remove, operands_.size());
+ ++removed_count;
+ ++next_index;
+ }
+ while (next_index < operands_.size()) {
+ operands_[next_index - removed_count] = operands_[next_index];
+ ++next_index;
+ }
+ CHECK_EQ(removed_count, ascending_indices.size());
+ operands_.resize(operands_.size() - removed_count);
+}
+
void HloInstruction::AddUser(HloInstruction* user) {
if (!ContainsKey(user_set_, user)) {
user_set_.insert(user);
@@ -1447,7 +1500,7 @@ bool HloInstruction::IdenticalSlowPath(
// These opcodes have complex or special behavior so just return false.
case HloOpcode::kDomain:
case HloOpcode::kWhile:
- case HloOpcode::kGenerateToken:
+ case HloOpcode::kAfterAll:
return false;
// Check dot dimension numbers.
@@ -1539,6 +1592,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
+ if (user->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(
+ Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
+ }
return Status::OK();
}
@@ -1577,6 +1634,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
+ if (user->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(
+ Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
+ }
}
}
users_.clear();
@@ -2226,8 +2287,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleGather(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
- case HloOpcode::kGenerateToken:
- return visitor->HandleGenerateToken(this);
+ case HloOpcode::kAfterAll:
+ return visitor->HandleAfterAll(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 8f59e67123..59a383218c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -389,11 +389,10 @@ class HloInstruction {
// Creates a map instruction, where the computation (given by the handle) is
// applied element-wise to every element in operands (across the operands,
- // at a given index) with the same `static_operands`.
+ // at a given index)
static std::unique_ptr<HloInstruction> CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
+ HloComputation* map_computation);
// Creates a convolution op, where rhs is the convolutional filter
// and window describes how the filter is applied to lhs.
@@ -459,13 +458,29 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand);
// Creates an infeed instruction, which reads data of the given shape from the
- // Infeed interface of the device.
- static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
+ // Infeed interface of the device. infeed_shape is the shape of the data
+ // received from the infeed *not* the shape of the infeed instruction which
+ // is a tuple containing the infeed_shape and the TOKEN.
+ static std::unique_ptr<HloInstruction> CreateInfeed(
+ const Shape& infeed_shape, HloInstruction* token_operand,
+ const string& config);
+ // Overload which does not require a token.
+ // TODO(b/80000000): Remove this overload when all uses of infeed are
+ // converted to take tokens.
+ static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& infeed_shape,
const string& config);
- // Creates an outfeed instruction, which outputs data.
+ // Creates an outfeed instruction, which outputs data. outfeed_shape is the
+ // shape of the data being outfed *not* the shape of the outfeed instruction
+ // which is a TOKEN.
static std::unique_ptr<HloInstruction> CreateOutfeed(
- const Shape& shape, HloInstruction* operand,
+ const Shape& outfeed_shape, HloInstruction* operand,
+ HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
+ // Overload which does not require a token.
+ // TODO(b/80000000): Remove this overload when all uses of infeed are
+ // converted to take tokens.
+ static std::unique_ptr<HloInstruction> CreateOutfeed(
+ const Shape& outfeed_shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
@@ -665,9 +680,9 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
- // Creates a token instruction used for joining or creating token types which
- // thread through side-effecting operations.
- static std::unique_ptr<HloInstruction> CreateGenerateToken(
+ // Creates a token instruction used for joining or creating new values of
+ // token type which thread through side-effecting operations.
+ static std::unique_ptr<HloInstruction> CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
// Creates an instance of GatherDimensionNumbers.
@@ -811,9 +826,15 @@ class HloInstruction {
// Replaces the use of this instruction in "user" with "new_producer". Note
// that there might be multiple uses of this instruction in "user"; all will
// be replaced.
+ //
+ // If user is a fusion instruction, this function will remove any duplicated
+ // operands of it which could be created due to this replacement.
Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
// Replaces the specified operand with new_operand.
+ //
+ // This function does NOT remove duplicated operands even if this instruction
+ // is a fusion, so that the existing operand numbers do not change.
Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
// Replaces all uses of this instruction with the new producer. If
@@ -822,6 +843,9 @@ class HloInstruction {
//
// If this instruction is the root of its computation, sets the computation's
// root to new_producer.
+ //
+ // If a user is a fusion instruction, this function will remove any duplicated
+ // operands of it which could be created due to this replacement.
Status ReplaceAllUsesWith(HloInstruction* new_producer);
// Performs a postorder DFS visit using this node as the root. If
@@ -1440,6 +1464,10 @@ class HloInstruction {
operands_.erase(operands_.begin() + index);
}
+ // Removes a list of operands with the given indices in ascending order.
+ void RemoveOperandsAtAscendingIndices(
+ tensorflow::gtl::ArraySlice<int> ascending_indices);
+
void AppendComputation(HloComputation* computation) {
called_computations_.push_back(computation);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 8ee24f9d92..d8ca99dfd1 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -716,10 +716,11 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
})));
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
+ auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto outfeed10 = builder.AddInstruction(
- HloInstruction::CreateOutfeed(shape10, constant, ""));
+ HloInstruction::CreateOutfeed(shape10, constant, token, ""));
auto outfeed01 = builder.AddInstruction(
- HloInstruction::CreateOutfeed(shape01, constant, ""));
+ HloInstruction::CreateOutfeed(shape01, constant, token, ""));
auto clone01 = builder.AddInstruction(outfeed01->Clone());
auto clone10 = builder.AddInstruction(outfeed10->Clone());
@@ -763,12 +764,12 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
- auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap(
- scalar_shape, {constant}, computation_x, /*static_operands=*/{}));
- auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap(
- scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{}));
- auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
- scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{}));
+ auto map_1_x = builder.AddInstruction(
+ HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
+ auto map_2_x = builder.AddInstruction(
+ HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
+ auto map_3_y = builder.AddInstruction(
+ HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
@@ -1170,6 +1171,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
}
+TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
+ // Fused expression:
+ //
+ // x y
+ // | |
+ // | transpose
+ // \ /
+ // dot
+ const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {dot, reshape}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
+
+ EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
+ EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
+}
+
TEST_F(HloInstructionTest, FusionEquality) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 803fde73a5..e2f43f5810 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -553,10 +554,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl(
HloMapInstruction::HloMapInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> static_operands)
+ HloComputation* map_computation)
: HloInstruction(HloOpcode::kMap, shape) {
- CHECK(static_operands.empty()) << "static_operands not yet supported";
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1210,6 +1209,26 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
new_fused_computation);
}
+Status HloFusionInstruction::DeduplicateFusionOperands() {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
+ std::vector<int> operands_to_remove;
+ for (int i = 0; i < operand_count(); ++i) {
+ auto emplace_result = operand_indices.emplace(operand(i), i);
+ if (!emplace_result.second) {
+ TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
+ fused_parameter(emplace_result.first->second)));
+ operands_to_remove.push_back(i);
+ }
+ }
+ if (operands_to_remove.empty()) {
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ fused_instructions_computation()->RemoveUnusedParameters());
+ RemoveOperandsAtAscendingIndices(operands_to_remove);
+ return Status::OK();
+}
+
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters)
@@ -1365,9 +1384,22 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], exponent_bits(), mantissa_bits());
}
-HloInfeedInstruction::HloInfeedInstruction(const Shape& shape,
+HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
+ HloInstruction* token_operand,
const string& config)
- : HloInstruction(HloOpcode::kInfeed, shape), infeed_config_(config) {}
+ : HloInstruction(HloOpcode::kInfeed,
+ ShapeUtil::MakeTupleShape(
+ {infeed_shape, ShapeUtil::MakeTokenShape()})),
+ infeed_config_(config) {
+ AppendOperand(token_operand);
+}
+
+HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
+ const string& config)
+ : HloInstruction(HloOpcode::kInfeed,
+ ShapeUtil::MakeTupleShape(
+ {infeed_shape, ShapeUtil::MakeTokenShape()})),
+ infeed_config_(config) {}
HloInstructionProto HloInfeedInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
@@ -1395,19 +1427,37 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 0);
- return MakeUnique<HloInfeedInstruction>(shape, infeed_config());
+ if (new_operands.empty()) {
+ return MakeUnique<HloInfeedInstruction>(infeed_shape(), infeed_config());
+ } else {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
+ infeed_config());
+ }
}
HloOutfeedInstruction::HloOutfeedInstruction(
- const Shape& shape, HloInstruction* operand,
+ const Shape& outfeed_shape, HloInstruction* operand,
+ HloInstruction* token_operand, tensorflow::StringPiece outfeed_config)
+ : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
+ outfeed_shape_(outfeed_shape),
+ outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
+ << "Outfeed shape " << outfeed_shape
+ << " must be compatible with operand shape " << operand->shape();
+ AppendOperand(operand);
+ AppendOperand(token_operand);
+}
+
+HloOutfeedInstruction::HloOutfeedInstruction(
+ const Shape& outfeed_shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config)
- : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()),
- outfeed_shape_(shape),
+ : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
+ outfeed_shape_(outfeed_shape),
outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
- CHECK(ShapeUtil::Compatible(operand->shape(), shape))
- << "Outfeed shape " << shape << " must be compatible with operand shape "
- << operand->shape();
+ CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
+ << "Outfeed shape " << outfeed_shape
+ << " must be compatible with operand shape " << operand->shape();
AppendOperand(operand);
}
@@ -1438,9 +1488,14 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
- outfeed_config());
+ if (new_operands.size() == 1) {
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
+ outfeed_config());
+ } else {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
+ new_operands[1], outfeed_config());
+ }
}
HloConvolutionInstruction::HloConvolutionInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 1a2e4ae0a5..ec8a42bd3b 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -407,8 +407,7 @@ class HloMapInstruction : public HloInstruction {
public:
explicit HloMapInstruction(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
+ HloComputation* map_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -636,6 +635,9 @@ class HloFusionInstruction : public HloInstruction {
void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
+ // If multiple operands are the same instruction, keeps only one of them.
+ Status DeduplicateFusionOperands();
+
private:
// Fuses the given instruction into this fusion instruction. When add_output
// is false (which is the default), instruction_to_fuse is cloned and the
@@ -785,12 +787,25 @@ class HloReducePrecisionInstruction : public HloInstruction {
class HloInfeedInstruction : public HloInstruction {
public:
- explicit HloInfeedInstruction(const Shape& shape, const string& config);
+ explicit HloInfeedInstruction(const Shape& infeed_shape,
+ HloInstruction* token_operand,
+ const string& config);
+ // TODO(b/80000000): Remove this constructor when all uses of infeed are
+ // converted to take tokens.
+ explicit HloInfeedInstruction(const Shape& infeed_shape,
+ const string& config);
// Returns the infeed configuration string. The infeed configuration includes
// any metadata needed for the backend compiler (e.g., infeed buffer address)
// and is target-dependent.
string infeed_config() const { return infeed_config_; }
void set_infeed_config(const string& config) { infeed_config_ = config; }
+ // Returns the shape of the data received by the infeed. This is not the same
+ // as the shape of the infeed instruction which produces a tuple containing
+ // the infeed data shape and a TOKEN.
+ const Shape& infeed_shape() const {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
+ return ShapeUtil::GetSubshape(shape(), {0});
+ }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -813,11 +828,19 @@ class HloInfeedInstruction : public HloInstruction {
class HloOutfeedInstruction : public HloInstruction {
public:
- explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand,
+ explicit HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ HloInstruction* token_operand,
tensorflow::StringPiece outfeed_config);
+ // TODO(b/80000000): Remove this constructor when all uses of outfeed are
+ // converted to take tokens.
+ explicit HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config);
+
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
return outfeed_shape_;
}
// Returns the config for the Outfeed instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 8a31a8e617..b57c940238 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -187,7 +187,7 @@ HLO_MATCHER(Exp);
HLO_MATCHER(Floor);
HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
-HLO_MATCHER(GenerateToken);
+HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
@@ -196,6 +196,7 @@ HLO_MATCHER(Log);
HLO_MATCHER(And);
HLO_MATCHER(Not);
HLO_MATCHER(Or);
+HLO_MATCHER(Xor);
HLO_MATCHER(Lt);
HLO_MATCHER(Map);
HLO_MATCHER(Maximum);
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 7083321276..05e47a698f 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -81,7 +81,7 @@ namespace xla {
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
- V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \
+ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
V(kHostCompute, "host-compute") \
diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
index 774345124b..6f3f83f63a 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
@@ -58,7 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
- case HloOpcode::kGenerateToken:
+ case HloOpcode::kAfterAll:
case HloOpcode::kTuple:
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 605c6ae741..57d17064c1 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -617,12 +617,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateReshape(shape, operands[0]));
break;
}
- case HloOpcode::kGenerateToken: {
+ case HloOpcode::kAfterAll: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateGenerateToken(operands));
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
break;
}
case HloOpcode::kTuple: {
@@ -978,23 +978,53 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kInfeed: {
optional<string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
- if (!ParseOperands(&operands, /*expected_size=*/0) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateInfeed(shape, config ? *config : ""));
+ // We need to know the infeed data shape to construct the infeed
+ // instruction. This is the zero-th element of the tuple-shaped output of
+ // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
+ // if the shape is not a non-empty tuple, so add guard so an error message
+ // can be emitted instead of a check fail
+ if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) {
+ return Error(lexer_.GetLoc(),
+ "infeed must have a non-empty tuple shape");
+ }
+
+ if (operands.empty()) {
+ // TODO(b/80000000): Remove this when all uses of infeed are
+ // converted to take tokens.
+ instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : ""));
+ } else if (operands.size() == 1) {
+ instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
+ config ? *config : ""));
+ } else {
+ return Error(lexer_.GetLoc(),
+ "infeed must have exactly zero or one operands");
+ }
break;
}
case HloOpcode::kOutfeed: {
optional<string> config;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
- operands[0]->shape(), operands[0], config ? *config : ""));
+ if (operands.size() == 1) {
+ // TODO(b/80000000): Remove this when all uses of outfeed are
+ // converted to take tokens.
+ instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
+ operands[0]->shape(), operands[0], config ? *config : ""));
+ } else if (operands.size() == 2) {
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
+ operands[1], config ? *config : ""));
+ } else {
+ return Error(lexer_.GetLoc(),
+ "outfeed must have exactly one or two operands");
+ }
break;
}
case HloOpcode::kRng: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index d481e07f60..da1a34ae3c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -795,10 +795,14 @@ ENTRY ReduceR3ToR2.v3 {
R"(HloModule outfeed_module
ENTRY InfeedToOutfeed {
- infeed = (u32[3]{0}, pred[]) infeed()
- outfeed = () outfeed(infeed)
- ROOT infeed.1 = (u32[3]{0}, pred[]) infeed()
- outfeed.1 = () outfeed(infeed.1)
+ token = token[] after-all()
+ infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
+ ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
+ infeed.1.token = token[] get-tuple-element(infeed.1), index=1
+ outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
}
)"
@@ -1418,5 +1422,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
+TEST_F(HloParserTest, NontupleInfeed) {
+ const string original = R"(HloModule nontuple_infeed:
+ENTRY nontuple_infeed {
+ token = token[] after-all()
+ ROOT infeed = pred[] infeed(token)
+})";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "infeed must have a non-empty tuple shape");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 1d6cd4cb23..fb39c6f085 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include <set>
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -106,22 +108,50 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
+ HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
+ // Infeed has an optional single token operand.
+ // TODO(b/80000000): Update when token is not optional.
+ if (infeed->operand_count() == 1 &&
+ !ShapeUtil::Equal(infeed->operand(0)->shape(),
+ ShapeUtil::MakeTokenShape())) {
+ return InternalError(
+ "Expected infeed operand to be token-shaped, actual shape is %s:\n%s",
+ ShapeUtil::HumanString(infeed->operand(0)->shape()).c_str(),
+ infeed->ToString().c_str());
+ }
+
+ // The output of infeed is a tuple containing the data value and a token.
+ return CheckShape(infeed,
+ ShapeUtil::MakeTupleShape(
+ {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
+}
+
+Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
+ HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
+ // Outfeed has an optional token operand (operand 1).
+ // TODO(b/80000000): Update when token is not optional.
+ if (outfeed->operand_count() == 2 &&
+ !ShapeUtil::Equal(outfeed->operand(1)->shape(),
+ ShapeUtil::MakeTokenShape())) {
+ return InternalError(
+ "Expected operand 1 of outfeed to be a token, actual shape is %s:\n%s",
+ ShapeUtil::HumanString(outfeed->operand(1)->shape()).c_str(),
+ outfeed->ToString().c_str());
+ }
-Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
// Outfeed has a separate shape field for the value which is outfed to the
- // host. The shape of the instruction itself is always nil because the outfeed
- // produces no HLO value in the graph.
+ // host. The shape of the instruction itself is always a token.
if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed to have shape compatible with operand's shape %s, "
+ "Expected outfeed shape to be compatible with operand's shape %s, "
"actual shape is %s:\n%s",
ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
outfeed->ToString().c_str());
}
- return CheckShape(outfeed, ShapeUtil::MakeNil());
+ return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
@@ -426,13 +456,12 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
-Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) {
+Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : token->operands()) {
operand_shapes.push_back(&operand->shape());
}
- return CheckShape(token,
- ShapeInference::InferGenerateTokenShape(operand_shapes));
+ return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes));
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
@@ -786,8 +815,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
const Shape& out_shape = instruction->shape();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
- if (!ShapeUtil::IsScalar(operand_shape) &&
- !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
+ if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
"Found non-compatible shapes for instruction %s.\n"
@@ -815,9 +843,10 @@ bool ShapeContainsToken(const Shape& shape) {
}
// Verifies that all types entering and exiting the entry computation are
-// legal. For example, TOKEN types have no Literal representation and cannot be
-// on the interface of the entry computation (parameters and root instruction).
+// legal.
Status VerifyEntryAndExitShapes(const HloModule& module) {
+ // Tokens cannot be passed as entry parameters.
+ // TODO(b/80000000): Remove this constraint.
for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
HloInstruction* param =
module.entry_computation()->parameter_instruction(i);
@@ -827,14 +856,6 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
ShapeUtil::HumanString(param->shape()).c_str());
}
}
- if (ShapeContainsToken(
- module.entry_computation()->root_instruction()->shape())) {
- return InternalError(
- "Entry root is or contains a token shape: %s",
- ShapeUtil::HumanString(
- module.entry_computation()->root_instruction()->shape())
- .c_str());
- }
return Status::OK();
}
@@ -881,7 +902,9 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
} else if (instruction->opcode() == HloOpcode::kWhile) {
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
- } else if (instruction->IsElementwise()) {
+ } else if (instruction->opcode() !=
+ HloOpcode::kRng /* Rng operands are always scalar. */
+ && instruction->IsElementwise()) {
TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction));
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 7283b3e7dc..da6b5d2222 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -81,7 +81,7 @@ class ShapeVerifier : public DfsHloVisitor {
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
- Status HandleGenerateToken(HloInstruction* token) override;
+ Status HandleAfterAll(HloInstruction* token) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 9ac8635767..088cc26226 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -97,7 +97,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
- case HloOpcode::kGenerateToken:
+ case HloOpcode::kAfterAll:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return false;
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 62599b376a..67e2cf6c77 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -770,9 +770,13 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
false_builder.AddInstruction(
HloInstruction::CreateParameter(0, tshape, "param"));
// Using infeed as layout assignment does not mess up with it.
- auto infeed =
- false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, ""));
- false_builder.AddInstruction(HloInstruction::CreateTuple({infeed}));
+ auto token =
+ false_builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto infeed = false_builder.AddInstruction(
+ HloInstruction::CreateInfeed(xshape, token, ""));
+ auto infeed_data = false_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
+ false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
}
HloComputation* false_computation =
module->AddEmbeddedComputation(false_builder.Build());
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index 3a6a7c25f4..f6e7578a89 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -67,22 +67,17 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
has_numeric_suffix = true;
// Remove numeric suffix from root.
root = root.substr(0, separator_index);
- // Update count to at least the numeric suffix value to avoid future
- // colisions with this name.
- generated_names_[root] = std::max(generated_names_[root], numeric_suffix);
}
}
- int64* count = &(generated_names_[root]);
- if (*count == 0) {
- *count = 1;
+
+ SequentialIdGenerator& id_generator = generated_names_[root];
+ numeric_suffix = id_generator.RegisterId(numeric_suffix);
+ if (numeric_suffix == 0) {
return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0)
: root;
- } else {
- tensorflow::strings::StrAppend(&root, separator_, *count);
- // Increment lookup under old 'root' name.
- (*count)++;
- return root;
}
+ tensorflow::strings::StrAppend(&root, separator_, numeric_suffix);
+ return root;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 4139c2700b..4423d61069 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -17,10 +17,11 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_
#include <string>
-#include <unordered_map>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -44,13 +45,40 @@ class NameUniquer {
static string GetSanitizedName(const string& name);
private:
+ // Used to track and generate new identifiers for the same instruction name
+ // root.
+ class SequentialIdGenerator {
+ public:
+ SequentialIdGenerator() = default;
+
+ // Tries to register id as used identifier. If id is not already used, the
+ // id itself will be returned. Otherwise a new one will be generated, and
+ // returned.
+ int64 RegisterId(int64 id) {
+ if (used_.insert(id).second) {
+ return id;
+ }
+ while (!used_.insert(next_).second) {
+ ++next_;
+ }
+ return next_++;
+ }
+
+ private:
+ // The next identifier to be tried.
+ int64 next_ = 0;
+
+ // Set of all the identifiers which has been used.
+ tensorflow::gtl::FlatSet<int64> used_;
+ };
+
// The string to use to separate the prefix of the name from the uniquing
// integer value.
string separator_;
- // Map from name prefix to the number of names generated using that prefix
- // so far.
- std::unordered_map<string, int64> generated_names_;
+ // Map from name prefix to the generator data structure which tracks used
+ // identifiers and generates new ones.
+ tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_;
TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
};
diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc
index 2ec255558c..3e2592c6ac 100644
--- a/tensorflow/compiler/xla/service/name_uniquer_test.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc
@@ -54,12 +54,13 @@ TEST_F(NameUniquerTest, NumericSuffixes) {
EXPECT_EQ("foo", uniquer.GetUniqueName("foo"));
EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54"));
- EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo"));
+ EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo"));
EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1"));
- EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1"));
- EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000"));
- EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000"));
- EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1"));
+ EXPECT_EQ("foo.55.0", uniquer.GetUniqueName("foo.55.1"));
+ EXPECT_EQ("bar.1000", uniquer.GetUniqueName("bar.1000"));
+ EXPECT_EQ("bar.2000", uniquer.GetUniqueName("bar.2000"));
+ EXPECT_EQ("bar.-2000", uniquer.GetUniqueName("bar.-2000"));
+ EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.1"));
}
TEST_F(NameUniquerTest, PrefixHasSuffix) {
@@ -77,12 +78,12 @@ TEST_F(NameUniquerTest, Sanitize) {
EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54"));
EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54"));
EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1"));
- EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo"));
+ EXPECT_EQ("foo_2", uniquer.GetUniqueName("foo"));
// Invalid characters will be replaced with '_'.
- EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000"));
- EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000"));
- EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1"));
+ EXPECT_EQ("bar_1000", uniquer.GetUniqueName("bar<1000"));
+ EXPECT_EQ("bar_2000", uniquer.GetUniqueName("bar<2000"));
+ EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar_1"));
// Separator is only recognized in the middle of the prefix.
EXPECT_EQ("_10", uniquer.GetUniqueName(
@@ -93,5 +94,15 @@ TEST_F(NameUniquerTest, Sanitize) {
EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_"));
}
+TEST_F(NameUniquerTest, KeepNamesInRandomOrder) {
+ NameUniquer uniquer(".");
+
+ EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11"));
+ EXPECT_EQ("foo.10", uniquer.GetUniqueName("foo.10"));
+ EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo.1"));
+ EXPECT_EQ("foo.12", uniquer.GetUniqueName("foo.12"));
+ EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index bbc95f8630..096bbde922 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -329,7 +329,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
-/* static */ StatusOr<Shape> ShapeInference::InferGenerateTokenShape(
+/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index eef6e62fc8..ad34a2aa18 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -216,11 +216,11 @@ class ShapeInference {
static StatusOr<Shape> InferConcatOpShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
- // Infers the shape produced by a kGenerateToken operation. Trivially this
- // shape is always a TOKEN shape. However, ShapeInference serves two purposes:
- // inferring shapes and checking operand shapes. This method verifies that the
- // operand shapes are all TOKENs.
- static StatusOr<Shape> InferGenerateTokenShape(
+ // Infers the shape produced by a kAfterAll. Trivially this shape is always a
+ // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
+ // and checking operand shapes. This method verifies that the operand shapes
+ // are all TOKENs.
+ static StatusOr<Shape> InferAfterAllShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
// Helper that validates the given operand shape can be converted to the
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 8831c513ee..23519e445e 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest,
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
- Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ Shape while_shape =
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
@@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ HloInstruction* in_token = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(token_shape, param, 2));
+ HloInstruction* out_token = builder.AddInstruction(
+ HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, ""));
builder.AddInstruction(
- HloInstruction::CreateOutfeed(scalar_s32, gte_0, ""));
- builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
+ HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
+ auto* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_s32, "param"));
+ auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto* init_value = builder.AddInstruction(
- HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
-
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
- EXPECT_FALSE(simplified_loop);
+ ASSERT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(),
Contains(op::Outfeed()));
@@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
// bitcast either.
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
- Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ Shape while_shape =
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
@@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ HloInstruction* in_token = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(token_shape, param, 2));
HloInstruction* bitcast_inst = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
+ HloInstruction* out_token = builder.AddInstruction(
+ HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, ""));
builder.AddInstruction(
- HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, ""));
- builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
+ HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
+ auto* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_s32, "param"));
+ auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto* init_value = builder.AddInstruction(
- HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
module().AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 619e87caa5..0536c99b67 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -208,8 +208,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- while_body->AddInstruction(
- HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
+ auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
+ while_body->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index d79d329721..2ccb919acf 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -179,7 +179,9 @@ body {
cond {
param.c = (s32[], s32[]) parameter(0)
- ROOT condition = pred[] infeed()
+ token = token[] after-all()
+ infeed = (pred[], token[]) infeed(token)
+ ROOT condition = pred[] get-tuple-element(infeed), index=0
}
ENTRY main {
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 98c3095499..e827ec5a22 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -94,8 +95,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
// Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also
// match.
-bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
- if (!ShapeUtil::SameElementType(lhs, rhs)) {
+bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
+ bool ignore_fp_precision) {
+ if ((ignore_fp_precision &&
+ !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
+ (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) {
VLOG(3) << "CompareShapes: lhs element type != rhs element type";
return false;
}
@@ -103,7 +107,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
if (ShapeUtil::IsTuple(lhs)) {
return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) {
- return CompareShapes(l, r, compare_layouts);
+ return CompareShapes(l, r, compare_layouts,
+ ignore_fp_precision);
});
} else if (!ShapeUtil::IsArray(lhs)) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
@@ -170,7 +175,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
} // namespace
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
- bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
+ bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
+ /*ignore_fp_precision=*/false);
if (!equal && VLOG_IS_ON(3)) {
VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
<< ", rhs = " << rhs.ShortDebugString();
@@ -179,6 +185,18 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return equal;
}
+/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
+ const Shape& rhs) {
+ bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
+ /*ignore_fp_precision=*/true);
+ if (!equal && VLOG_IS_ON(3)) {
+ VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
+ << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
+ }
+
+ return equal;
+}
+
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
CHECK(ShapeUtil::IsArray(shape))
<< "Non-arrays do not have a rank, shape: " << shape;
@@ -665,7 +683,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
- return CompareShapes(lhs, rhs, /*compare_layouts=*/false);
+ return CompareShapes(lhs, rhs, /*compare_layouts=*/false,
+ /*ignore_fp_precision=*/false);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
@@ -867,6 +886,50 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
}
+ TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
+ return Status::OK();
+}
+
+/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
+ VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
+ auto invalid_argument =
+ InvalidArgument("Shape %s size may overflow int64.",
+ ShapeUtil::HumanString(shape).c_str());
+ if (!IsArray(shape)) {
+ return Status::OK();
+ }
+ int64 shape_size;
+ if (LayoutUtil::IsSparseArray(shape)) {
+ shape_size = LayoutUtil::MaxSparseElements(shape.layout());
+ shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape));
+ if (shape_size < 0) {
+ return invalid_argument;
+ }
+ shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64));
+ if (shape_size < 0) {
+ return invalid_argument;
+ }
+ }
+
+ // This is intentionally unconditional: even if the shape is sparse, we want
+ // to verify the densified version has a reasonable size.
+ if (shape.dimensions().empty()) {
+ return Status::OK();
+ }
+ shape_size = 1;
+ for (int64 dim : shape.dimensions()) {
+ shape_size = MultiplyWithoutOverflow(shape_size, dim);
+ if (shape_size < 0) {
+ return invalid_argument;
+ }
+ }
+ shape_size = MultiplyWithoutOverflow(
+ shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
+ if (shape_size < 0) {
+ return invalid_argument;
+ }
+
+ VLOG(3) << "Shape size is valid: " << shape_size;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 02e4f41505..5ae04451d3 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -280,6 +280,9 @@ class ShapeUtil {
// Returns whether the lhs and rhs shapes are identical protobufs.
static bool Equal(const Shape& lhs, const Shape& rhs);
+ // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
+ static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
+
// Returns the rank (number of dimensions) of the given shape.
// Precondition: !IsTuple(shape)
static int64 Rank(const Shape& shape);
@@ -699,6 +702,10 @@ class ShapeUtil {
static size_t Hash(const Shape& shape);
private:
+ // Validates the shape size is sane. This makes sure it's safe to do
+ // calculations in int64 without overflowing.
+ static Status ValidateShapeSize(const Shape& shape);
+
// Validates all of the non-layout properties of the shape -- this is a helper
// used by both the layout-optional and layout-required public method.
static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 606f7492ce..b6f30af381 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -242,6 +242,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
}
+TEST(ShapeUtilTest, EqualIgnoringFpPrecision) {
+ EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
+}
+
+TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
+ EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
+ EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
+ EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
+}
+
TEST(ShapeUtilTest, CompatibleTuples) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 8ac771ae5a..0aaa990503 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -282,7 +282,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
- auto sub = b.Sub(lhs_param, rhs_param);
+ b.Sub(lhs_param, rhs_param);
std::vector<int64> expected(lhs.size());
for (int64 i = 0; i < lhs.size(); ++i) {
@@ -2456,7 +2456,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
// comparison.
auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1});
auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0});
- auto result = builder.Tuple({cmp_dim_0, cmp_dim_1});
+ builder.Tuple({cmp_dim_0, cmp_dim_1});
auto expected = Literal::MakeTuple(
{Literal::CreateR2<bool>({{true, true}, {true, false}}).get(),
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index f3dac75a44..3489514fe8 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -252,7 +252,7 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
-XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
+XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
const int kFeatureIndex = 2;
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index ca337e7884..9d4f723ed6 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -92,8 +92,8 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
auto offset = builder.ConstantR1<bfloat16>(
{static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
- auto tuple = builder.BatchNormTraining(operand, scale, offset,
- /*epsilon=*/0.001, kFeatureIndex);
+ builder.BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001,
+ kFeatureIndex);
auto expected = Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 3a0f51fc66..1a7f188346 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -262,7 +262,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input");
auto r3_parameter = builder.Parameter(1, r3_shape, "input");
- XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
+ BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
spec.output_bounds[2]);
@@ -516,7 +516,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
XlaOp op1 =
BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
- XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
+ BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index 660ff0cad5..7c73e80d69 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -40,7 +40,7 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
auto p0 = builder.Parameter(0, param_literal->shape(), "param0");
auto p1 = builder.Parameter(1, param_literal->shape(), "param1");
- auto add = builder.Add(p0, p1);
+ builder.Add(p0, p1);
auto param0_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
@@ -79,7 +79,7 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1");
- auto add = builder.Mul(p0, p1);
+ builder.Mul(p0, p1);
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 916ffadbc7..1b929d7d2f 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -109,7 +109,7 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- auto constant = builder.ConstantLiteral(
+ builder.ConstantLiteral(
*Literal::CreateR3FromArray3D<float>(Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@@ -125,8 +125,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- auto constant =
- builder.ConstantLiteral(*Literal::CreateR3FromArray3D<float>(array3d));
+ builder.ConstantLiteral(*Literal::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 3a885b4389..ba5ba3a82f 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -478,7 +478,7 @@ XLA_TEST_F(ConvertTest, ConvertBF16F32) {
xla::XlaOp all_bfloats_bf16 = builder.ConstantR1<bfloat16>(all_bfloats);
xla::XlaOp all_bfloats_f32 =
builder.ConvertElementType(all_bfloats_bf16, F32);
- xla::XlaOp all_bfloats_u32 = builder.BitcastConvertType(all_bfloats_f32, U32);
+ builder.BitcastConvertType(all_bfloats_f32, U32);
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 2b3390ca98..b20499f252 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -248,7 +248,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
auto empty = Literal::CreateFromShape(in_shape);
XlaBuilder builder(TestName());
- auto param0 = builder.Parameter(0, in_shape, "input");
+ builder.Parameter(0, in_shape, "input");
auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 0fd846cef8..6a2c581aec 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -89,7 +89,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) {
auto lhs = builder.ConstantR1<T>({});
auto rhs = builder.ConstantR1<T>({});
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {},
this->error_spec_);
@@ -104,7 +104,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
XlaBuilder builder(this->TestName());
auto lhs = builder.ConstantR2FromArray2D<T>({{3.0f, 4.0f}});
auto rhs = builder.ConstantFromArray<T>({3.0f, 4.0f});
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {},
this->error_spec_);
@@ -115,7 +115,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) {
XlaBuilder builder(this->TestName());
auto lhs = builder.ConstantR1<T>({static_cast<T>(2.0f)});
auto rhs = builder.ConstantR1<T>({static_cast<T>(3.0f)});
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {},
this->error_spec_);
@@ -126,7 +126,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) {
XlaBuilder builder(this->TestName());
auto lhs = builder.ConstantFromArray<T>({1.0f, 2.5f, 42.0f});
auto rhs = builder.ConstantFromArray<T>({11.0f, -1.0f, 0.5f});
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {},
this->error_spec_);
@@ -141,7 +141,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
XlaBuilder builder(this->TestName());
auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2));
auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0));
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {},
this->error_spec_);
@@ -153,7 +153,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2));
auto rhs = builder.ConstantR2FromArray2D<T>(
{{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {},
this->error_spec_);
@@ -165,7 +165,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
auto lhs = builder.ConstantR2FromArray2D<T>(
{{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0));
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {},
this->error_spec_);
@@ -176,7 +176,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) {
XlaBuilder builder(this->TestName());
auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0));
auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2));
- auto result = builder.Dot(lhs, rhs);
+ builder.Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(
&builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
@@ -190,7 +190,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) {
auto param1 =
builder.Parameter(1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
auto exp0 = builder.Exp(param0);
- auto result = builder.Dot(exp0, param1);
+ builder.Dot(exp0, param1);
auto lhs_handle =
this->client_
@@ -231,7 +231,7 @@ class SquareMatrixDot : public DotOperationTest {
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
- auto result = builder.Dot(
+ builder.Dot(
builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
@@ -492,7 +492,7 @@ class NonsquareMatrixDot : public DotOperationTest {
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
- auto result = builder.Dot(
+ builder.Dot(
builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
@@ -524,7 +524,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
- auto result = builder.Dot(
+ builder.Dot(
builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
@@ -626,7 +626,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
- auto out = builder.DotGeneral(x, y, dnums);
+ builder.DotGeneral(x, y, dnums);
auto x_data =
this->client_
@@ -690,7 +690,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
if (transpose_rhs) {
rhs_arg = builder.Transpose(rhs_arg, {1, 0});
}
- auto result = builder.Dot(lhs_arg, rhs_arg);
+ builder.Dot(lhs_arg, rhs_arg);
Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
@@ -720,8 +720,8 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
"rhs_arg_1");
auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}),
"rhs_arg_2");
- auto result = builder.Dot(
- lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
+ builder.Dot(lhs_constant,
+ builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
std::unique_ptr<Array2D<T>> arg_0_value_array(
new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
@@ -768,8 +768,8 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
"lhs_arg_1");
auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType<T>({2, 1}),
"lhs_arg_2");
- auto result = builder.Dot(
- builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant);
+ builder.Dot(builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1),
+ rhs_constant);
std::unique_ptr<Array2D<T>> arg_0_value_array(
new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
@@ -820,7 +820,7 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{96.0, 105.0, 114.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -848,7 +848,7 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{105.0}, {105.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -856,8 +856,8 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstRHSReverseMM)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -879,7 +879,7 @@ XLA_TEST_F(DotOperationTest,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{105.0, 105.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -887,8 +887,8 @@ XLA_TEST_F(DotOperationTest,
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstLHSReverseMM)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -910,7 +910,7 @@ XLA_TEST_F(DotOperationTest,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{96.0}, {105.0}, {114.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -918,8 +918,8 @@ XLA_TEST_F(DotOperationTest,
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstRHSRows)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -946,7 +946,7 @@ XLA_TEST_F(DotOperationTest,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{126.0, 129.0, 132.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -954,8 +954,8 @@ XLA_TEST_F(DotOperationTest,
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstLHSRows)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -982,7 +982,7 @@ XLA_TEST_F(DotOperationTest,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{129.0}, {129.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -990,8 +990,8 @@ XLA_TEST_F(DotOperationTest,
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstRHSCols)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1010,7 +1010,7 @@ XLA_TEST_F(DotOperationTest,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{56.0, 168.0, 91.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -1018,8 +1018,8 @@ XLA_TEST_F(DotOperationTest,
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
+ DotOfGatherOptimizationWithConstLHSCols)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1038,7 +1038,7 @@ XLA_TEST_F(DotOperationTest,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{168.0}, {168.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index e6f79b5ac5..45a5cdc896 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -557,8 +557,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
-// TODO(b/64070202): Investigate failure.
-XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
+XLA_TEST_F(FusionTest, TransposeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
@@ -800,7 +799,7 @@ void BM_ParallelFusion(int num_iters) {
auto param2 = builder.Parameter(2, shape2, "param2");
auto x = builder.Mul(param0, param1);
- auto y = builder.Add(x, param2);
+ builder.Add(x, param2);
auto computation = builder.Build().ConsumeValueOrDie();
// Transfer literals to device.
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 5a70c2a9ae..77f9c33ee1 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -54,7 +54,7 @@ class LocalClientExecuteTest : public LocalClientTestBase {
XLA_TEST_F(LocalClientExecuteTest, Constant) {
XlaBuilder builder(TestName());
- auto y = builder.ConstantR0<float>(123.0f);
+ builder.ConstantR0<float>(123.0f);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
@@ -701,7 +701,7 @@ XLA_TEST_F(LocalClientExecuteTest,
TestAllocator allocator(wrong_platform);
XlaBuilder builder(TestName());
- auto y = builder.ConstantR0<float>(123.0f);
+ builder.ConstantR0<float>(123.0f);
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
@@ -841,6 +841,31 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
Literal::CreateR0<int64>(123456789000LL).get()}));
}
+XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
+ XlaBuilder builder(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {3});
+ auto in = builder.Infeed(shape);
+ auto constant = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f});
+ builder.Add(in, constant);
+
+ std::unique_ptr<Literal> result;
+ std::unique_ptr<tensorflow::Thread> thread(
+ tensorflow::Env::Default()->StartThread(
+ tensorflow::ThreadOptions(), "execute_thread", [&] {
+ result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
+ builder.Build().ValueOrDie(), /*arguments=*/{}));
+ }));
+
+ ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
+ *Literal::CreateR1<float>({-5.0, 123.0, 42.0}),
+ local_client_->default_device_ordinal()));
+
+ // Join the thread.
+ thread.reset();
+
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+}
+
// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel.
// 2017-10-18.
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) {
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 27fd36e06a..c1f1c45c8c 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -89,7 +89,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
{1.0f, 0.0f}, // row 0
{-1.0f, 0.5f}, // row 1
});
- auto map = builder.Map({data}, add_half, {0, 1});
+ builder.Map({data}, add_half, {0, 1});
std::unique_ptr<Literal> expected =
Literal::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
@@ -108,7 +108,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
{5.0f, 6.0f}, // row 0
{1.0f, -8.0f}, // row 1
});
- auto max = builder.Max(lhs, rhs);
+ builder.Max(lhs, rhs);
std::unique_ptr<Literal> expected =
Literal::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
@@ -139,7 +139,7 @@ class TestLinspaceMaxParametric
tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
auto lhs = builder.ConstantR2FromArray2D<T>(*alhs);
auto rhs = builder.ConstantR2FromArray2D<T>(*arhs);
- auto max = builder.Max(lhs, rhs);
+ builder.Max(lhs, rhs);
Array2D<T> expected(rows, cols);
for (int row = 0; row < rows; ++row) {
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index a42a19af15..6597748c8d 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -454,7 +454,8 @@ XLA_TEST_F(MultiOutputFusionTest,
r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
c1 = f32[] constant(5)
- mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1)
+ b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={}
+ mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1)
ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
tuple(r1, mul, mul2)
}
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index 838f1b4e2f..3c3c865673 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -46,7 +46,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()},
ErrorSpec(0.0001f));
@@ -58,7 +58,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0");
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0");
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -71,7 +71,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -84,7 +84,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(
+ builder.Parameter(
0, ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}), "param0");
ComputeAndCompareR1U8(&builder, str, {param0_data.get()});
@@ -97,7 +97,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0),
{param0_data.get()}, ErrorSpec(0.01f));
@@ -110,7 +110,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
Array2D<float> expected_array(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
@@ -142,7 +142,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
// parameters to test that the parameters are not swapped.
//
// {11, 22} * {10, 20} = {110, 440}
- auto prod = builder.Mul(sum, param1);
+ builder.Mul(sum, param1);
ComputeAndCompareR1<float>(&builder, {110, 440},
{param0_data.get(), param1_data.get()},
@@ -157,7 +157,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
client_->TransferToServer(*literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
+ builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
auto computation_status = builder.Build();
ASSERT_NE(computation_status.status(), Status::OK());
@@ -169,12 +169,12 @@ XLA_TEST_F(ParamsTest, UnusedParameter) {
std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+ builder.Parameter(0, literal0->shape(), "param0");
std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+ builder.Parameter(1, literal1->shape(), "param1");
ComputeAndCompareR1<float>(&builder, {10, 20},
{param0_data.get(), param1_data.get()},
@@ -478,7 +478,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
- {1, 3}, {2, 4},
+ {1, 3},
+ {2, 4},
});
const Shape original = literal->shape();
{
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 77159efb26..f405bb3d49 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -36,20 +36,20 @@ class PredTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
XlaOp lhs_op = builder.ConstantR0<bool>(lhs);
XlaOp rhs_op = builder.ConstantR0<bool>(rhs);
- XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
+ (builder.*op)(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
};
TEST_F(PredTest, ConstantR0PredTrue) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR0<bool>(true);
+ builder.ConstantR0<bool>(true);
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, ConstantR0PredFalse) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR0<bool>(false);
+ builder.ConstantR0<bool>(false);
ComputeAndCompareR0<bool>(&builder, false, {});
}
@@ -79,14 +79,13 @@ TEST_F(PredTest, ConstantR0PredCompareGt) {
TEST_F(PredTest, ConstantR1Pred) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({true, false, false, true});
+ builder.ConstantR1<bool>({true, false, false, true});
ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
}
TEST_F(PredTest, ConstantR2Pred) {
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
+ builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
const string expected = R"(pred[2,3] {
{ 011 },
{ 100 }
@@ -97,21 +96,21 @@ TEST_F(PredTest, ConstantR2Pred) {
TEST_F(PredTest, AnyR1True) {
XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({true, false});
- TF_ASSERT_OK(Any(a, &builder).status());
+ Any(a);
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, AnyR1False) {
XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({false, false});
- TF_ASSERT_OK(Any(a, &builder).status());
+ Any(a);
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, AnyR1VacuouslyFalse) {
XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({});
- TF_ASSERT_OK(Any(a, &builder).status());
+ Any(a);
ComputeAndCompareR0<bool>(&builder, false, {});
}
@@ -122,7 +121,7 @@ TEST_F(PredTest, AnyR2True) {
{false, false, false},
{false, false, true},
});
- TF_ASSERT_OK(Any(a, &builder).status());
+ Any(a);
ComputeAndCompareR0<bool>(&builder, true, {});
}
@@ -133,7 +132,7 @@ TEST_F(PredTest, AnyR2False) {
{false, false, false},
{false, false, false},
});
- TF_ASSERT_OK(Any(a, &builder).status());
+ Any(a);
ComputeAndCompareR0<bool>(&builder, false, {});
}
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 1a2de6937c..ba58feea8e 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -294,9 +294,9 @@ XLA_TEST_F(PrngTest, RngUniformCrash) {
XlaBuilder builder(TestName());
// This used to crash XLA during LLVM IR generation for CPUs.
- auto rng_uniform = builder.RngUniform(builder.ConstantR0<int32>(0),
- builder.ConstantR0<int32>(1000 * 1000),
- ShapeUtil::MakeShape(S32, {}));
+ builder.RngUniform(builder.ConstantR0<int32>(0),
+ builder.ConstantR0<int32>(1000 * 1000),
+ ShapeUtil::MakeShape(S32, {}));
SetSeed(0);
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
}
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index d671d40456..579be77b24 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -829,8 +829,8 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
auto input_activations =
builder.Parameter(0, input_literal->shape(), "input");
XlaComputation add = CreateScalarAddComputation(F32, &builder);
- auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f),
- add, GetParam().reduce_dims);
+ builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
+ GetParam().reduce_dims);
auto expected =
ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
@@ -878,7 +878,7 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
std::unique_ptr<GlobalData> b_data =
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
auto b = builder.Parameter(0, b_literal->shape(), "b");
- auto max = builder.Reduce(b, a2, max_f32, {0});
+ builder.Reduce(b, a2, max_f32, {0});
ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index da1b588ec4..3e5087922c 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -48,7 +48,7 @@ TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) {
auto b = builder.ConstantR2<int32>({{17, 19}, {23, 29}, {31, 37}});
auto c = builder.Reshape(a, {6});
auto d = builder.Reshape(b, {6});
- auto e = builder.Mul(c, d);
+ builder.Mul(c, d);
ComputeAndCompareR1<int32>(&builder, {34, 57, 115, 203, 341, 481}, {});
}
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index a4580cd71d..fccc497550 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -125,10 +125,7 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
+XLA_TEST_P(ReshapeTest, Trivial0x3) {
XlaBuilder builder(TestName());
Array2D<float> input_array(0, 3);
auto input_literal = Literal::CreateR2FromArray2D(input_array);
@@ -141,10 +138,7 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-05-15
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
+XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
@@ -158,10 +152,7 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
+XLA_TEST_P(ReshapeTest, Trivial3x0) {
XlaBuilder builder(TestName());
Array2D<float> input_array(3, 0);
auto input_literal = Literal::CreateR2FromArray2D(input_array);
@@ -200,12 +191,8 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Splits an empty vector into an empty matrix.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) {
+XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
auto input_literal = Literal::CreateR1<float>({});
XlaOp parameter;
@@ -234,12 +221,8 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Transposes a 2x0 array to a 0x2 array.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) {
+XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
@@ -286,12 +269,8 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Transposes a 0x4 array with XlaBuilder::Transpose.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) {
+XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
@@ -319,13 +298,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Reshapes an empty 2-dimensional array with dimensions that are not just a
// rearrangement of the originals (split), but no reordering (no shuffle).
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) {
+XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
@@ -338,10 +313,7 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
+XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
@@ -372,11 +344,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
+XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 308d3fc78a..323635b0e6 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -51,7 +51,7 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
XlaOp lhs_op = builder.ConstantR0<NativeT>(lhs);
XlaOp rhs_op = builder.ConstantR0<NativeT>(rhs);
- XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
+ (builder.*op)(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
@@ -62,7 +62,7 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
XlaOp lhs_op = builder.ConstantR0<NativeT>(lhs);
XlaOp rhs_op = builder.ConstantR0<NativeT>(rhs);
- XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
+ (builder.*op)(lhs_op, rhs_op, {});
ComputeAndCompareR0<NativeT>(&builder, expected, {});
}
};
diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc
index 72707f2244..6d6c393655 100644
--- a/tensorflow/compiler/xla/tests/select_test.cc
+++ b/tensorflow/compiler/xla/tests/select_test.cc
@@ -38,7 +38,7 @@ TEST_F(SelectTest, SelectScalarF32True) {
auto pred = builder.ConstantR0<bool>(true);
auto on_true = builder.ConstantR0<float>(123.0f);
auto on_false = builder.ConstantR0<float>(42.0f);
- auto result = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
}
@@ -48,7 +48,7 @@ TEST_F(SelectTest, SelectScalarS32True) {
auto pred = builder.ConstantR0<bool>(true);
auto on_true = builder.ConstantR0<int32>(-42);
auto on_false = builder.ConstantR0<int32>(42);
- auto result = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR0<int32>(&builder, -42, {});
}
@@ -58,7 +58,7 @@ TEST_F(SelectTest, SelectScalarF32False) {
auto pred = builder.ConstantR0<bool>(false);
auto on_true = builder.ConstantR0<float>(123.0f);
auto on_false = builder.ConstantR0<float>(42.0f);
- auto result = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
}
@@ -68,7 +68,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
auto pred = builder.ConstantR1<bool>({});
auto on_true = builder.ConstantR1<float>({});
auto on_false = builder.ConstantR1<float>({});
- auto select = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -78,7 +78,7 @@ TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- auto select = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
error_spec_);
@@ -93,7 +93,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
auto cmp = builder.Eq(v1, v2);
auto on_true = builder.ConstantR1<float>({});
auto on_false = builder.ConstantR1<float>({});
- auto select = builder.Select(cmp, on_true, on_false);
+ builder.Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -107,7 +107,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
auto cmp = builder.Eq(v1, v2);
auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ builder.Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
error_spec_);
@@ -121,7 +121,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
auto cmp = builder.Gt(v1, v2);
auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ builder.Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
error_spec_);
@@ -141,7 +141,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
/*builder=*/&builder, /*data_handle=*/&v2);
auto cmp = builder.Gt(v1, v2);
- auto select = builder.Select(cmp, v1, v2);
+ builder.Select(cmp, v1, v2);
ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -182,7 +182,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
/*builder=*/&builder, /*data_handle=*/&v2);
auto cmp = builder.Gt(v1, v2);
- auto select = builder.Select(cmp, v1, v2);
+ builder.Select(cmp, v1, v2);
ComputeAndCompareR1<float>(&builder, expected_vec,
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -199,7 +199,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
auto on_false =
builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ builder.Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
error_spec_);
@@ -216,7 +216,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
auto on_false =
builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ builder.Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
error_spec_);
@@ -228,7 +228,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
auto pred = builder.ConstantR0<bool>(which);
auto on_true = builder.ConstantR1<float>({});
auto on_false = builder.ConstantR1<float>({});
- auto select = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -239,7 +239,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
auto pred = builder.ConstantR0<bool>(true);
auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
- auto select = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
}
@@ -249,7 +249,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
auto pred = builder.ConstantR0<bool>(false);
auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
- auto select = builder.Select(pred, on_true, on_false);
+ builder.Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index 8541698576..e9008fa48a 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -31,27 +31,29 @@ class TokenHloTest : public HloTestBase {};
XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
std::unique_ptr<HloModule> module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ builder.AddInstruction(HloInstruction::CreateAfterAll({}));
module->AddEntryComputation(builder.Build());
- EXPECT_IS_OK(HloVerifier().Run(module.get()).status());
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken()));
}
XLA_TEST_F(TokenHloTest, TokenTree) {
std::unique_ptr<HloModule> module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
- auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
- auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
- builder.AddInstruction(
- HloInstruction::CreateGenerateToken({token0, token0, token1, token2}));
+ auto token0 = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token2 = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ HloInstruction::CreateAfterAll({token0, token0, token1, token2}));
module->AddEntryComputation(builder.Build());
- EXPECT_IS_OK(HloVerifier().Run(module.get()).status());
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken()));
}
XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -89,24 +91,12 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
}
-XLA_TEST_F(TokenHloTest, InvalidTokenRoot) {
- std::unique_ptr<HloModule> module = CreateNewModule();
- auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
- module->AddEntryComputation(builder.Build());
-
- Status status = HloVerifier().Run(module.get()).status();
- ASSERT_IS_NOT_OK(status);
- EXPECT_THAT(status.error_message(),
- ::testing::HasSubstr("Entry root is or contains a token shape"));
-}
-
XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
std::unique_ptr<HloModule> module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
- builder.AddInstruction(HloInstruction::CreateGenerateToken({param}));
+ builder.AddInstruction(HloInstruction::CreateAfterAll({param}));
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int32>(123)));
module->AddEntryComputation(builder.Build());
@@ -120,7 +110,7 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
XLA_TEST_F(TokenHloTest, TokenInWhileLoop) {
// Thread a token around a while loop. Token is created and consumed by a
- // GenerateToken instruction in the while body.
+ // AfterAll instruction in the while body.
string module_string = R"(
HloModule TokenInWhileLoop
@@ -130,8 +120,8 @@ HloModule TokenInWhileLoop
%constant.1 = s32[] constant(1)
%add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
%get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
- %generate-token = token[] generate-token(token[] %get-tuple-element.2)
- ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token)
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
}
%Cond (param: (s32[], token[])) -> pred[] {
@@ -143,7 +133,7 @@ HloModule TokenInWhileLoop
ENTRY %TokenInWhileLoop () -> s32[] {
%zero = s32[] constant(0)
- %init_token = token[] generate-token()
+ %init_token = token[] after-all()
%init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
%while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
@@ -172,13 +162,13 @@ HloModule TokenInConditional
%False (param.2: s32[]) -> (s32[], token[]) {
%param.2 = s32[] parameter(0)
- %new_token = token[] generate-token()
+ %new_token = token[] after-all()
ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token)
}
ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
%param.3 = pred[] parameter(0)
- %init_token = token[] generate-token()
+ %init_token = token[] after-all()
%seven = s32[] constant(7)
%cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False
ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 85799d4cfb..86babb58c9 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -256,6 +256,18 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
+XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
+ // "Copy" a token from the device. The token has no physical representation so
+ // no copying is actually performed, but it shouldn't fail.
+ // TODO(b/110532604): Add transferring the token to device when this is
+ // supported.
+ auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result));
+}
+
XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
const int64 kIterationCount = 5000;
std::unique_ptr<Literal> literal1 = Literal::MakeTuple(
diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc
index fe1e3da7ec..db85344ed6 100644
--- a/tensorflow/compiler/xla/tests/transpose_test.cc
+++ b/tensorflow/compiler/xla/tests/transpose_test.cc
@@ -39,7 +39,7 @@ class TransposeTest : public ClientLibraryTestBase {
XLA_TEST_F(TransposeTest, Transpose0x0) {
XlaBuilder builder("Transpose");
auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
- auto result = builder.Transpose(lhs, {1, 0});
+ builder.Transpose(lhs, {1, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
}
@@ -47,7 +47,7 @@ XLA_TEST_F(TransposeTest, Transpose0x0) {
XLA_TEST_F(TransposeTest, Transpose0x42) {
XlaBuilder builder("Transpose");
auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42));
- auto result = builder.Transpose(lhs, {1, 0});
+ builder.Transpose(lhs, {1, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_);
}
@@ -55,7 +55,7 @@ XLA_TEST_F(TransposeTest, Transpose0x42) {
XLA_TEST_F(TransposeTest, Transpose7x0) {
XlaBuilder builder("Transpose");
auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0));
- auto result = builder.Transpose(lhs, {1, 0});
+ builder.Transpose(lhs, {1, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_);
}
@@ -65,7 +65,7 @@ TEST_F(TransposeTest, Transpose2x2) {
auto lhs = builder.ConstantR2<float>({
{1.0, 2.0}, {3.0, 4.0},
});
- auto result = builder.Transpose(lhs, {1, 0});
+ builder.Transpose(lhs, {1, 0});
Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}});
@@ -75,7 +75,7 @@ TEST_F(TransposeTest, Transpose2x2) {
XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) {
XlaBuilder builder("Transpose");
auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3));
- auto result = builder.Transpose(operand, {1, 2, 0});
+ builder.Transpose(operand, {1, 2, 0});
ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {});
}
@@ -83,7 +83,7 @@ XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) {
TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
XlaBuilder builder("Transpose");
auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
- auto result = builder.Transpose(operand, {1, 2, 0});
+ builder.Transpose(operand, {1, 2, 0});
Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}});
@@ -93,7 +93,7 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
XlaBuilder builder("Transpose");
auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
- auto result = builder.Transpose(operand, {2, 1, 0});
+ builder.Transpose(operand, {2, 1, 0});
Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}});
@@ -103,7 +103,7 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
TEST_F(TransposeTest, Transpose1x2x3_1x2x3) {
XlaBuilder builder("Transpose");
auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
- auto result = builder.Transpose(operand, {0, 1, 2});
+ builder.Transpose(operand, {0, 1, 2});
Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}});
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index c3abe22797..dbbe1b49e4 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -39,7 +39,7 @@ class UnaryOpTest : public ClientLibraryTestBase {
void AbsSize0TestHelper() {
XlaBuilder builder(TestName());
auto arg = builder.ConstantR1<T>({});
- auto abs = builder.Abs(arg);
+ builder.Abs(arg);
if (primitive_util::NativeToPrimitiveType<T>() == C64) {
ComputeAndCompareR1<float>(&builder, {}, {});
@@ -52,7 +52,7 @@ class UnaryOpTest : public ClientLibraryTestBase {
void AbsTestHelper() {
XlaBuilder builder(TestName());
auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()});
- auto abs = builder.Abs(arg);
+ builder.Abs(arg);
ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
}
@@ -62,7 +62,7 @@ class UnaryOpTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
auto arg = builder.ConstantR1<T>(
{-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
- auto sign = builder.Sign(arg);
+ builder.Sign(arg);
ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
}
@@ -98,7 +98,7 @@ void UnaryOpTest::AbsTestHelper<complex64>() {
{-0.3f, 0.4f},
{0, inf<float>()},
{-inf<float>(), 0}});
- auto abs = builder.Abs(arg);
+ builder.Abs(arg);
std::unique_ptr<Literal> expected =
Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
@@ -110,7 +110,7 @@ void UnaryOpTest::SignTestHelper<complex64>() {
XlaBuilder builder(TestName());
auto arg = builder.ConstantR1<complex64>(
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
- auto sign = builder.Sign(arg);
+ builder.Sign(arg);
std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
@@ -196,7 +196,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
XlaBuilder builder(TestName());
auto arg = builder.ConstantR1<unsigned int>(
{2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- auto abs = builder.Abs(arg);
+ builder.Abs(arg);
ComputeAndCompareR1<unsigned int>(
&builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
@@ -206,7 +206,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
XlaBuilder builder(TestName());
auto arg = builder.ConstantR1<unsigned int>(
{2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- auto sign = builder.Sign(arg);
+ builder.Sign(arg);
ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
index 82d301983f..9e76177483 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
@@ -58,9 +58,8 @@ TEST_F(VecOpsReduceTest, AddReduceR1F32) {
auto x = builder_.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR0<float>(&builder_, -4.2f, {}, errspec_);
}
@@ -72,9 +71,8 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) {
std::iota(input.begin(), input.end(), 100.0f);
auto x = builder_.ConstantR1<float>(input);
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
float expected = std::accumulate(input.begin(), input.end(), 0.0f);
ComputeAndCompareR0<float>(&builder_, expected, {}, errspec_);
@@ -85,9 +83,8 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) {
auto x = builder_.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto max_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer,
- /*dimensions_to_reduce=*/{0});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR0<float>(&builder_, 2.6f, {}, errspec_);
}
@@ -97,9 +94,8 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) {
auto x = builder_.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto max_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer,
- /*dimensions_to_reduce=*/{0});
+ builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR0<float>(&builder_, 4.0f, {}, errspec_);
}
@@ -114,9 +110,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) {
// ------ dim 1 ----------
// clang-format on
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{1});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
ComputeAndCompareR1<float>(&builder_, {6.0, 15.0}, {}, errspec_);
}
@@ -129,9 +124,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) {
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}});
// clang-format on
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR1<float>(&builder_, {5.0, 7.0, 9.0}, {}, errspec_);
}
@@ -139,9 +133,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{2});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{2});
Array2D<float> expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}});
@@ -151,9 +144,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{1});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
Array2D<float> expected_array(
{{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}});
@@ -164,9 +156,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
Array2D<float> expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}});
@@ -176,9 +167,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{1, 2});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1, 2});
ComputeAndCompareR1<float>(&builder_, {21.0, 21.0, 21.0}, {}, errspec_);
}
@@ -186,9 +176,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) {
XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0, 2});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 2});
ComputeAndCompareR1<float>(&builder_, {18.0, 45.0}, {}, errspec_);
}
@@ -196,9 +185,8 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0, 1});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1});
ComputeAndCompareR1<float>(&builder_, {15.0, 21.0, 27.0}, {}, errspec_);
}
@@ -206,9 +194,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) {
TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0, 1, 2});
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1, 2});
ComputeAndCompareR0<float>(&builder_, 63.0, {}, errspec_);
}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index 5cce7a2bf8..4f7168204f 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -52,7 +52,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) {
XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto exp = builder.Exp(x);
+ builder.Exp(x);
std::vector<float> expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02,
8.1662, 9.9742, 6.7379e-03, 4.0657e-01,
@@ -70,7 +70,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) {
exponents.push_back(i / static_cast<float>(count));
}
auto x = builder.ConstantR1<float>(exponents);
- auto exp = builder.Exp(x);
+ builder.Exp(x);
std::vector<float> expected;
expected.reserve(exponents.size());
@@ -99,7 +99,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
Array4D<float> expected(2, 2, 2, 2, expected_vector);
auto x = builder.ConstantR4FromArray4D<float>(exponents);
- auto exp = builder.Exp(x);
+ builder.Exp(x);
ComputeAndCompareR4<float>(&builder, expected, {},
ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
@@ -161,7 +161,7 @@ XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>({0.0, -0.0});
- auto exp = builder.SqrtF32(x);
+ builder.SqrtF32(x);
ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
}
@@ -169,7 +169,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) {
XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
- auto exp = builder.SqrtF32(x);
+ builder.SqrtF32(x);
std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -179,7 +179,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
XlaBuilder builder(TestName());
auto x =
builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345});
- auto exp = builder.Pow(x, builder.ConstantR0<float>(-.5f));
+ builder.Pow(x, builder.ConstantR0<float>(-.5f));
std::vector<float> expected = {.25, 1, .03125, 2.5,
2.23607, .009000, .900025};
@@ -195,7 +195,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR1<float>(
{-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto max = builder.Map({x, y}, add, {0});
+ builder.Map({x, y}, add, {0});
std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9,
0.1, -6.8, 4., -1., 2.2};
@@ -208,7 +208,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) {
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR1<float>(
{-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto max = builder.Max(x, y);
+ builder.Max(x, y);
std::vector<float> expected = {2.1, -0.6, 2.6, 0.2, 3.8,
2.3, -1.8, 4.9, 1.4, 1.6};
@@ -227,7 +227,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
{21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&builder, /*data_handle=*/&v2);
- auto max = builder.Max(v1, v2);
+ builder.Max(v1, v2);
ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -267,7 +267,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&builder, /*data_handle=*/&v2);
- auto max = builder.Max(v1, v2);
+ builder.Max(v1, v2);
ComputeAndCompareR1<float>(&builder, expected_vec,
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -278,7 +278,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR0<float>(0);
- auto max = builder.Max(x, y);
+ builder.Max(x, y);
std::vector<float> expected = {2.1, 0.0, 2.6, 0.0, 2.1,
2.3, 0.0, 0.0, 0.0, 1.6};
@@ -291,7 +291,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR1<float>(
{-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto min = builder.Min(x, y);
+ builder.Min(x, y);
std::vector<float> expected = {-0.4, -2.6, -3.0, -4.0, 2.1,
-2.2, -5.0, -0.9, -2.4, 0.6};
@@ -304,7 +304,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
auto one = builder.ConstantR0<float>(1);
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
- auto clamp = builder.Min(builder.Max(x, zero), one);
+ builder.Min(builder.Max(x, zero), one);
std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
0.9, 0.0, 0.1, 0.0, 0.6};
@@ -317,7 +317,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
auto one = builder.ConstantR0<float>(1);
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
- auto clamp = builder.Clamp(zero, x, one);
+ builder.Clamp(zero, x, one);
std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
0.9, 0.0, 0.1, 0.0, 0.6};
@@ -329,7 +329,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
auto one = builder.ConstantR1<float>({1.0f, 1.0f});
auto x = builder.ConstantR1<float>({2.1, -2.6});
- auto clamp = builder.Clamp(zero, x, one);
+ builder.Clamp(zero, x, one);
std::vector<float> expected = {1.0, 0.0};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -341,7 +341,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
auto two = builder.ConstantR0<float>(2);
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
- auto clamp = builder.Clamp(one, x, two);
+ builder.Clamp(one, x, two);
std::vector<float> expected = {2.0, 1.0, 2.0, 1.0, 2.0,
1.0, 1.0, 1.0, 1.0, 1.0};
@@ -353,7 +353,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) {
auto zero = builder.ConstantR0<int64>(0);
auto one = builder.ConstantR0<int64>(10);
auto x = builder.ConstantR1<int64>({-3, 3, 9, 13});
- auto clamp = builder.Clamp(zero, x, one);
+ builder.Clamp(zero, x, one);
std::vector<int64> expected = {0, 3, 9, 10};
ComputeAndCompareR1<int64>(&builder, expected, {});
@@ -380,7 +380,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
auto y_value =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value");
auto zero = builder.ConstantR0<float>(0.0);
- auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0<float>(5));
+ builder.Clamp(zero, y_value, builder.ConstantR0<float>(5));
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
clamp = computation_status.ConsumeValueOrDie();
@@ -407,7 +407,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
{
auto x = builder.ConstantR1<float>(
{2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto activations = builder.Map({x}, mult_relu_add, {0});
+ builder.Map({x}, mult_relu_add, {0});
}
std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7,
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index c463f3eac5..3119456347 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -184,8 +184,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
// while (result.sum() < 15.5f) {
// result = result + vector<float>(0);
// }
-// TODO(b/29185393): does not terminate on CPU.
-TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
+TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) {
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
// Create a computation for the reduction.
@@ -965,10 +964,8 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
XlaBuilder cond("cond");
auto cond_t = cond.Parameter(0, tuple_shape, "t");
- TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0),
- cond.ConstantR1<float>({42, 42})),
- &cond)
- .status());
+ Any(cond.Eq(cond.GetTupleElement(cond_t, 0),
+ cond.ConstantR1<float>({42, 42})));
XlaBuilder body("body");
auto body_t = body.Parameter(0, tuple_shape, "t");
@@ -997,12 +994,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
XlaBuilder cond("cond");
auto cond_t = cond.Parameter(0, element_shape, "t");
- TF_ASSERT_OK(
- Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status());
+ Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})));
XlaBuilder body("body");
- auto body_t = body.Parameter(0, element_shape, "t");
- auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2});
+ body.Parameter(0, element_shape, "t");
+ body.Broadcast(body.ConstantR0<float>(1.0), {2});
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
@@ -1029,7 +1025,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
auto body_t = body.Parameter(0, element_shape, "t");
auto tuple =
body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))});
- auto e = body.GetTupleElement(tuple, 1);
+ body.GetTupleElement(tuple, 1);
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
@@ -1068,7 +1064,7 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
XlaBuilder body("body");
auto body_t = body.Parameter(0, result_shape, "t");
- auto tuple = body.Tuple(
+ body.Tuple(
{body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0<int32>(1)),
body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0<int32>(1))});
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 0be950cacb..b081850eb5 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -187,7 +187,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
ClientLibrary::GetOrCreateLocalClient(platform));
XlaBuilder builder(TestName());
- auto result = builder.Tanh(builder.Add(
+ builder.Tanh(builder.Add(
builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb
index 324b23c24b..44532cb078 100644
--- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb
+++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb
@@ -190,7 +190,6 @@
" self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n",
" self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n",
"\n",
- "\n",
" def _rnn_layer(self, chars, cell, batch_size, training):\n",
" \"\"\"A single RNN layer.\n",
"\n",
@@ -203,13 +202,12 @@
" Returns:\n",
" A Tensor of shape (max_sequence_length, batch_size, output_size).\n",
" \"\"\"\n",
- " hidden_outputs = []\n",
- " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n",
+ " hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n",
" state, output = cell.zero_state(batch_size, tf.float32)\n",
" for ch in chars:\n",
" cell_output, (state, output) = cell.call(ch, (state, output))\n",
" hidden_outputs.append(cell_output)\n",
- " hidden_outputs = hidden_outputs.stack()\n",
+ " hidden_outputs = autograph.stack(hidden_outputs)\n",
" if training:\n",
" hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n",
" return hidden_outputs\n",
@@ -223,7 +221,7 @@
"\n",
"\n",
" def call(self, inputs, training=False):\n",
- " \"\"\"The RNN model code. Uses Eager and \n",
+ " \"\"\"The RNN model code. Uses Eager.\n",
"\n",
" The model consists of two RNN layers (made by lower_cell and upper_cell),\n",
" followed by a fully connected layer with ReLU activation.\n",
@@ -243,7 +241,8 @@
" seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n",
"\n",
" # Grab just the end-of-sequence from each output.\n",
- " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n",
+ " indices = (length - 1, range(batch_size))\n",
+ " indices = tf.stack(indices, 1)\n",
" sequence_ends = tf.gather_nd(seq, indices)\n",
" return self.relu_layer(sequence_ends)\n",
"\n",
@@ -381,7 +380,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 107,
"metadata": {
"colab": {
"autoexec": {
@@ -392,9 +391,9 @@
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 10604,
+ "elapsed": 5454,
"status": "ok",
- "timestamp": 1524095272039,
+ "timestamp": 1529952160455,
"user": {
"displayName": "",
"photoUrl": "",
@@ -403,7 +402,7 @@
"user_tz": 240
},
"id": "2pg1AfbxBJQq",
- "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660",
+ "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb",
"slideshow": {
"slide_type": "-"
}
@@ -413,7 +412,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Eval loss at step 100: 0.0674834\n"
+ "Eval loss at step 100: 0.0705221\n"
]
}
],
@@ -423,8 +422,8 @@
" 'learning_rate': 0.01,\n",
"}\n",
"\n",
- "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n",
- "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n",
+ "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n",
+ "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n",
"data_dir = \"tmp/rnn/data\"\n",
"\n",
"regressor = tf.estimator.Estimator(\n",
@@ -457,7 +456,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 108,
"metadata": {
"colab": {
"autoexec": {
@@ -468,9 +467,9 @@
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 7990,
+ "elapsed": 3432,
"status": "ok",
- "timestamp": 1524095280105,
+ "timestamp": 1529952163923,
"user": {
"displayName": "",
"photoUrl": "",
@@ -479,7 +478,7 @@
"user_tz": 240
},
"id": "dxHex2tUN_10",
- "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101",
+ "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2",
"slideshow": {
"slide_type": "slide"
}
@@ -491,12 +490,12 @@
"\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -507,12 +506,12 @@
"\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -520,15 +519,15 @@
{
"data": {
"text/html": [
- "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e"
+ "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -536,16 +535,16 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n",
- "//# sourceURL=js_71b9087b6d"
+ "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n",
+ "//# sourceURL=js_dc5d7f2784"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -553,16 +552,16 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_e390445f33"
+ "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_be7950150b"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -570,17 +569,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
- "//# sourceURL=js_241dd76d85"
+ "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
+ "//# sourceURL=js_d0c3bd4eaa"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -588,17 +587,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
- "//# sourceURL=js_60c64e3d50"
+ "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
+ "//# sourceURL=js_f10f6eba86"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -606,17 +605,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_14ea437cbd"
+ "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_ff29697179"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -624,17 +623,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_09294c2226"
+ "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_ff85295dc7"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -642,17 +641,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_e5e8266997"
+ "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_ed7aabfedb"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -660,17 +659,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
- "//# sourceURL=js_07a097f0ee"
+ "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
+ "//# sourceURL=js_c86f8feaf4"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -678,17 +677,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
- "//# sourceURL=js_790d669ca8"
+ "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
+ "//# sourceURL=js_4d0fde6662"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -696,17 +695,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_d30df771f0"
+ "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_3f66d52720"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -714,32 +713,32 @@
{
"data": {
"application/javascript": [
- "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_8a43a2da4b"
+ "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_375f5ae6d7"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
},
{
"data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n",
"text/plain": [
- "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e"
+ "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -748,17 +747,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_893ad561f4"
+ "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_34b0509660"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -766,17 +765,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
- "//# sourceURL=js_2d99e0ac17"
+ "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
+ "//# sourceURL=js_518a0f26fe"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -784,17 +783,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
- "//# sourceURL=js_5c19462e32"
+ "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
+ "//# sourceURL=js_17eb3ff612"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -802,17 +801,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_b9c8b7567b"
+ "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_99da807c8e"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -820,17 +819,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_fd05186348"
+ "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_dee01cb4b6"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -838,16 +837,16 @@
{
"data": {
"text/html": [
- "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e"
+ "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -856,17 +855,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n",
- "//# sourceURL=js_efef96e882"
+ "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n",
+ "//# sourceURL=js_8c378be329"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -875,17 +874,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
- "//# sourceURL=js_6eca889864"
+ "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
+ "//# sourceURL=js_f0b946600c"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -894,17 +893,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n",
- "//# sourceURL=js_f02070cc60"
+ "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n",
+ "//# sourceURL=js_9e21b1373a"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -913,17 +912,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n",
- "//# sourceURL=js_ed9faba660"
+ "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n",
+ "//# sourceURL=js_a7764968c6"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -932,17 +931,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n",
- "//# sourceURL=js_f3458d7074"
+ "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n",
+ "//# sourceURL=js_74279d3ff0"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -951,17 +950,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
- "//# sourceURL=js_3ffd97bd6f"
+ "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
+ "//# sourceURL=js_82b6c34cdb"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -970,17 +969,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_7f73e8bcca"
+ "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_ff6144734a"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -1043,28 +1042,6 @@
"kind": "local"
},
"name": "RNN Colorbot using Keras and Estimators",
- "provenance": [
- {
- "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl",
- "timestamp": 1523579810961
- },
- {
- "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
- "timestamp": 1523016192637
- },
- {
- "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
- "timestamp": 1522238054357
- },
- {
- "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
- "timestamp": 1521743157199
- },
- {
- "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
- "timestamp": 1520522344607
- }
- ],
"version": "0.3.2",
"views": {}
},
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index 2e0a2fcef4..7a30eb94f5 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -36,16 +36,3 @@ add_dependencies(
tf_cc_while_loop
tf_core_lib
tf_protos_cc)
-
-if(tensorflow_BUILD_PYTHON_BINDINGS)
- add_library(tf_c_python_api OBJECT
- "${tensorflow_source_dir}/tensorflow/c/python_api.cc"
- "${tensorflow_source_dir}/tensorflow/c/python_api.h"
- )
- add_dependencies(
- tf_c_python_api
- tf_c
- tf_core_lib
- tf_core_framework
- tf_protos_cc)
-endif()
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 786ea05c74..e3b59001bc 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -456,6 +456,18 @@ add_custom_command(
COMMENT "Running SWIG to generate Python wrappers"
VERBATIM )
+add_library(tf_c_python_api OBJECT
+ "${tensorflow_source_dir}/tensorflow/c/python_api.cc"
+ "${tensorflow_source_dir}/tensorflow/c/python_api.h"
+)
+add_dependencies(
+ tf_c_python_api
+ tf_c
+ tf_core_lib
+ tf_core_framework
+ tf_protos_cc
+ tf_python_protos_cc)
+
set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index a2bfce0362..0fc3773475 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -269,18 +269,20 @@ class FunctionBufferResourceHandleOp : public OpKernel {
std::vector<Tensor> func_args;
func_args.push_back(*string_arg);
+ const string& source_device = ctx->device()->name();
+
// Obtain and canonicalize target_device.
const Tensor* target_arg;
OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
- const string& target_device =
- DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar<string>()());
+ string target_device;
+ OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
+ target_arg->scalar<string>()(), source_device,
+ &target_device));
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr,
errors::Internal("No function library is provided."));
- const string& source_device = ctx->device()->name();
-
mutex_lock l(mu_);
if (!initialized_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index b08132cd72..9c7040de9e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -235,6 +235,36 @@ class PrefetchToDeviceTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.test_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
index af41f64286..74c1825a49 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -24,6 +24,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import ops
@@ -93,9 +94,18 @@ class RevBlock(tf.keras.Model):
for i in reversed(range(len(self.blocks))):
block = self.blocks[i]
- y_inv = x if i == 0 else block.backward(y, training=training)
+ if i == 0:
+ y_inv = x
+ else:
+ # Don't update running stats when reconstructing activations
+ vars_and_vals = block.get_moving_stats()
+ y_inv = block.backward(y, training=training)
+ block.restore_moving_stats(vars_and_vals)
+
+ # Update running stats when computing gradients during training
dy, grads, vars_ = block.backward_grads_and_vars(
y_inv, dy, training=training)
+
grads_all += grads
vars_all += vars_
@@ -159,17 +169,18 @@ class _Residual(tf.keras.Model):
"""Apply residual block to inputs."""
x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
- f_x2 = self.f.call(x2, training=training)
+ f_x2 = self.f(x2, training=training)
# TODO(lxuechen): Replace with simpler downsampling
x1_down = ops.downsample(
x1, self.filters // 2, self.strides, axis=self.axis)
x2_down = ops.downsample(
x2, self.filters // 2, self.strides, axis=self.axis)
y1 = f_x2 + x1_down
- g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error
+ g_y1 = self.g(y1, training=training)
y2 = g_y1 + x2_down
- if not concat: # Concat option needed for correct backward grads
+ if not concat: # For correct backward grads
return y1, y2
+
return tf.concat([y1, y2], axis=self.axis)
def backward(self, y, training=True):
@@ -178,9 +189,9 @@ class _Residual(tf.keras.Model):
assert self.strides == (1, 1)
y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
- g_y1 = self.g.call(y1, training=training)
+ g_y1 = self.g(y1, training=training)
x2 = y2 - g_y1
- f_x2 = self.f.call(x2, training=training)
+ f_x2 = self.f(x2, training=training)
x1 = y1 - f_x2
return tf.concat([x1, x2], axis=self.axis)
@@ -216,6 +227,22 @@ class _Residual(tf.keras.Model):
return tf.concat([dx1, dx2], axis=self.axis), grads, vars_
+ def get_moving_stats(self):
+ vars_and_vals = {}
+
+ def _is_moving_var(v): # pylint: disable=invalid-name
+ n = v.name
+ return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
+
+ for v in filter(_is_moving_var, self.f.variables + self.g.variables):
+ vars_and_vals[v] = v.read_value()
+
+ return vars_and_vals
+
+ def restore_moving_stats(self, vars_and_vals):
+ for var_, val in six.iteritems(vars_and_vals):
+ var_.assign(val)
+
def _BottleneckResidualInner(filters,
strides,
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
index f4436fd925..a28ca6e3e0 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -240,13 +240,12 @@ class _ResidualTest(tf.test.TestCase):
x = tf.random_normal(shape=data_shape)
residual = blocks._Residual(
filters=16, strides=(1, 1), input_shape=input_shape)
+
y_tr, y_ev = residual(x, training=True), residual(x, training=False)
- x_ = residual.backward(y_tr, training=True)
- # The numerical loss is alarming; reconstructed inputs could differ from
- # the original inputs often by more than 1e-3
- self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
x_ = residual.backward(y_ev, training=False)
- self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
+ self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
+ x_ = residual.backward(y_tr, training=True) # This updates moving avg
+ self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
def test_backward_channels_last(self):
"""Test `backward` function with `channels_last` data format."""
@@ -259,12 +258,12 @@ class _ResidualTest(tf.test.TestCase):
strides=(1, 1),
input_shape=input_shape,
data_format="channels_last")
+
y_tr, y_ev = residual(x, training=True), residual(x, training=False)
- x_ = residual.backward(y_tr, training=True)
- # Egregious numerical error
- self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
x_ = residual.backward(y_ev, training=False)
- self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01)
+ self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
+ x_ = residual.backward(y_tr, training=True) # This updates moving avg
+ self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1)
def test_backward_grads_and_vars_channels_first(self):
"""Test `backward_grads` function with `channels_first` data format."""
@@ -278,6 +277,8 @@ class _ResidualTest(tf.test.TestCase):
dy = tf.random_normal(shape=data_shape)
residual = blocks._Residual(
filters=16, strides=(1, 1), input_shape=input_shape)
+
+ vars_and_vals = residual.get_moving_stats()
dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars(
x, dy=dy, training=True)
dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars(
@@ -289,10 +290,23 @@ class _ResidualTest(tf.test.TestCase):
self.assertTrue(isinstance(vars_ev, list))
for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev,
vars_ev):
- if grad_tr is not None: # Batch norm moving mean, var gives None grad
- self.assertEqual(grad_tr.shape, grad_ev.shape)
- self.assertEqual(var_tr.shape, var_ev.shape)
- self.assertEqual(grad_tr.shape, var_tr.shape)
+ self.assertEqual(grad_tr.shape, grad_ev.shape)
+ self.assertEqual(var_tr.shape, var_ev.shape)
+ self.assertEqual(grad_tr.shape, var_tr.shape)
+
+ # Compare against the true gradient computed by the tape
+ residual.restore_moving_stats(vars_and_vals)
+ with tf.GradientTape(persistent=True) as tape:
+ tape.watch(x)
+ y = residual(x, training=True)
+ grads = tape.gradient(
+ y, [x] + residual.trainable_variables, output_gradients=[dy])
+ dx_tr_true, grads_tr_true = grads[0], grads[1:]
+
+ del tape
+
+ self.assertAllClose(dx_tr, dx_tr_true, rtol=1e-1, atol=1e-1)
+ self.assertAllClose(grads_tr, grads_tr_true, rtol=1e-1, atol=1e-1)
def test_backward_grads_and_vars_channels_last(self):
"""Test `backward_grads` function with `channels_last` data format."""
@@ -306,6 +320,7 @@ class _ResidualTest(tf.test.TestCase):
strides=(1, 1),
input_shape=input_shape,
data_format="channels_last")
+
dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars(
x, dy=dy, training=True)
dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars(
@@ -317,10 +332,9 @@ class _ResidualTest(tf.test.TestCase):
self.assertTrue(isinstance(vars_ev, list))
for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev,
vars_ev):
- if grad_tr is not None: # Batch norm moving mean, var gives None grad
- self.assertEqual(grad_tr.shape, grad_ev.shape)
- self.assertEqual(var_tr.shape, var_ev.shape)
- self.assertEqual(grad_tr.shape, var_tr.shape)
+ self.assertEqual(grad_tr.shape, grad_ev.shape)
+ self.assertEqual(var_tr.shape, var_ev.shape)
+ self.assertEqual(grad_tr.shape, var_tr.shape)
class _ResidualInnerTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
index 3bc69da5ad..e1d8b3a055 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
@@ -26,8 +26,6 @@ import tensorflow as tf
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 32
NUM_CHANNEL = 3
-NUM_TRAIN_IMG = 50000
-NUM_TEST_IMG = 10000
def get_ds_from_tfrecords(data_dir,
@@ -37,8 +35,8 @@ def get_ds_from_tfrecords(data_dir,
epochs=None,
shuffle=True,
data_format="channels_first",
- num_parallel_calls=4,
- prefetch=True,
+ num_parallel_calls=8,
+ prefetch=0,
div255=True,
dtype=tf.float32):
"""Returns a tf.train.Dataset object from reading tfrecords.
@@ -48,11 +46,12 @@ def get_ds_from_tfrecords(data_dir,
split: "train", "validation", or "test"
data_aug: Apply data augmentation if True
batch_size: Batch size of dataset object
- epochs: Number of epochs to repeat the dataset
+ epochs: Number of epochs to repeat the dataset; default `None` means
+ repeating indefinitely
shuffle: Shuffle the dataset if True
data_format: `channels_first` or `channels_last`
num_parallel_calls: Number of threads for dataset preprocess
- prefetch: Apply prefetch for the dataset if True
+ prefetch: Buffer size for prefetch
div255: Divide the images by 255 if True
dtype: Data type of images
Returns:
@@ -62,7 +61,7 @@ def get_ds_from_tfrecords(data_dir,
ValueError: Unknown split
"""
- if split not in ["train", "validation", "test"]:
+ if split not in ["train", "validation", "test", "train_all"]:
raise ValueError("Unknown split {}".format(split))
def _parser(serialized_example):
@@ -74,7 +73,11 @@ def get_ds_from_tfrecords(data_dir,
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image"], tf.uint8)
- image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL])
+ # Initially reshaping to [H, W, C] does not work
+ image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH])
+ # This is needed for `tf.image.resize_image_with_crop_or_pad`
+ image = tf.transpose(image, [1, 2, 0])
+
image = tf.cast(image, dtype)
label = tf.cast(features["label"], tf.int32)
@@ -93,13 +96,21 @@ def get_ds_from_tfrecords(data_dir,
return image, label
filename = os.path.join(data_dir, split + ".tfrecords")
- dataset = tf.data.TFRecordDataset(filename).repeat(epochs)
+ dataset = tf.data.TFRecordDataset(filename)
+ dataset = dataset.repeat(epochs)
dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(prefetch)
- if prefetch:
- dataset = dataset.prefetch(batch_size)
if shuffle:
- dataset = dataset.shuffle(NUM_TRAIN_IMG)
+ # Find the right size according to the split
+ size = {
+ "train": 40000,
+ "validation": 10000,
+ "test": 10000,
+ "train_all": 50000
+ }[split]
+ dataset = dataset.shuffle(size)
+
dataset = dataset.batch(batch_size)
return dataset
diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py
index 263a65dc76..30b0edbf43 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/config.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/config.py
@@ -61,12 +61,13 @@ def get_hparams_cifar_38():
config.add_hparam("max_train_iter", 80000)
config.add_hparam("seed", 1234)
config.add_hparam("shuffle", True)
- config.add_hparam("prefetch", True)
- config.add_hparam("log_every", 50)
- config.add_hparam("save_every", 50)
+ config.add_hparam("log_every", 500)
+ config.add_hparam("save_every", 500)
config.add_hparam("dtype", tf.float32)
- config.add_hparam("eval_batch_size", 500)
+ config.add_hparam("eval_batch_size", 1000)
config.add_hparam("div255", True)
+ # TODO(lxuechen): This is imprecise, when training with validation set,
+ # we only have 40k images in training data
config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)
@@ -104,11 +105,10 @@ def get_hparams_imagenet_56():
config.add_hparam("max_train_iter", 600000)
config.add_hparam("seed", 1234)
config.add_hparam("shuffle", True)
- config.add_hparam("prefetch", True)
config.add_hparam("log_every", 50)
config.add_hparam("save_every", 50)
config.add_hparam("dtype", tf.float32)
- config.add_hparam("eval_batch_size", 500)
+ config.add_hparam("eval_batch_size", 1000)
config.add_hparam("div255", True)
# TODO(lxuechen): Update this according to ImageNet data
config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index 9ef11f8e9b..1065592509 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -19,9 +19,11 @@ from __future__ import division
from __future__ import print_function
import os
+import sys
from absl import flags
import tensorflow as tf
+from tqdm import tqdm
from tensorflow.contrib.eager.python.examples.revnet import cifar_input
from tensorflow.contrib.eager.python.examples.revnet import config as config_
from tensorflow.contrib.eager.python.examples.revnet import revnet
@@ -38,28 +40,54 @@ def main(_):
tf.enable_eager_execution()
config = config_.get_hparams_cifar_38()
- model = revnet.RevNet(config=config)
-
- ds_train = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
- split="train",
- data_aug=True,
- batch_size=config.batch_size,
- epochs=config.epochs,
- shuffle=config.shuffle,
- data_format=config.data_format,
- dtype=config.dtype,
- prefetch=config.prefetch)
- ds_validation = cifar_input.get_ds_from_tfrecords(
+ if FLAGS.validate:
+ # 40k Training set
+ ds_train = cifar_input.get_ds_from_tfrecords(
+ data_dir=FLAGS.data_dir,
+ split="train",
+ data_aug=True,
+ batch_size=config.batch_size,
+ epochs=config.epochs,
+ shuffle=config.shuffle,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.batch_size)
+ # 10k Training set
+ ds_validation = cifar_input.get_ds_from_tfrecords(
+ data_dir=FLAGS.data_dir,
+ split="validation",
+ data_aug=False,
+ batch_size=config.eval_batch_size,
+ epochs=1,
+ shuffle=False,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.eval_batch_size)
+ else:
+ # 50k Training set
+ ds_train = cifar_input.get_ds_from_tfrecords(
+ data_dir=FLAGS.data_dir,
+ split="train_all",
+ data_aug=True,
+ batch_size=config.batch_size,
+ epochs=config.epochs,
+ shuffle=config.shuffle,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.batch_size)
+
+ # Always compute loss and accuracy on whole training and test set
+ ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
data_dir=FLAGS.data_dir,
- split="validation",
+ split="train_all",
data_aug=False,
batch_size=config.eval_batch_size,
epochs=1,
+ shuffle=False,
data_format=config.data_format,
dtype=config.dtype,
- prefetch=config.prefetch)
+ prefetch=config.eval_batch_size)
ds_test = cifar_input.get_ds_from_tfrecords(
data_dir=FLAGS.data_dir,
@@ -67,69 +95,116 @@ def main(_):
data_aug=False,
batch_size=config.eval_batch_size,
epochs=1,
+ shuffle=False,
data_format=config.data_format,
dtype=config.dtype,
- prefetch=config.prefetch)
+ prefetch=config.eval_batch_size)
+ model = revnet.RevNet(config=config)
global_step = tfe.Variable(1, trainable=False)
-
- def learning_rate(): # TODO(lxuechen): Remove once cl/201089859 is in place
- return tf.train.piecewise_constant(global_step, config.lr_decay_steps,
- config.lr_list)
-
- optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
- checkpoint = tf.train.Checkpoint(
+ learning_rate = tf.train.piecewise_constant(
+ global_step, config.lr_decay_steps, config.lr_list)
+ optimizer = tf.train.MomentumOptimizer(
+ learning_rate, momentum=config.momentum)
+ checkpointer = tf.train.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=global_step)
if FLAGS.train_dir:
summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
if FLAGS.restore:
latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
- checkpoint.restore(latest_path)
+ checkpointer.restore(latest_path)
+ print("Restored latest checkpoint at path:\"{}\" "
+ "with global_step: {}".format(latest_path, global_step.numpy()))
+ sys.stdout.flush()
+
+ warmup(model, config)
for x, y in ds_train:
loss = train_one_iter(model, x, y, optimizer, global_step=global_step)
- if global_step % config.log_every == 0:
- it_validation = ds_validation.make_one_shot_iterator()
+ if global_step.numpy() % config.log_every == 0:
+ it_train = ds_train_one_shot.make_one_shot_iterator()
+ acc_train, loss_train = evaluate(model, it_train)
it_test = ds_test.make_one_shot_iterator()
- acc_validation = evaluate(model, it_validation)
- acc_test = evaluate(model, it_test)
- print("Iter {}, "
- "train loss {}, "
- "validation accuracy {}, "
- "test accuracy {}".format(global_step.numpy(), loss, acc_validation,
- acc_test))
+ acc_test, loss_test = evaluate(model, it_test)
+ if FLAGS.validate:
+ it_validation = ds_validation.make_one_shot_iterator()
+ acc_validation, loss_validation = evaluate(model, it_validation)
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "validation set accuracy {:.4f}, loss {:4.f}"
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_validation,
+ loss_validation, acc_test, loss_test))
+ else:
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_test,
+ loss_test))
+ sys.stdout.flush()
if FLAGS.train_dir:
with summary_writer.as_default():
with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("Validation accuracy", acc_validation)
- tf.contrib.summary.scalar("Test accuracy", acc_test)
tf.contrib.summary.scalar("Training loss", loss)
+ tf.contrib.summary.scalar("Test accuracy", acc_test)
+ if FLAGS.validate:
+ tf.contrib.summary.scalar("Validation accuracy", acc_validation)
if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
- checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt")
+ saved_path = checkpointer.save(
+ file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
+ print("Saved checkpoint at path: \"{}\" "
+ "with global_step: {}".format(saved_path, global_step.numpy()))
+ sys.stdout.flush()
+
+def warmup(model, config, steps=1):
+ mock_input = tf.random_normal((config.batch_size,) + config.input_shape)
+ for _ in range(steps):
+ model(mock_input, training=False)
-def train_one_iter(model, inputs, labels, optimizer, global_step=None):
+
+def train_one_iter(model,
+ inputs,
+ labels,
+ optimizer,
+ global_step=None,
+ verbose=False):
"""Train for one iteration."""
- grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ if FLAGS.manual_grad:
+ if verbose:
+ print("Using manual gradients")
+ grads, vars_, loss = model.compute_gradients(inputs, labels)
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ else: # For correctness validation
+ if verbose:
+ print("Not using manual gradients")
+ with tf.GradientTape() as tape:
+ logits, _ = model(inputs, training=True)
+ loss = model.compute_loss(logits=logits, labels=labels)
+ grads = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
return loss.numpy()
def evaluate(model, iterator):
"""Compute accuracy with the given dataset iterator."""
+ mean_loss = tfe.metrics.Mean()
accuracy = tfe.metrics.Accuracy()
- for x, y in iterator:
+ for x, y in tqdm(iterator):
logits, _ = model(x, training=False)
+ loss = model.compute_loss(logits=logits, labels=y)
accuracy(
labels=tf.cast(y, tf.int64),
predictions=tf.argmax(logits, axis=1, output_type=tf.int64))
+ mean_loss(loss)
- return accuracy.result().numpy()
+ return accuracy.result().numpy(), mean_loss.result().numpy()
if __name__ == "__main__":
@@ -138,10 +213,18 @@ if __name__ == "__main__":
default=None,
help="[Optional] Directory to store the training information")
flags.DEFINE_string(
- "data_dir", default=None, help="Directory to load tfrecords.")
+ "data_dir", default=None, help="Directory to load tfrecords")
flags.DEFINE_boolean(
"restore",
- default=True,
+ default=False,
help="[Optional] Restore the latest checkpoint from `train_dir` if True")
+ flags.DEFINE_boolean(
+ "validate",
+ default=False,
+ help="[Optional] Use the validation set or not for hyperparameter search")
+ flags.DEFINE_boolean(
+ "manual_grad",
+ default=False,
+ help="[Optional] Use manual gradient graph to save memory")
FLAGS = flags.FLAGS
tf.app.run(main)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
index b3b8c262b1..0228bff6fa 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -27,6 +27,7 @@ from __future__ import print_function
import functools
import operator
+import six
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
@@ -47,6 +48,7 @@ class RevNet(tf.keras.Model):
self._init_block = self._construct_init_block()
self._block_list = self._construct_intermediate_blocks()
self._final_block = self._construct_final_block()
+ self._moving_stats_vars = None
def _construct_init_block(self):
init_block = tf.keras.Sequential(
@@ -153,7 +155,6 @@ class RevNet(tf.keras.Model):
def call(self, inputs, training=True):
"""Forward pass."""
- # Only store hidden states during training
if training:
saved_hidden = [inputs]
@@ -181,17 +182,22 @@ class RevNet(tf.keras.Model):
def compute_gradients(self, inputs, labels, training=True):
"""Manually computes gradients.
+ This method also SILENTLY updates the running averages of batch
+ normalization when `training` is set to True.
+
Args:
inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
labels: One-hot labels for classification
- training: for batch normalization
+ training: Use the mini-batch stats in batch norm if set to True
Returns:
- list of tuple each being (grad, var) for optimizer use
+ list of tuples each being (grad, var) for optimizer to use
"""
- # Forward pass record hidden states before downsampling
+ # Run forward pass to record hidden states; avoid updating running averages
+ vars_and_vals = self.get_moving_stats()
_, saved_hidden = self.call(inputs, training=training)
+ self.restore_moving_stats(vars_and_vals)
grads_all = []
vars_all = []
@@ -201,6 +207,7 @@ class RevNet(tf.keras.Model):
with tf.GradientTape() as tape:
x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
tape.watch(x)
+ # Running stats updated below
logits = self._final_block(x, training=training)
loss = self.compute_loss(logits, labels)
@@ -226,16 +233,38 @@ class RevNet(tf.keras.Model):
with tf.GradientTape() as tape:
x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed
+ # Running stats updated below
y = self._init_block(x, training=training)
grads_all += tape.gradient(
y, self._init_block.trainable_variables, output_gradients=[dy])
vars_all += self._init_block.trainable_variables
+ # Apply weight decay
grads_all = self._apply_weight_decay(grads_all, vars_all)
return grads_all, vars_all, loss
def _apply_weight_decay(self, grads, vars_):
"""Update gradients to reflect weight decay."""
- return [g + self.config.weight_decay * v for g, v in zip(grads, vars_)]
+ # Don't decay bias
+ return [
+ g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g
+ for g, v in zip(grads, vars_)
+ ]
+
+ def get_moving_stats(self):
+ vars_and_vals = {}
+
+ def _is_moving_var(v):
+ n = v.name
+ return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
+
+ for v in filter(_is_moving_var, self.variables):
+ vars_and_vals[v] = v.read_value()
+
+ return vars_and_vals
+
+ def restore_moving_stats(self, vars_and_vals):
+ for var_, val in six.iteritems(vars_and_vals):
+ var_.assign(val)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index cb3bac13f9..a5f240436a 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -36,10 +36,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None):
return loss
-class RevnetTest(tf.test.TestCase):
+class RevNetTest(tf.test.TestCase):
def setUp(self):
- super(RevnetTest, self).setUp()
+ super(RevNetTest, self).setUp()
+ tf.set_random_seed(1)
config = config_.get_hparams_imagenet_56()
shape = (config.batch_size,) + config.input_shape
self.model = revnet.RevNet(config=config)
@@ -56,7 +57,7 @@ class RevnetTest(tf.test.TestCase):
del self.x
del self.t
del self.config
- super(RevnetTest, self).tearDown()
+ super(RevNetTest, self).tearDown()
def test_call(self):
"""Test `call` function."""
@@ -67,7 +68,8 @@ class RevnetTest(tf.test.TestCase):
def test_compute_gradients(self):
"""Test `compute_gradients` function."""
- grads, vars_, _ = self.model.compute_gradients(inputs=self.x, labels=self.t)
+ grads, vars_, _ = self.model.compute_gradients(
+ inputs=self.x, labels=self.t, training=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -84,7 +86,7 @@ class RevnetTest(tf.test.TestCase):
def test_compute_gradients_defun(self):
"""Test `compute_gradients` function with defun."""
compute_gradients = tfe.defun(self.model.compute_gradients)
- grads, vars_, _ = compute_gradients(self.x, self.t)
+ grads, vars_, _ = compute_gradients(self.x, self.t, training=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -144,7 +146,7 @@ class MockIterator(object):
return self._tensors
-class RevnetBenchmark(tf.test.Benchmark):
+class RevNetBenchmark(tf.test.Benchmark):
"""Eager and graph benchmarks for RevNet."""
def _train_batch_sizes(self):
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py
index f1c60a912c..4bb90cf81b 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn.py
@@ -53,6 +53,18 @@ class DNNEstimator(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = DNNEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = DNNEstimator(
head=tf.contrib.estimator.multi_label_head(n_classes=3),
@@ -115,8 +127,9 @@ class DNNEstimator(estimator.Estimator):
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to Adagrad optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
index ccaf1128bf..894a295498 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
@@ -53,12 +53,19 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
dnn_hidden_units=[1000, 500, 100],
dnn_optimizer=tf.train.ProximalAdagradOptimizer(...))
- # To apply L1 and L2 regularization, you can set optimizers as follows:
+ # To apply L1 and L2 regularization, you can set dnn_optimizer to:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
- # It is same for FtrlOptimizer.
+ # To apply learning rate decay, you can set dnn_optimizer to a callable:
+ lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96)
+ # It is the same for linear_optimizer.
# Input builders
def input_fn_train: # returns x, y
@@ -116,12 +123,16 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the linear part of the model. Defaults to FTRL optimizer.
+ the linear part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL
+ optimizer.
dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the deep part of the model. Defaults to Adagrad optimizer.
+ the deep part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad
+ optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully
connected.
dnn_activation_fn: Activation function applied to each layer. If None,
diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py
index 3bf4abe83d..b960b16f1b 100644
--- a/tensorflow/contrib/estimator/python/estimator/linear.py
+++ b/tensorflow/contrib/estimator/python/estimator/linear.py
@@ -39,6 +39,18 @@ class LinearEstimator(estimator.Estimator):
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b])
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = LinearEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=lambda: tf.train.FtrlOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator using the FTRL optimizer with regularization.
estimator = LinearEstimator(
head=tf.contrib.estimator.multi_label_head(n_classes=3),
@@ -99,8 +111,9 @@ class LinearEstimator(estimator.Estimator):
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to FTRL optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
"""
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index 44ea2dcd90..192162cfce 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -5,7 +5,8 @@ android {
buildToolsVersion "26.0.1"
defaultConfig {
applicationId "android.example.com.tflitecamerademo"
- minSdkVersion 15
+ // Required by Camera2 API.
+ minSdkVersion 21
targetSdkVersion 26
versionCode 1
versionName "1.0"
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 69a2f638af..a4229f91f5 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -50,6 +50,7 @@ from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: di
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework.importer import import_graph_def
@@ -269,6 +270,48 @@ class TocoConverter(object):
return cls(
graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
+ @classmethod
+ def from_keras_model_file(cls,
+ model_file,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None):
+ """Creates a TocoConverter class from a tf.keras model file.
+
+ Args:
+ model_file: Full filepath of HDF5 file containing the tf.keras model.
+ input_arrays: List of input tensors to freeze graph with. Uses input
+ arrays from SignatureDef when none are provided. (default None)
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+
+ Returns:
+ TocoConverter class.
+ """
+ _keras.backend.clear_session()
+ _keras.backend.set_learning_phase(False)
+ keras_model = _keras.models.load_model(model_file)
+ sess = _keras.backend.get_session()
+
+ # Get input and output tensors.
+ if input_arrays:
+ input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ else:
+ input_tensors = keras_model.inputs
+
+ if output_arrays:
+ output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ else:
+ output_tensors = keras_model.outputs
+ set_tensor_shapes(input_tensors, input_shapes)
+
+ graph_def = _freeze_graph(sess, output_tensors)
+ return cls(graph_def, input_tensors, output_tensors)
+
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
@@ -366,7 +409,7 @@ def _is_frozen_graph(sess):
Bool.
"""
for op in sess.graph.get_operations():
- if op.type.startswith("Variable"):
+ if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
return False
return True
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index a9475de474..ca2af5aaed 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
import os
+import tempfile
import numpy as np
from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python.interpreter import Interpreter
+from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -618,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
+class FromKerasFile(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ keras.backend.clear_session()
+
+ def _getSequentialModel(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testSequentialModel(self):
+ """Test a Sequential tf.keras model with default inputs."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testSequentialModelInputArray(self):
+ """Test a Sequential tf.keras model testing input arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid input array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['invalid-input'])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ # Valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['dense_input'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testSequentialModelInputShape(self):
+ """Test a Sequential tf.keras model testing input shapes argument."""
+ keras_file = self._getSequentialModel()
+
+ # Passing in shape of invalid input array has no impact as long as all input
+ # arrays have a shape.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'invalid-input': [2, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Passing in shape of valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'dense_input': [2, 3]})
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ # Check input shape from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertTrue(([2, 3] == input_details[0]['shape']).all())
+
+ def testSequentialModelOutputArray(self):
+ """Test a Sequential tf.keras model testing output arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid output array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['invalid-output'])
+ self.assertEqual("Invalid tensors 'invalid-output' were found.",
+ str(error.exception))
+
+ # Valid output array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['time_distributed/Reshape_1'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testFunctionalModel(self):
+ """Test a Functional tf.keras model with default inputs."""
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFunctionalModelMultipleInputs(self):
+ """Test a Functional tf.keras model with multiple inputs and outputs."""
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('input_a', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('input_b', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(2, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('dropout/Identity', output_details[1]['name'])
+ self.assertEqual(np.float32, output_details[1]['dtype'])
+ self.assertTrue(([1, 4] == output_details[1]['shape']).all())
+ self.assertEqual((0., 0.), output_details[1]['quantization'])
+
+ def testFunctionalSequentialModel(self):
+ """Test a Functional tf.keras model containing a Sequential model."""
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index d18a29834b..249b940f92 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -74,6 +74,9 @@ def _get_toco_converter(flags):
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
+ elif flags.keras_model_file:
+ converter_fn = lite.TocoConverter.from_keras_model_file
+ converter_kwargs["model_file"] = flags.keras_model_file
return converter_fn(**converter_kwargs)
@@ -227,6 +230,10 @@ def run_main(_):
"--saved_model_dir",
type=str,
help="Full filepath of directory containing the SavedModel.")
+ input_file_group.add_argument(
+ "--keras_model_file",
+ type=str,
+ help="Full filepath of HDF5 file containing tf.Keras model.")
# Model format flags.
parser.add_argument(
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index afa6fd6957..b04d166f89 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -15,6 +15,7 @@ Table of contents:
* [Exporting a GraphDef from tf.Session](#basic-graphdef-sess)
* [Exporting a GraphDef from file](#basic-graphdef-file)
* [Exporting a SavedModel](#basic-savedmodel)
+ * [Exporting a tf.keras File](#basic-keras-file)
* [Complex examples](#complex)
* [Exporting a quantized GraphDef](#complex-quant)
* [TensorFlow Lite Python interpreter](#interpreter)
@@ -114,6 +115,51 @@ For more complex SavedModels, the optional parameters that can be passed into
`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are
available by running `help(tf.contrib.lite.TocoConverter)`.
+### Exporting a tf.keras File <a name="basic-keras-file"></a>
+
+The following example shows how to convert a tf.keras model into a TensorFlow
+Lite FlatBuffer.
+
+```python
+import tensorflow as tf
+
+converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5")
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
+The tf.keras file must contain both the model and the weights. A comprehensive
+example including model construction can be seen below.
+
+```python
+import numpy as np
+import tensorflow as tf
+
+# Generate tf.keras model.
+model = tf.keras.models.Sequential()
+model.add(tf.keras.layers.Dense(2, input_shape=(3,)))
+model.add(tf.keras.layers.RepeatVector(3))
+model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)))
+model.compile(loss=tf.keras.losses.MSE,
+ optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[tf.keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+
+x = np.random.random((1, 3))
+y = np.random.random((1, 3, 3))
+model.train_on_batch(x, y)
+model.predict(x)
+
+# Save tf.keras model in HDF5 format.
+keras_file = "keras_model.h5"
+tf.keras.models.save_model(model, keras_file)
+
+# Convert to TensorFlow Lite model.
+converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
## Complex examples <a name="complex"></a>
For models where the default value of the attributes is not sufficient, the
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index da7e5add7e..485e853e25 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -378,7 +378,7 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
for (int i = 0; i < input_flat_size; i++) {
output_bool_data[i] = input_tensor.bool_val(0);
}
- } else if (input_tensor.int_val_size() == input_flat_size) {
+ } else if (input_tensor.bool_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.bool_val_size(); i++) {
output_bool_data[i] = input_tensor.bool_val(i);
}
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index f1ef218e74..3e41e3d0b4 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -81,6 +81,19 @@ class EagerFileTest(test_util.TensorFlowTestCase):
# test here that we're calling them correctly.
self.assertTrue(gfile.Exists(logdir))
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerMemory(self):
+ training_util.get_or_create_global_step()
+ logdir = self.get_temp_dir()
+ with summary_ops.create_file_writer(
+ logdir, max_queue=0,
+ name='t0').as_default(), summary_ops.always_record_summaries():
+ summary_ops.generic('tensor', 1, '')
+ summary_ops.scalar('scalar', 2.0)
+ summary_ops.histogram('histogram', [1.0])
+ summary_ops.image('image', [[[[1.0]]]])
+ summary_ops.audio('audio', [[1.0]], 1.0, 1)
+
def testDefunSummarys(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 59e76cb575..0e41170367 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -793,6 +793,7 @@ tf_cuda_library(
"framework/graph_def_util.h",
"framework/graph_to_functiondef.h",
"framework/kernel_def_builder.h",
+ "framework/kernel_def_util.h",
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
@@ -1198,6 +1199,7 @@ tf_cuda_library(
hdrs = [
"common_runtime/device.h",
"common_runtime/device_factory.h",
+ "common_runtime/function.h",
"common_runtime/optimization_registry.h",
"common_runtime/shape_refiner.h",
"graph/algorithm.h",
@@ -3377,6 +3379,7 @@ tf_cc_tests(
"framework/graph_def_util_test.cc",
"framework/graph_to_functiondef_test.cc",
"framework/kernel_def_builder_test.cc",
+ "framework/kernel_def_util_test.cc",
"framework/memory_types_test.cc",
"framework/node_def_builder_test.cc",
"framework/node_def_util_test.cc",
diff --git a/tensorflow/core/framework/kernel_def_util.cc b/tensorflow/core/framework/kernel_def_util.cc
new file mode 100644
index 0000000000..bbd3dd3e57
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_util.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/kernel_def_util.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/kernel_def.pb_text.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+namespace {
+// Helper for KernelAttrsMatch().
+bool InTypeList(DataType dt, const AttrValue& type_list) {
+ for (int in_list : type_list.list().type()) {
+ if (dt == in_list) return true;
+ }
+ return false;
+}
+} // namespace
+
+Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
+ bool* match) {
+ *match = false;
+ for (const auto& constraint : kernel_def.constraint()) {
+ if (constraint.allowed_values().list().type_size() == 0) {
+ return errors::Unimplemented(
+ "KernelDef '", ProtoShortDebugString(kernel_def),
+ " has constraint on attr '", constraint.name(),
+ "' with unsupported type: ",
+ SummarizeAttrValue(constraint.allowed_values()));
+ }
+
+ const AttrValue* found = attrs.Find(constraint.name());
+ if (found) {
+ if (found->type() != DT_INVALID) {
+ if (!InTypeList(found->type(), constraint.allowed_values())) {
+ return Status::OK();
+ }
+ } else {
+ if (!AttrValueHasType(*found, "list(type)").ok()) {
+ return errors::InvalidArgument(
+ "KernelDef '", ProtoShortDebugString(kernel_def),
+ "' has constraint on attr '", constraint.name(),
+ "' that has value '", SummarizeAttrValue(*found),
+ "' that does not have type 'type' or 'list(type)' in NodeDef "
+ "'",
+ attrs.SummarizeNode(), "'");
+ }
+
+ for (int t : found->list().type()) {
+ if (!InTypeList(static_cast<DataType>(t),
+ constraint.allowed_values())) {
+ return Status::OK();
+ }
+ }
+ }
+ } else {
+ return errors::InvalidArgument(
+ "OpKernel '", kernel_def.op(), "' has constraint on attr '",
+ constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
+ "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
+ }
+ }
+ *match = true;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/kernel_def_util.h b/tensorflow/core/framework/kernel_def_util.h
new file mode 100644
index 0000000000..b973cefc4f
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_util.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+
+namespace tensorflow {
+
+// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
+// an error if attrs in kernel_def are not found, or have a mismatching type.
+Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
+ bool* match);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/kernel_def_util_test.cc b/tensorflow/core/framework/kernel_def_util_test.cc
new file mode 100644
index 0000000000..a2e4aa82fa
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_util_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/kernel_def_util.h"
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+NodeDef NodeDefFromText(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+KernelDef KernelDefFromText(const string& text) {
+ KernelDef kernel_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &kernel_def));
+ return kernel_def;
+}
+
+class AttrsMatchTest : public ::testing::Test {
+ protected:
+ void ExpectStatus(const string& node_def_str, const string& kernel_def_str,
+ error::Code code) {
+ bool match;
+ auto status = KernelAttrsMatch(KernelDefFromText(kernel_def_str),
+ NodeDefFromText(node_def_str), &match);
+ LOG(INFO) << "status: " << status;
+ EXPECT_EQ(code, status.code());
+ if (!status.ok()) {
+ EXPECT_FALSE(match)
+ << "Expect no match between the given NodeDef and KernelDef";
+ }
+ }
+};
+
+TEST_F(AttrsMatchTest, ValidConstraint) {
+ string node_def_str = R"(
+ name: "ValidConstraint-op"
+ op: "ValidConstraint"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ )";
+ string kernel_def_str = R"(
+ op: "ValidConstraint"
+ device_type: "CPU"
+ constraint {
+ name: "T"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ ExpectStatus(node_def_str, kernel_def_str, error::OK);
+}
+
+TEST_F(AttrsMatchTest, BadConstraint) {
+ string node_def_str = R"(
+ name: "BadConstraint-op"
+ op: "BadConstraint"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ )";
+ string kernel_def_str = R"(
+ op: "BadConstraint"
+ device_type: "CPU"
+ constraint {
+ name: "T"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ ExpectStatus(node_def_str, kernel_def_str, error::INVALID_ARGUMENT);
+}
+
+TEST_F(AttrsMatchTest, Unimplemented) {
+ string node_def_str = R"(
+ name: "BadConstraint-op"
+ op: "BadConstraint"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ )";
+ string kernel_def_str = R"(
+ op: "BadConstraint"
+ device_type: "CPU"
+ constraint {
+ name: "T"
+ allowed_values {
+ list {
+ }
+ }
+ }
+ )";
+ ExpectStatus(node_def_str, kernel_def_str, error::UNIMPLEMENTED);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c2561b5019..8a332fa1d8 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
+#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -969,62 +970,6 @@ void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
namespace {
-// Helper for AttrsMatch().
-bool InTypeList(DataType dt, const AttrValue& type_list) {
- for (int in_list : type_list.list().type()) {
- if (dt == in_list) return true;
- }
- return false;
-}
-
-// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
-// an error if attrs in kernel_def are not found, or have a mismatching type.
-Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
- *match = false;
- for (const auto& constraint : kernel_def.constraint()) {
- if (constraint.allowed_values().list().type_size() == 0) {
- return errors::Unimplemented(
- "KernelDef '", ProtoShortDebugString(kernel_def),
- " has constraint on attr '", constraint.name(),
- "' with unsupported type: ",
- SummarizeAttrValue(constraint.allowed_values()));
- }
-
- const AttrValue* found = attrs.Find(constraint.name());
- if (found) {
- if (found->type() != DT_INVALID) {
- if (!InTypeList(found->type(), constraint.allowed_values())) {
- return Status::OK();
- }
- } else {
- if (!AttrValueHasType(*found, "list(type)").ok()) {
- return errors::InvalidArgument(
- "KernelDef '", ProtoShortDebugString(kernel_def),
- "' has constraint on attr '", constraint.name(),
- "' that has value '", SummarizeAttrValue(*found),
- "' that does not have type 'type' or 'list(type)' in NodeDef "
- "'",
- attrs.SummarizeNode(), "'");
- }
-
- for (int t : found->list().type()) {
- if (!InTypeList(static_cast<DataType>(t),
- constraint.allowed_values())) {
- return Status::OK();
- }
- }
- }
- } else {
- return errors::InvalidArgument(
- "OpKernel '", kernel_def.op(), "' has constraint on attr '",
- constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
- "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
- }
- }
- *match = true;
- return Status::OK();
-}
-
static const StringPiece kKernelAttr("_kernel");
// TODO(irving): Replace with const Node& version below.
@@ -1043,7 +988,7 @@ Status FindKernelRegistration(const DeviceType& device_type,
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
bool match;
- TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
+ TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
if (match) {
if (*reg != nullptr) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index 80c76df255..b5c2c2aac8 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -24,6 +24,9 @@ namespace tensorflow {
TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {}
+SafeTensorId::SafeTensorId(StringPiece str, int idx)
+ : SafeTensorId(str.ToString(), idx) {}
+
SafeTensorId::SafeTensorId(const TensorId& id)
: SafeTensorId(id.first.ToString(), id.second) {}
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
index bf13fc78a6..b0978b4120 100644
--- a/tensorflow/core/graph/tensor_id.h
+++ b/tensorflow/core/graph/tensor_id.h
@@ -68,6 +68,7 @@ struct SafeTensorId : public std::pair<string, int> {
// NOTE(skyewm): this is required on some platforms. I'm not sure why the
// using statement above isn't always sufficient.
SafeTensorId() : Base() {}
+ SafeTensorId(StringPiece str, int idx);
SafeTensorId(const TensorId& id);
string ToString() const {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 90be051764..d8c5d09c4d 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2519,33 +2519,32 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
bool* modified) {
const auto& t =
ctx().graph_properties->GetInputProperties(input->name())[i];
- for (int k = 0; k < t.shape().dim_size(); ++k) {
- // Skip if t shape is not fully determined.
- if (t.shape().dim(k).size() < 0) {
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[j];
+ for (int k = 0; k < c.shape().dim_size(); ++k) {
+ // Skip if c shape is not fully determined.
+ if (c.shape().dim(k).size() < 0) {
return Status::OK();
}
}
- const auto& c =
- ctx().graph_properties->GetInputProperties(input->name())[j];
TensorShapeProto broadcast_shape;
if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
- return errors::InvalidArgument("Cannot get broadcast shape for: ",
- t.DebugString(), " and ", c.DebugString());
+ return Status::OK();
}
if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
// skip if the non-constant tensor doesn't have the same shape after
// broadcast.
return Status::OK();
}
- if (TensorShape::IsValid(t.shape()) && t.has_value()) {
- Tensor tensor(t.dtype(), t.shape());
- if (!tensor.FromProto(t.value())) {
+ if (TensorShape::IsValid(c.shape()) && c.has_value()) {
+ Tensor constant(c.dtype(), c.shape());
+ if (!constant.FromProto(c.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- t.value().DebugString());
+ c.value().DebugString());
}
complex128 element;
- for (int k = 0; k < tensor.NumElements(); ++k) {
- if (!GetElement(tensor, k, &element)) {
+ for (int k = 0; k < constant.NumElements(); ++k) {
+ if (!GetElement(constant, k, &element)) {
// input data type is not supported by log1p. Skip.
return Status::OK();
}
@@ -2558,11 +2557,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
node->set_op("Log1p");
- node->set_input(0, y->name());
- node->add_input(AsControlDependency(x->name()));
+ node->set_input(0, input->input(i));
+ node->add_input(AsControlDependency(y->name()));
ForwardControlDependencies(node, {input});
AddToOptimizationQueue(node);
+ AddToOptimizationQueue(input);
AddToOptimizationQueue(x);
AddToOptimizationQueue(y);
*modified = true;
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
index 0de97de205..f942b1a8a9 100644
--- a/tensorflow/core/kernels/dense_update_ops.cc
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -98,6 +98,8 @@ typedef Eigen::SyclDevice SYCLDevice;
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
+// quint16 not included in QUANTZIED_TYPES
+TF_CALL_quint16(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index f2724735bf..fcdf6c447c 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -302,15 +302,21 @@ class RemoteCallOp : public AsyncOpKernel {
~RemoteCallOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- const Tensor* target;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
- const string& target_device =
- DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
-
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."),
done);
+
+ const string& source_device = lib->device()->name();
+ const Tensor* target;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
+ string target_device;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
+ source_device, &target_device),
+ done);
+
AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
instantiate_opts.target = target_device;
@@ -345,7 +351,7 @@ class RemoteCallOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.runner = ctx->runner();
- opts.source_device = lib->device()->name();
+ opts.source_device = source_device;
if (opts.source_device != target_device) {
opts.remote_execution = true;
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 23fdfe944a..f08dd4f750 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -133,7 +133,6 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
bool should_select = true;
for (int j = selected.size() - 1; j >= 0; --j) {
iou = IOU(boxes_data, next_candidate.box_index, selected[j]);
- if (iou == 0.0) continue;
if (iou > iou_threshold) should_select = false;
}
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
index 41494f56c5..3b9133ed7e 100644
--- a/tensorflow/core/kernels/pad_op.cc
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -320,7 +320,7 @@ namespace functor {
DECLARE_GPU_SPEC(T, 5); \
DECLARE_GPU_SPEC(T, 6);
-TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPECS);
TF_CALL_int8(DECLARE_GPU_SPECS);
} // namespace functor
@@ -353,7 +353,7 @@ TF_CALL_int8(DECLARE_GPU_SPECS);
.HostMemory("constant_values"), \
PadOp<GPUDevice, T, int64>)
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_int8(REGISTER_GPU_KERNEL);
// A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/pad_op_gpu.cu.cc b/tensorflow/core/kernels/pad_op_gpu.cu.cc
index 8e13e19e2e..00ec44adc2 100644
--- a/tensorflow/core/kernels/pad_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/pad_op_gpu.cu.cc
@@ -39,7 +39,7 @@ typedef Eigen::GpuDevice GPUDevice;
DEFINE_GPU_PAD_SPECS(T, int32) \
DEFINE_GPU_PAD_SPECS(T, int64)
-TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPECS);
TF_CALL_int8(DEFINE_GPU_SPECS);
} // namespace tensorflow
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index 2c0576ff10..1c130ba300 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -354,6 +354,18 @@ struct bfloat16 {
return x;
}
+ static bfloat16 highest() {
+ bfloat16 x;
+ x.value = 0x7F7F; // 0x1.FEp127
+ return x;
+ }
+
+ static bfloat16 lowest() {
+ bfloat16 x;
+ x.value = 0xFF7F; // -0x1.FEp127
+ return x;
+ }
+
uint16_t value;
// A value that represents "not a number".
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
index 90c3fed2e8..8c24076aa9 100644
--- a/tensorflow/core/util/device_name_utils.cc
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -184,16 +184,65 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
return true;
}
+namespace {
+
+void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
+ DeviceNameUtils::ParsedName* parsed_name) {
+ if (!parsed_name->has_job) {
+ parsed_name->job = parsed_basename.job;
+ parsed_name->has_job = true;
+ }
+ if (!parsed_name->has_replica) {
+ parsed_name->replica = parsed_basename.replica;
+ parsed_name->has_replica = true;
+ }
+ if (!parsed_name->has_task) {
+ parsed_name->task = parsed_basename.task;
+ parsed_name->has_task = true;
+ }
+ if (!parsed_name->has_type) {
+ parsed_name->type = parsed_basename.type;
+ parsed_name->has_type = true;
+ }
+ if (!parsed_name->has_id) {
+ parsed_name->id = parsed_basename.id;
+ parsed_name->has_id = true;
+ }
+}
+
+} // namespace
+
/* static */
-string DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname) {
+Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname,
+ StringPiece basename,
+ string* canonical_name) {
+ *canonical_name = "";
+ ParsedName parsed_basename;
+ if (!ParseFullName(basename, &parsed_basename)) {
+ return errors::InvalidArgument("Could not parse basename: ", basename,
+ " into a device specification.");
+ }
+ if (!(parsed_basename.has_job && parsed_basename.has_replica &&
+ parsed_basename.has_task && parsed_basename.has_type &&
+ parsed_basename.has_id)) {
+ return errors::InvalidArgument("Basename: ", basename,
+ " should be fully "
+ "specified.");
+ }
ParsedName parsed_name;
if (ParseLocalName(fullname, &parsed_name)) {
- return ParsedNameToString(parsed_name);
+ CompleteName(parsed_basename, &parsed_name);
+ *canonical_name = ParsedNameToString(parsed_name);
+ return Status::OK();
}
if (ParseFullName(fullname, &parsed_name)) {
- return ParsedNameToString(parsed_name);
+ CompleteName(parsed_basename, &parsed_name);
+ *canonical_name = ParsedNameToString(parsed_name);
+ return Status::OK();
}
- return "";
+ return errors::InvalidArgument("Could not parse ", fullname,
+ " into a device "
+ "specification.");
}
/* static */
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
index 0ae28df997..4071a70836 100644
--- a/tensorflow/core/util/device_name_utils.h
+++ b/tensorflow/core/util/device_name_utils.h
@@ -88,10 +88,14 @@ class DeviceNameUtils {
// Parses "fullname" into "*parsed". Returns true iff succeeds.
static bool ParseFullName(StringPiece fullname, ParsedName* parsed);
- // Canonicalizes "fullname". Accepts both legacy, newer and local versions of
- // the device spec. Returns the newer version of the device spec. If we were
- // unable to interpret / parse "fullname" returns "".
- static string CanonicalizeDeviceName(StringPiece fullname);
+ // Canonicalizes "fullname" into "*canonical_name". Uses a fully specified
+ // basename to fill in fields that are missing. Accepts both legacy, newer
+ // and local versions of the device spec. Returns the newer version of the
+ // device spec. If we were unable to interpret / parse "fullname" returns
+ // an error and *canonical_name is set to "".
+ static Status CanonicalizeDeviceName(StringPiece fullname,
+ StringPiece basename,
+ string* canonical_name);
// Returns true if "name" specifies any non-trivial constraint on the device.
static bool HasSomeDetails(const ParsedName& name) {
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index ff9c108f10..dafb3b20b9 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -467,18 +467,41 @@ TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) {
}
TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) {
- EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1",
- DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/replica:10/task:0/device:CPU:1"));
- EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1",
- DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/task:0/replica:10/device:CPU:1"));
- EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1",
- DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/task:0/replica:10/cpu:1"));
- EXPECT_EQ("/device:CPU:0", DeviceNameUtils::CanonicalizeDeviceName("CPU:0"));
- EXPECT_EQ("", DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/task:0/replica/cpu:1"));
+ string canonical_name;
+ {
+ // Good basename.
+ string basename = "/job:foo/replica:10/task:0/device:CPU:0";
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/replica:10/task:0/device:CPU:1", basename, &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name);
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/task:0/replica:10/device:CPU:1", basename, &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name);
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/task:0/replica:10/cpu:1", basename, &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name);
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName("CPU:0", basename,
+ &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:0", canonical_name);
+ Status s = DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/task:0/replica/cpu:1", basename, &canonical_name);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_EQ("", canonical_name);
+ }
+
+ {
+ // Try out malformed basenames.
+ string fullname = "/device:CPU:0";
+
+ Status s = DeviceNameUtils::CanonicalizeDeviceName(
+ fullname, "/device:CPU:0", &canonical_name);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_EQ("", canonical_name);
+ s = DeviceNameUtils::CanonicalizeDeviceName(
+ fullname, "/job:foo/task:0/replica/cpu:1", &canonical_name);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_EQ("", canonical_name);
+ }
}
static void BM_ParseFullName(int iters) {
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
index ee43945a39..90672a10a8 100644
--- a/tensorflow/core/util/saved_tensor_slice_util.h
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -123,6 +123,7 @@ TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
#undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX
#undef TENSOR_PROTO_EXTRACT_TYPE_HELPER
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index f7e116bf0f..ce43d09b63 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1308,12 +1308,10 @@ See also
: : : parameters of type T and M of :
: : : arbitrary type :
| `dimensions` | `int64` array | array of map dimensions |
-| `static_operands` | sequence of M `XlaOp`s | M arrays of arbitrary type |
Applies a scalar function over the given `operands` arrays, producing an array
of the same dimensions where each element is the result of the mapped function
-applied to the corresponding elements in the input arrays with `static_operands`
-given as additional input to `computation`.
+applied to the corresponding elements in the input arrays.
The mapped function is an arbitrary computation with the restriction that it has
N inputs of scalar type `T` and a single output with type `S`. The output has
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5d9a5130a0..f19bdeaa39 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3925,7 +3925,7 @@ tf_cuda_library(
tf_py_test(
name = "session_test",
- size = "small",
+ size = "medium",
srcs = ["client/session_test.py"],
additional_deps = [
":array_ops",
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 35aa37ac6d..f3b788f931 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1291,7 +1291,7 @@ class BaseSession(SessionInterface):
raise type(e)(node_def, op, message)
def _extend_graph(self):
- with self._graph._lock: # pylint: disable=protected-access
+ with self._graph._session_run_lock(): # pylint: disable=protected-access
tf_session.ExtendSession(self._session)
# The threshold to run garbage collection to delete dead tensors.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index e49d067105..b72e029d1c 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import random
import os
import sys
import threading
@@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase):
for t in threads:
t.join()
- def testParallelRunAndBuild(self):
+ @staticmethod
+ def _build_graph():
+ time.sleep(random.random() * 0.1)
+ # Do some graph construction. Try to exercise non-trivial paths.
+ graph = ops.get_default_graph()
+ gdef = None
+ for _ in range(10):
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ with ops.colocate_with(x):
+ y = array_ops.placeholder(dtype=dtypes.float32)
+ with ops.device('/cpu:0'):
+ z = control_flow_ops.while_loop(
+ lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
+ with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
+ gradients_impl.gradients(z, [x, y])
+ if gdef is None:
+ gdef = graph.as_graph_def()
+ else:
+ importer.import_graph_def(gdef, name='import')
+
+ def testParallelRunAndSingleBuild(self):
with session.Session() as sess:
c = constant_op.constant(5.0)
stop = threading.Event()
def run_loop():
while not stop.is_set():
+ time.sleep(random.random() * 0.1)
self.assertEqual(sess.run(c), 5.0)
- threads = [self.checkedThread(target=run_loop) for _ in range(100)]
+ threads = [self.checkedThread(target=run_loop) for _ in range(10)]
for t in threads:
t.start()
- # Do some graph construction. Try to exercise non-trivial paths.
- graph = ops.get_default_graph()
- gdef = None
- for _ in range(10):
- x = array_ops.placeholder(dtype=dtypes.float32)
- with ops.colocate_with(x):
- y = array_ops.placeholder(dtype=dtypes.float32)
- with ops.device('/cpu:0'):
- z = control_flow_ops.while_loop(
- lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
- with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
- gradients_impl.gradients(z, [x, y])
- if gdef is None:
- gdef = graph.as_graph_def()
- else:
- importer.import_graph_def(gdef, name='import')
+ SessionTest._build_graph()
stop.set()
for t in threads:
t.join()
+ def testParallelRunAndParallelBuild(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ stop = threading.Event()
+
+ def run_loop():
+ while not stop.is_set():
+ time.sleep(random.random() * 0.1)
+ self.assertEqual(sess.run(c), 5.0)
+
+ run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
+ for t in run_threads:
+ t.start()
+
+ build_threads = [self.checkedThread(target=SessionTest._build_graph)
+ for _ in range(10)]
+ for t in build_threads:
+ t.start()
+ for t in build_threads:
+ t.join()
+
+ # Let the run_threads run until the build threads are finished.
+ stop.set()
+ for t in run_threads:
+ t.join()
+
def testRunFeedDict(self):
with session.Session() as s:
x = array_ops.zeros([2])
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index 50bb0837b7..c3d42b49af 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -18,9 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -461,5 +464,55 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
5, padded_shapes=shape_as_tensor)
+class BatchDatasetBenchmark(test.Benchmark):
+
+ def benchmarkBatchSparse(self):
+ non_zeros_per_row_values = [0, 1, 5, 10, 100]
+ batch_size_values = [1, 32, 64, 128, 1024]
+
+ sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
+ batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
+
+ dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
+ ).batch(batch_size_placeholder)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ for non_zeros_per_row in non_zeros_per_row_values:
+
+ sparse_value = sparse_tensor.SparseTensorValue(
+ indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
+ values=np.arange(non_zeros_per_row, dtype=np.int64),
+ dense_shape=[1000])
+
+ for batch_size in batch_size_values:
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer, feed_dict={
+ sparse_placeholder: sparse_value,
+ batch_size_placeholder: batch_size})
+ # Run five steps to warm up the session caches before taking the
+ # first measurement.
+ for _ in range(5):
+ sess.run(next_element.indices.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.indices.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100.0
+
+ print('Batch sparse dataset non-zeros per row: %d batch_size: %d '
+ 'wall time: %f'
+ % (non_zeros_per_row, batch_size, median_wall_time))
+ self.report_benchmark(
+ iters=10000, wall_time=median_wall_time,
+ name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % (
+ non_zeros_per_row, batch_size))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index fc68e945c0..a81ef90513 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -47,8 +47,11 @@ def capture_value(tensor_map, value, dtype, name):
"""Capture a value from outside the function, to pass in as an extra arg."""
captured_value = tensor_map.get(ops.tensor_id(value), None)
if captured_value is None:
- captured_value = graph_placeholder(
- dtype=dtype or value.dtype, shape=value.shape, name=name)
+ # Note: setting ops.control_dependencies(None) ensures we always put
+ # capturing placeholders outside of any control flow context.
+ with ops.control_dependencies(None):
+ captured_value = graph_placeholder(
+ dtype=dtype or value.dtype, shape=value.shape, name=name)
if captured_value.dtype == dtypes_module.resource:
if ops._USE_C_SHAPES: # pylint: disable=protected-access
if isinstance(value, ops.EagerTensor):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a5df3ef530..9e5754fc4c 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -210,6 +210,21 @@ class FunctionTest(test.TestCase):
compiled = function.defun(f)
compiled()
+ def testVariableInLoopInFunction(self):
+
+ @function.defun
+ def test_function():
+
+ def loop_test(_):
+ return False
+
+ def loop_body(_):
+ return variable_scope.get_variable('a', shape=())
+
+ return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
+
+ self.assertEqual(test_function().shape, [])
+
def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
with context.graph_mode():
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 90889e3e5d..2c7c4285ca 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -230,6 +230,17 @@ class DNNClassifier(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = DNNClassifier(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
@@ -317,8 +328,9 @@ class DNNClassifier(estimator.Estimator):
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
Also there will be errors if vocabulary is not provided and labels are
string.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to Adagrad optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
@@ -385,6 +397,17 @@ class DNNRegressor(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = DNNRegressor(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = DNNRegressor(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
@@ -465,8 +488,9 @@ class DNNRegressor(estimator.Estimator):
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
then weight_column.normalizer_fn is applied on it to get weight tensor.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to Adagrad optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 3d1ad1365b..2f20e4b289 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -257,12 +257,19 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
# warm-start settings
warm_start_from="/path/to/checkpoint/dir")
- # To apply L1 and L2 regularization, you can set optimizers as follows:
+ # To apply L1 and L2 regularization, you can set dnn_optimizer to:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
- # It is same for FtrlOptimizer.
+ # To apply learning rate decay, you can set dnn_optimizer to a callable:
+ lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96)
+ # It is the same for linear_optimizer.
# Input builders
def input_fn_train: # returns x, y
@@ -325,12 +332,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the linear part of the model. Defaults to FTRL optimizer.
+ the linear part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL
+ optimizer.
dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the deep part of the model. Defaults to Adagrad optimizer.
+ the deep part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad
+ optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully
connected.
dnn_activation_fn: Activation function applied to each layer. If None,
@@ -441,12 +452,19 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
# warm-start settings
warm_start_from="/path/to/checkpoint/dir")
- # To apply L1 and L2 regularization, you can set optimizers as follows:
+ # To apply L1 and L2 regularization, you can set dnn_optimizer to:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
- # It is same for FtrlOptimizer.
+ # To apply learning rate decay, you can set dnn_optimizer to a callable:
+ lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96)
+ # It is the same for linear_optimizer.
# Input builders
def input_fn_train: # returns x, y
@@ -508,12 +526,16 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the linear part of the model. Defaults to FTRL optimizer.
+ the linear part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL
+ optimizer.
dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the deep part of the model. Defaults to Adagrad optimizer.
+ the deep part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad
+ optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully
connected.
dnn_activation_fn: Activation function applied to each layer. If None,
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index ac59e786c4..e22df849e5 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -193,6 +193,17 @@ class LinearClassifier(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = LinearClassifier(
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=lambda: tf.train.FtrlOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = LinearClassifier(
feature_columns=[categorical_column_a,
@@ -272,8 +283,9 @@ class LinearClassifier(estimator.Estimator):
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
Also there will be errors if vocabulary is not provided and labels are
string.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to FTRL optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
@@ -335,10 +347,31 @@ class LinearRegressor(estimator.Estimator):
categorical_feature_a_x_categorical_feature_b = crossed_column(...)
+ # Estimator using the default optimizer.
estimator = LinearRegressor(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b])
+ # Or estimator using the FTRL optimizer with regularization.
+ estimator = LinearRegressor(
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=tf.train.FtrlOptimizer(
+ learning_rate=0.1,
+ l1_regularization_strength=0.001
+ ))
+
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = LinearRegressor(
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=lambda: tf.train.FtrlOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = LinearRegressor(
feature_columns=[categorical_column_a,
@@ -409,8 +442,9 @@ class LinearRegressor(estimator.Estimator):
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
then weight_column.normalizer_fn is applied on it to get weight tensor.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to FTRL optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
diff --git a/tensorflow/python/estimator/canned/optimizers.py b/tensorflow/python/estimator/canned/optimizers.py
index f72c5ca5cb..8f51cc3a80 100644
--- a/tensorflow/python/estimator/canned/optimizers.py
+++ b/tensorflow/python/estimator/canned/optimizers.py
@@ -72,6 +72,8 @@ def get_optimizer_instance(opt, learning_rate=None):
raise ValueError(
'Unsupported optimizer name: {}. Supported names are: {}'.format(
opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES)))))
+ if callable(opt):
+ opt = opt()
if not isinstance(opt, optimizer_lib.Optimizer):
raise ValueError(
'The given object is not an Optimizer instance. Given: {}'.format(opt))
diff --git a/tensorflow/python/estimator/canned/optimizers_test.py b/tensorflow/python/estimator/canned/optimizers_test.py
index ee28756155..eadabdbc49 100644
--- a/tensorflow/python/estimator/canned/optimizers_test.py
+++ b/tensorflow/python/estimator/canned/optimizers_test.py
@@ -28,6 +28,13 @@ from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import rmsprop
+class _TestOptimizer(optimizer_lib.Optimizer):
+
+ def __init__(self):
+ super(_TestOptimizer, self).__init__(
+ use_locking=False, name='TestOptimizer')
+
+
class GetOptimizerInstance(test.TestCase):
def test_unsupported_name(self):
@@ -66,12 +73,6 @@ class GetOptimizerInstance(test.TestCase):
self.assertAlmostEqual(0.1, opt._learning_rate)
def test_object(self):
- class _TestOptimizer(optimizer_lib.Optimizer):
-
- def __init__(self):
- super(_TestOptimizer, self).__init__(
- use_locking=False, name='TestOptimizer')
-
opt = optimizers.get_optimizer_instance(_TestOptimizer())
self.assertIsInstance(opt, _TestOptimizer)
@@ -80,6 +81,23 @@ class GetOptimizerInstance(test.TestCase):
ValueError, 'The given object is not an Optimizer instance'):
optimizers.get_optimizer_instance((1, 2, 3))
+ def test_callable(self):
+ def _optimizer_fn():
+ return _TestOptimizer()
+ opt = optimizers.get_optimizer_instance(_optimizer_fn)
+ self.assertIsInstance(opt, _TestOptimizer)
+
+ def test_lambda(self):
+ opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda
+ self.assertIsInstance(opt, _TestOptimizer)
+
+ def test_callable_returns_invalid(self):
+ def _optimizer_fn():
+ return (1, 2, 3)
+ with self.assertRaisesRegexp(
+ ValueError, 'The given object is not an Optimizer instance'):
+ optimizers.get_optimizer_instance(_optimizer_fn)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 72eb7e0eeb..699d2b70d1 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -407,11 +407,11 @@ def import_graph_def(graph_def,
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
- # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
- # call cannot occur between creating the TF_Operations in the
+ # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
+ # Session.run call cannot occur between creating the TF_Operations in the
# TF_GraphImportGraphDefWithResults call and mutating the them in
# _ProcessNewOps.
- with graph._lock: # pylint: disable=protected-access
+ with graph._mutation_lock(): # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
results = c_api.TF_GraphImportGraphDefWithResults(
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 05f9ae21b1..cf0b1e36fb 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
@@ -2599,6 +2600,10 @@ def _name_from_scope_name(name):
return name[:-1] if (name and name[-1] == "/") else name
+_MUTATION_LOCK_GROUP = 0
+_SESSION_RUN_LOCK_GROUP = 1
+
+
@tf_export("Graph")
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@@ -2648,20 +2653,21 @@ class Graph(object):
def __init__(self):
"""Creates a new, empty Graph."""
- # Protects core state that can be returned via public accessors, as well as
- # synchronizes Session.run calls with methods that create and mutate ops
- # (e.g. Graph.create_op()). This synchronization is necessary because it's
- # illegal to modify an operation after it's been run. Thread-safety is
- # provided on a best-effort basis to support buggy programs, and is not
- # guaranteed by the public `tf.Graph` API.
- #
- # The lock must be reentrant because create_op can be called recursively due
- # to control flow. Without a reentrant lock, many methods would also need a
- # "locked" version or parameter (including generated code).
+ # Protects core state that can be returned via public accessors.
+ # Thread-safety is provided on a best-effort basis to support buggy
+ # programs, and is not guaranteed by the public `tf.Graph` API.
#
# NOTE(mrry): This does not protect the various stacks. A warning will
# be reported if these are used from multiple threads
self._lock = threading.RLock()
+ # The group lock synchronizes Session.run calls with methods that create
+ # and mutate ops (e.g. Graph.create_op()). This synchronization is
+ # necessary because it's illegal to modify an operation after it's been run.
+ # The group lock allows any number of threads to mutate ops at the same time
+ # but if any modification is going on, all Session.run calls have to wait.
+ # Similarly, if one or more Session.run calls are going on, all mutate ops
+ # have to wait until all Session.run calls have finished.
+ self._group_lock = lock_util.GroupLock(num_groups=2)
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
@@ -3192,9 +3198,9 @@ class Graph(object):
input_ops = set([t.op for t in inputs])
control_inputs = self._control_dependencies_for_inputs(input_ops)
- # _create_op_helper mutates the new Operation. _lock ensures a Session.run
- # call cannot occur between creating and mutating the op.
- with self._lock:
+ # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
+ # Session.run call cannot occur between creating and mutating the op.
+ with self._mutation_lock():
ret = Operation(
node_def,
self,
@@ -4727,6 +4733,20 @@ class Graph(object):
else:
self._graph_control_dependencies_stack = control_dependencies
+ def _mutation_lock(self):
+ """Returns a lock to guard code that creates & mutates ops.
+
+ See the comment for self._group_lock for more info.
+ """
+ return self._group_lock.group(_MUTATION_LOCK_GROUP)
+
+ def _session_run_lock(self):
+ """Returns a lock to guard code for Session.run.
+
+ See the comment for self._group_lock for more info.
+ """
+ return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
+
# TODO(agarwal): currently device directives in an outer eager scope will not
# apply to inner graph mode code. Fix that.
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 3988238609..1b5db17ae7 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -414,8 +414,28 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections}
for _ in range(3):
f(self, **kwargs)
+ # Note that gc.get_objects misses anything that isn't subject to garbage
+ # collection (C types). Collections are a common source of leaks, so we
+ # test for collection sizes explicitly.
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).")
+ % (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 3edb8033ff..aa84eaa8ab 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -665,14 +666,14 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
- return layer_utils.gather_trainable_weights(
+ return checkpointable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self.layers,
extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- return layer_utils.gather_non_trainable_weights(
+ return checkpointable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self.layers,
extra_variables=self._extra_variables)
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 1beb0e396e..671508ab4e 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -604,6 +604,25 @@ class FunctionalOpsTest(test.TestCase):
mul = sess.run(remote_op)
self.assertEqual(mul, [6])
+ def testRemoteFunctionSameDeviceDirectSession(self):
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def _remote_fn(a, b):
+ return math_ops.multiply(a, b)
+
+ with ops.device("/cpu:0"):
+ a = variables.Variable(2, dtype=dtypes.int32)
+ b = variables.Variable(3, dtype=dtypes.int32)
+
+ with ops.device("/cpu:0"):
+ remote_op = functional_ops.remote_call(
+ args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ mul = sess.run(remote_op)
+ self.assertEqual(mul, [6])
+
def testRemoteFunctionCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 837c144467..c8442b42d5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2943,9 +2943,10 @@ class WhileContext(ControlFlowContext):
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
try:
self.Enter()
- # _BuildLoop calls _update_input in several places. _lock ensures a
- # Session.run call cannot occur between creating and mutating new ops.
- with ops.get_default_graph()._lock: # pylint: disable=protected-access
+ # _BuildLoop calls _update_input in several places. _mutation_lock()
+ # ensures a Session.run call cannot occur between creating and mutating
+ # new ops.
+ with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
original_body_result, exit_vars = self._BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants)
finally:
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 99909ac38e..250b9285c9 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -534,10 +534,10 @@ def gradients(ys,
RuntimeError: if called in Eager mode.
"""
- # Creating the gradient graph for control flow mutates Operations. _lock
- # ensures a Session.run call cannot occur between creating and mutating new
- # ops.
- with ops.get_default_graph()._lock: # pylint: disable=protected-access
+ # Creating the gradient graph for control flow mutates Operations.
+ # _mutation_lock ensures a Session.run call cannot occur between creating and
+ # mutating new ops.
+ with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients)
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index b80f84eb7c..00150fe688 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -306,10 +306,11 @@ def create_db_writer(db_uri,
def _make_summary_writer(name, factory, **kwargs):
resource = gen_summary_ops.summary_writer(shared_name=name)
init_op_fn = lambda: factory(resource, **kwargs)
- # TODO(apassos): Consider doing this instead.
- # if not context.executing_eagerly():
- # ops.get_default_session().run(init_op)
- ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn())
+ init_op = init_op_fn()
+ if not context.executing_eagerly():
+ # TODO(apassos): Consider doing this instead.
+ # ops.get_default_session().run(init_op)
+ ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op)
return SummaryWriter(resource, init_op_fn)
@@ -380,7 +381,8 @@ def summary_writer_function(name, tensor, function, family=None):
with ops.device("cpu:0"):
op = smart_cond.smart_cond(
should_record_summaries(), record, _nothing, name="")
- ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
+ if not context.executing_eagerly():
+ ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 9232b6089a..54f359489e 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -62,11 +62,18 @@ py_test(
)
py_library(
+ name = "layer_utils",
+ srcs = ["layer_utils.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
name = "data_structures",
srcs = ["data_structures.py"],
srcs_version = "PY2AND3",
deps = [
":base",
+ ":layer_utils",
],
)
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 680cf3441f..c46585b417 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -21,10 +21,9 @@ import collections
import six
-from tensorflow.python.keras.engine import base_layer
-from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import layer_utils
# TODO(allenl): We could track regular Python data structures which get assigned
@@ -54,7 +53,8 @@ class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase):
("Only checkpointable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
- if isinstance(value, (base_layer.Layer, CheckpointableDataStructure)):
+ if (isinstance(value, CheckpointableDataStructure)
+ or layer_utils.is_layer(value)):
if value not in self._layers:
self._layers.append(value)
if hasattr(value, "_use_resource_variables"):
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
new file mode 100644
index 0000000000..fdcf963d32
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities related to layer/model functionality."""
+
+# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils
+# once __init__ files no longer require all of tf.keras to be imported together.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def is_layer(obj):
+ """Implicit check for Layer-like objects."""
+ # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
+ return (hasattr(obj, "call")
+ and hasattr(obj, "build")
+ and hasattr(obj, "variables"))
+
+
+def gather_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected trainable weights/variables.
+ """
+ if not trainable:
+ return []
+ weights = []
+ for layer in sub_layers:
+ weights += layer.trainable_weights
+ trainable_extra_variables = [
+ v for v in extra_variables if v.trainable]
+ return weights + trainable_extra_variables
+
+
+def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the non-trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected non-trainable weights/variables.
+ """
+ trainable_extra_variables = []
+ non_trainable_extra_variables = []
+ for v in extra_variables:
+ if v.trainable:
+ trainable_extra_variables.append(v)
+ else:
+ non_trainable_extra_variables.append(v)
+ weights = []
+ for layer in sub_layers:
+ weights += layer.non_trainable_weights
+ if not trainable:
+ trainable_weights = []
+ for layer in sub_layers:
+ trainable_weights += layer.trainable_weights
+ return (trainable_weights + trainable_extra_variables
+ + weights + non_trainable_extra_variables)
+ return weights + non_trainable_extra_variables
diff --git a/tensorflow/python/util/lock_util_test.py b/tensorflow/python/util/lock_util_test.py
index 2ac640ff99..cda8f95225 100644
--- a/tensorflow/python/util/lock_util_test.py
+++ b/tensorflow/python/util/lock_util_test.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import random
-import threading
import time
from absl.testing import parameterized
@@ -48,7 +47,7 @@ class GroupLockTest(test.TestCase, parameterized.TestCase):
finished.add(thread_id)
threads = [
- threading.Thread(target=thread_fn, args=(i,))
+ self.checkedThread(target=thread_fn, args=(i,))
for i in range(num_threads)
]
diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py
index 7b8f27c1b1..dbff904abe 100644
--- a/tensorflow/tools/api/generator/doc_srcs_test.py
+++ b/tensorflow/tools/api/generator/doc_srcs_test.py
@@ -39,27 +39,27 @@ class DocSrcsTest(test.TestCase):
file_path += '/'
file_path += '__init__.py'
- if file_path not in FLAGS.outputs:
- self.assertFalse('%s is not a valid API module' % module_name)
+ self.assertIn(
+ file_path, FLAGS.outputs,
+ msg='%s is not a valid API module' % module_name)
def testHaveDocstringOrDocstringModule(self):
for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
- if docsrc.docstring and docsrc.docstring_module_name:
- self.assertFalse(
- '%s contains DocSource has both a docstring and a '
- 'docstring_module_name. '
- 'Only one of "docstring" or "docstring_module_name" should be set.'
- % (module_name))
+ self.assertFalse(
+ docsrc.docstring and docsrc.docstring_module_name,
+ msg=('%s contains DocSource has both a docstring and a '
+ 'docstring_module_name. Only one of "docstring" or '
+ '"docstring_module_name" should be set.') % (module_name))
def testDocstringModulesAreValidModules(self):
for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
if docsrc.docstring_module_name:
doc_module_name = '.'.join([
FLAGS.package, docsrc.docstring_module_name])
- if doc_module_name not in sys.modules:
- self.assertFalse(
- 'docsources_module %s is not a valid module under %s.' %
- (docsrc.docstring_module_name, FLAGS.package))
+ self.assertIn(
+ doc_module_name, sys.modules,
+ msg=('docsources_module %s is not a valid module under %s.' %
+ (docsrc.docstring_module_name, FLAGS.package)))
if __name__ == '__main__':
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index eea712c279..2403e2d966 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -39,6 +39,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:platform",
+ "//tensorflow/python:util",
"@astor_archive//:astor",
],
)
@@ -95,6 +96,7 @@ py_binary(
deps = [
":generate_lib",
"//tensorflow:tensorflow_py",
+ "//tensorflow/python:util",
"//tensorflow/python/debug:debug_py",
],
)
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 67c413cccb..e7634cd5dc 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -388,16 +388,40 @@ def _build_guide_index(guide_src_dir):
class _UpdateTags(py_guide_parser.PyGuideParser):
- """Rewrites a Python guide so that each section has an explicit tag."""
+ """Rewrites a Python guide so that each section has an explicit id tag.
+
+ "section" here refers to blocks delimited by second level headings.
+ """
def process_section(self, line_number, section_title, tag):
self.replace_line(line_number, '<h2 id="%s">%s</h2>' % (tag, section_title))
+def update_id_tags_inplace(src_dir):
+ """Set explicit ids on all second-level headings to ensure back-links work.
+
+ Args:
+ src_dir: The directory of md-files to convert (inplace).
+ """
+ tag_updater = _UpdateTags()
+
+ for dirpath, _, filenames in os.walk(src_dir):
+ for base_name in filenames:
+ if not base_name.endswith('.md'):
+ continue
+ full_path = os.path.join(src_dir, dirpath, base_name)
+
+ # Tag updater loads the file, makes the replacements, and returns the
+ # modified file contents
+ content = tag_updater.process(full_path)
+ with open(full_path, 'w') as f:
+ f.write(content)
+
+
EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
-def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
+def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
"""Fix @{} references in all files under `src_dir` matching `file_pattern`.
A matching directory structure, with the modified files is
@@ -418,7 +442,6 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
using fnmatch. Non-matching files are copied unchanged.
"""
# Iterate through all the source files and process them.
- tag_updater = _UpdateTags()
for dirpath, _, filenames in os.walk(src_dir):
# How to get from `dirpath` to api_docs/python/
relative_path_to_root = os.path.relpath(
@@ -435,24 +458,25 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
continue
full_in_path = os.path.join(dirpath, base_name)
+ # Set the `current_doc_full_name` so bad files can be reported on errors.
reference_resolver.current_doc_full_name = full_in_path
suffix = os.path.relpath(path=full_in_path, start=src_dir)
full_out_path = os.path.join(output_dir, suffix)
+ # Copy files that do not match the file_pattern, unmodified.
if not fnmatch.fnmatch(base_name, file_pattern):
shutil.copyfile(full_in_path, full_out_path)
continue
- if dirpath.endswith('/api_guides/python'):
- content = tag_updater.process(full_in_path)
- else:
- with open(full_in_path, 'rb') as f:
- content = f.read().decode('utf-8')
+
+ with open(full_in_path, 'rb') as f:
+ content = f.read().decode('utf-8')
content = reference_resolver.replace_references(content,
relative_path_to_root)
with open(full_out_path, 'wb') as f:
f.write(content.encode('utf-8'))
+
class DocGenerator(object):
"""Main entry point for generating docs."""
@@ -538,15 +562,43 @@ class DocGenerator(object):
self._do_not_descend_map)
def build(self, flags):
- """Actually build the docs."""
+ """Build all the docs.
+
+ This produces two outputs
+
+ python api docs:
+
+ * generated from modules set with `set_py_modules`.
+ * written to '{FLAGS.output_dir}/api_docs/python/'
+
+ non-api docs:
+
+ * Everything in '{FLAGS.src_dir}' is copied to '{FLAGS.output_dir}'.
+ * '@{}' references in '.md' files are replaced with links.
+ * '.md' files under 'api_guides/python' have explicit ids set for their
+ second level headings.
+
+ Args:
+ flags:
+ * src_dir: Where to fetch the non-api-docs.
+ * base_dir: Base of the docs directory (Used to build correct
+ relative links).
+ * output_dir: Where to write the resulting docs.
+
+ Returns:
+ The number of errors encountered while processing.
+ """
+ # Extract the python api from the _py_modules
doc_index = build_doc_index(flags.src_dir)
visitor = self.run_extraction()
reference_resolver = self.make_reference_resolver(visitor, doc_index)
+ # Build the guide_index for the api_docs back links.
root_title = getattr(flags, 'root_title', 'TensorFlow')
guide_index = _build_guide_index(
os.path.join(flags.src_dir, 'api_guides/python'))
+ # Write the api docs.
parser_config = self.make_parser_config(visitor, reference_resolver,
guide_index, flags.base_dir)
output_dir = os.path.join(flags.output_dir, 'api_docs/python')
@@ -557,8 +609,16 @@ class DocGenerator(object):
yaml_toc=self.yaml_toc,
root_title=root_title,
search_hints=getattr(flags, 'search_hints', True))
- _other_docs(flags.src_dir, flags.output_dir, reference_resolver)
+ # Replace all the @{} references in files under `FLAGS.src_dir`
+ replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md')
+ # Fix the tags in the guide dir.
+ guide_dir = os.path.join(flags.output_dir, 'api_guides/python')
+ if os.path.exists(guide_dir):
+ update_id_tags_inplace(guide_dir)
+
+ # Report all errors found by the reference resolver, and return the error
+ # code.
parser_config.reference_resolver.log_errors()
return parser_config.reference_resolver.num_errors()
diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py
index ea6d28a02b..7a6f9fd9f7 100644
--- a/tensorflow/tools/docs/generate_lib_test.py
+++ b/tensorflow/tools/docs/generate_lib_test.py
@@ -51,7 +51,9 @@ class DummyVisitor(object):
class GenerateTest(googletest.TestCase):
- def test_write(self):
+ def get_test_objects(self):
+ # These are all mutable objects, so rebuild them for each test.
+ # Don't cache the objects.
module = sys.modules[__name__]
index = {
@@ -98,6 +100,11 @@ class GenerateTest(googletest.TestCase):
guide_index={},
base_dir=base_dir)
+ return reference_resolver, parser_config
+
+ def test_write(self):
+ _, parser_config = self.get_test_objects()
+
output_dir = googletest.GetTempDir()
generate_lib.write_docs(output_dir, parser_config, yaml_toc=True)
@@ -127,6 +134,107 @@ class GenerateTest(googletest.TestCase):
os.path.exists(
os.path.join(output_dir, 'tf/TestModule/test_function.md')))
+ def test_update_id_tags_inplace(self):
+ test_dir = googletest.GetTempDir()
+ test_sub_dir = os.path.join(test_dir, 'a/b')
+ os.makedirs(test_sub_dir)
+
+ test_path1 = os.path.join(test_dir, 'file1.md')
+ test_path2 = os.path.join(test_sub_dir, 'file2.md')
+ test_path3 = os.path.join(test_sub_dir, 'file3.notmd')
+
+ with open(test_path1, 'w') as f:
+ f.write('## abc&123')
+
+ with open(test_path2, 'w') as f:
+ f.write('# A Level 1 Heading\n')
+ f.write('## A Level 2 Heading')
+
+ with open(test_path3, 'w') as f:
+ f.write("## don\'t change this")
+
+ generate_lib.update_id_tags_inplace(test_dir)
+
+ with open(test_path1) as f:
+ content = f.read()
+
+ self.assertEqual(content, '<h2 id="abc_123">abc&123</h2>')
+
+ with open(test_path2) as f:
+ content = f.read()
+
+ self.assertEqual(
+ content, '# A Level 1 Heading\n'
+ '<h2 id="A_Level_2_Heading">A Level 2 Heading</h2>')
+
+ with open(test_path3) as f:
+ content = f.read()
+
+ self.assertEqual(content, "## don\'t change this")
+
+ def test_replace_refes(self):
+ test_dir = googletest.GetTempDir()
+ test_in_dir = os.path.join(test_dir, 'in')
+ test_in_dir_a = os.path.join(test_dir, 'in/a')
+ test_in_dir_b = os.path.join(test_dir, 'in/b')
+ os.makedirs(test_in_dir)
+ os.makedirs(test_in_dir_a)
+ os.makedirs(test_in_dir_b)
+
+ test_out_dir = os.path.join(test_dir, 'out')
+ os.makedirs(test_out_dir)
+
+ test_path1 = os.path.join(test_in_dir_a, 'file1.md')
+ test_path2 = os.path.join(test_in_dir_b, 'file2.md')
+ test_path3 = os.path.join(test_in_dir_b, 'file3.notmd')
+ test_path4 = os.path.join(test_in_dir_b, 'OWNERS')
+
+ with open(test_path1, 'w') as f:
+ f.write('Use `tf.test_function` to test things.')
+
+ with open(test_path2, 'w') as f:
+ f.write('Use @{tf.TestModule.TestClass.ChildClass} to test things.\n'
+ "`tf.whatever` doesn't exist")
+
+ with open(test_path3, 'w') as f:
+ file3_content = (
+ 'Not a .md file. Should be copied unchanged:'
+ '@{tf.TestModule.TestClass.ChildClass}, `tf.test_function`')
+ f.write(file3_content)
+
+ with open(test_path4, 'w') as f:
+ f.write('')
+
+ reference_resolver, _ = self.get_test_objects()
+ generate_lib.replace_refs(test_in_dir, test_out_dir, reference_resolver,
+ '*.md')
+
+ with open(os.path.join(test_out_dir, 'a/file1.md')) as f:
+ content = f.read()
+ self.assertEqual(
+ content,
+ 'Use <a href="../api_docs/python/tf/TestModule/test_function.md">'
+ '<code>tf.test_function</code></a> to test things.')
+
+ with open(os.path.join(test_out_dir, 'b/file2.md')) as f:
+ content = f.read()
+ self.assertEqual(
+ content,
+ 'Use '
+ '<a href="../api_docs/python/tf/TestModule/TestClass/ChildClass.md">'
+ '<code>tf.TestModule.TestClass.ChildClass</code></a> '
+ 'to test things.\n'
+ '`tf.whatever` doesn\'t exist')
+
+ with open(os.path.join(test_out_dir, 'b/file3.notmd')) as f:
+ content = f.read()
+ self.assertEqual(content, file3_content)
+
+ with self.assertRaises(IOError):
+ # This should fail. The OWNERS file should not be copied
+ with open(os.path.join(test_out_dir, 'b/OWNERS')) as f:
+ content = f.read()
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 5cefe37782..7e4676e522 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7f7cea53068238fca7b7e4299793a0c77bea7219.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/7f7cea53068238fca7b7e4299793a0c77bea7219.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/8a152c54c401f9a9370bedf05049ac5b847bc965.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/8a152c54c401f9a9370bedf05049ac5b847bc965.tar.gz",
],
- sha256 = "b645507080e07c845607f212d45e4ee79253c3c9b762531f51fbaeceb6b47391",
- strip_prefix = "llvm-7f7cea53068238fca7b7e4299793a0c77bea7219",
+ sha256 = "dad37678abffa4f3001b1789a89f64f245bc50721f8d37b4f8b31b0695e90015",
+ strip_prefix = "llvm-8a152c54c401f9a9370bedf05049ac5b847bc965",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl
index 6b7e5a8808..ffba9850bb 100644
--- a/third_party/toolchains/clang6/CROSSTOOL.tpl
+++ b/third_party/toolchains/clang6/CROSSTOOL.tpl
@@ -76,9 +76,6 @@ toolchain {
# This adds a little bit more durability to our Clang build.
#
- # At the moment, this only only be needed for:
- # - add_boringssl_s390x.patch: --Wa,--noexecstack
- #
# Folks who do maintenance work on TF Bazel Clang should consider
# commenting out these lines, while doing that work, to gain a better
# understanding of what the intersection of support looks like between GCC